From 0d12ca98548524a2f1da5f4a7c5f4f0831732483 Mon Sep 17 00:00:00 2001 From: fengmengqi Date: Thu, 2 Apr 2026 15:14:31 +0800 Subject: [PATCH] =?UTF-8?q?=20feat:=20Claw=20Code=20=E9=87=8D=E5=91=BD?= =?UTF-8?q?=E5=90=8D=E5=8F=8A=E6=9E=B6=E6=9E=84=E5=8D=87=E7=BA=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 品牌重塑: - Claude Code → Claw Code - .claude → .claw 配置目录 - CLAUDE_* → CLAW_* 环境变量 新增功能: - 多 Provider 架构 (ClawApi/Xai/OpenAI) - 插件系统 (生命周期/钩子/工具扩展) - LSP 集成 (诊断/代码智能) - Hook 系统 (PreToolUse/PostToolUse) - 独立 CLI (claw-cli) - HTTP Server (Axum/SSE) - Slash Commands 扩展 (branch/worktree/commit/pr/plugin等) 优化改进: - Compaction 支持增量压缩 - 全局工具注册表 - 配置文件统一为 .claw.json --- .claude/sessions/session-1775007453382.json | 1 - .claude/sessions/session-1775007484031.json | 1 - .claude/sessions/session-1775007490104.json | 1 - .claude/sessions/session-1775007981374.json | 1 - .claude/sessions/session-1775008007069.json | 1 - .claude/sessions/session-1775008071886.json | 1 - .github/workflows/ci.yml | 36 + CONTRIBUTING.md | 43 + Cargo.lock | 394 +- Cargo.toml | 4 + README.md | 256 +- crates/api/Cargo.toml | 2 +- crates/api/src/client.rs | 1029 +--- crates/api/src/error.rs | 43 +- crates/api/src/lib.rs | 10 +- crates/api/src/providers/claw_provider.rs | 1046 ++++ crates/api/src/providers/mod.rs | 239 + crates/api/src/providers/openai_compat.rs | 1050 ++++ crates/api/src/sse.rs | 60 + crates/api/src/types.rs | 11 + crates/api/tests/client_integration.rs | 74 +- crates/api/tests/openai_compat_integration.rs | 415 ++ .../api/tests/provider_client_integration.rs | 86 + crates/claw-cli/Cargo.toml | 27 + crates/claw-cli/src/app.rs | 402 ++ crates/claw-cli/src/args.rs | 104 + crates/claw-cli/src/init.rs | 432 ++ crates/claw-cli/src/input.rs | 1195 ++++ crates/claw-cli/src/main.rs | 5090 +++++++++++++++++ crates/claw-cli/src/render.rs | 797 +++ crates/commands/Cargo.toml | 2 + crates/commands/src/lib.rs | 2245 +++++++- crates/lsp/Cargo.toml | 16 + crates/lsp/src/client.rs | 463 ++ crates/lsp/src/error.rs | 62 + crates/lsp/src/lib.rs | 283 + crates/lsp/src/manager.rs | 191 + crates/lsp/src/types.rs | 186 + crates/plugins/Cargo.toml | 13 + .../example-bundled/.claw-plugin/plugin.json | 10 + .../bundled/example-bundled/hooks/post.sh | 2 + .../bundled/example-bundled/hooks/pre.sh | 2 + .../sample-hooks/.claw-plugin/plugin.json | 10 + .../bundled/sample-hooks/hooks/post.sh | 2 + .../plugins/bundled/sample-hooks/hooks/pre.sh | 2 + crates/plugins/src/hooks.rs | 395 ++ crates/plugins/src/lib.rs | 2943 ++++++++++ crates/runtime/Cargo.toml | 4 +- crates/runtime/src/bootstrap.rs | 2 +- crates/runtime/src/compact.rs | 237 +- crates/runtime/src/config.rs | 381 +- crates/runtime/src/conversation.rs | 239 +- crates/runtime/src/file_ops.rs | 2 +- crates/runtime/src/hooks.rs | 357 ++ crates/runtime/src/lib.rs | 14 +- crates/runtime/src/mcp.rs | 4 +- crates/runtime/src/mcp_client.rs | 14 +- crates/runtime/src/mcp_stdio.rs | 31 +- crates/runtime/src/oauth.rs | 10 +- crates/runtime/src/prompt.rs | 102 +- crates/runtime/src/remote.rs | 14 +- crates/runtime/src/session.rs | 12 +- crates/runtime/src/usage.rs | 9 +- crates/server/Cargo.toml | 20 + crates/server/src/lib.rs | 442 ++ crates/tools/Cargo.toml | 5 +- crates/tools/src/lib.rs | 1120 +++- docs/releases/0.1.0.md | 51 + 68 files changed, 21305 insertions(+), 1443 deletions(-) delete mode 100644 .claude/sessions/session-1775007453382.json delete mode 100644 .claude/sessions/session-1775007484031.json delete mode 100644 .claude/sessions/session-1775007490104.json delete mode 100644 .claude/sessions/session-1775007981374.json delete mode 100644 .claude/sessions/session-1775008007069.json delete mode 100644 .claude/sessions/session-1775008071886.json create mode 100644 .github/workflows/ci.yml create mode 100644 CONTRIBUTING.md create mode 100644 crates/api/src/providers/claw_provider.rs create mode 100644 crates/api/src/providers/mod.rs create mode 100644 crates/api/src/providers/openai_compat.rs create mode 100644 crates/api/tests/openai_compat_integration.rs create mode 100644 crates/api/tests/provider_client_integration.rs create mode 100644 crates/claw-cli/Cargo.toml create mode 100644 crates/claw-cli/src/app.rs create mode 100644 crates/claw-cli/src/args.rs create mode 100644 crates/claw-cli/src/init.rs create mode 100644 crates/claw-cli/src/input.rs create mode 100644 crates/claw-cli/src/main.rs create mode 100644 crates/claw-cli/src/render.rs create mode 100644 crates/lsp/Cargo.toml create mode 100644 crates/lsp/src/client.rs create mode 100644 crates/lsp/src/error.rs create mode 100644 crates/lsp/src/lib.rs create mode 100644 crates/lsp/src/manager.rs create mode 100644 crates/lsp/src/types.rs create mode 100644 crates/plugins/Cargo.toml create mode 100644 crates/plugins/bundled/example-bundled/.claw-plugin/plugin.json create mode 100644 crates/plugins/bundled/example-bundled/hooks/post.sh create mode 100644 crates/plugins/bundled/example-bundled/hooks/pre.sh create mode 100644 crates/plugins/bundled/sample-hooks/.claw-plugin/plugin.json create mode 100644 crates/plugins/bundled/sample-hooks/hooks/post.sh create mode 100644 crates/plugins/bundled/sample-hooks/hooks/pre.sh create mode 100644 crates/plugins/src/hooks.rs create mode 100644 crates/plugins/src/lib.rs create mode 100644 crates/runtime/src/hooks.rs create mode 100644 crates/server/Cargo.toml create mode 100644 crates/server/src/lib.rs create mode 100644 docs/releases/0.1.0.md diff --git a/.claude/sessions/session-1775007453382.json b/.claude/sessions/session-1775007453382.json deleted file mode 100644 index d45e491..0000000 --- a/.claude/sessions/session-1775007453382.json +++ /dev/null @@ -1 +0,0 @@ -{"messages":[],"version":1} \ No newline at end of file diff --git a/.claude/sessions/session-1775007484031.json b/.claude/sessions/session-1775007484031.json deleted file mode 100644 index d45e491..0000000 --- a/.claude/sessions/session-1775007484031.json +++ /dev/null @@ -1 +0,0 @@ -{"messages":[],"version":1} \ No newline at end of file diff --git a/.claude/sessions/session-1775007490104.json b/.claude/sessions/session-1775007490104.json deleted file mode 100644 index d45e491..0000000 --- a/.claude/sessions/session-1775007490104.json +++ /dev/null @@ -1 +0,0 @@ -{"messages":[],"version":1} \ No newline at end of file diff --git a/.claude/sessions/session-1775007981374.json b/.claude/sessions/session-1775007981374.json deleted file mode 100644 index d45e491..0000000 --- a/.claude/sessions/session-1775007981374.json +++ /dev/null @@ -1 +0,0 @@ -{"messages":[],"version":1} \ No newline at end of file diff --git a/.claude/sessions/session-1775008007069.json b/.claude/sessions/session-1775008007069.json deleted file mode 100644 index d45e491..0000000 --- a/.claude/sessions/session-1775008007069.json +++ /dev/null @@ -1 +0,0 @@ -{"messages":[],"version":1} \ No newline at end of file diff --git a/.claude/sessions/session-1775008071886.json b/.claude/sessions/session-1775008071886.json deleted file mode 100644 index d45e491..0000000 --- a/.claude/sessions/session-1775008071886.json +++ /dev/null @@ -1 +0,0 @@ -{"messages":[],"version":1} \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..73459b8 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,36 @@ +name: CI + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + rust: + name: ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: + - ubuntu-latest + - macos-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Run cargo check + run: cargo check --workspace + + - name: Run cargo test + run: cargo test --workspace + + - name: Run release build + run: cargo build --release diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..759fb9e --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,43 @@ +# 贡献指南 + +感谢你为 Claw Code 做出贡献。 + +## 开发设置 + +- 安装稳定的 Rust 工具链。 +- 在此 Rust 工作区的仓库根目录下进行开发。如果你从父仓库根目录开始,请先执行 `cd rust/`。 + +## 构建 + +```bash +cargo build +cargo build --release +``` + +## 测试与验证 + +在开启 Pull Request 之前,请运行完整的 Rust 验证集: + +```bash +cargo fmt --all --check +cargo clippy --workspace --all-targets -- -D warnings +cargo check --workspace +cargo test --workspace +``` + +如果你更改了行为,请在同一个 Pull Request 中添加或更新相关的测试。 + +## 代码风格 + +- 遵循所修改 crate 中的现有模式,而不是引入新的风格。 +- 使用 `rustfmt` 格式化代码。 +- 确保你修改的工作区目标的 `clippy` 检查通过。 +- 优先采用针对性的 diff,而不是顺便进行的重构。 + +## Pull Request + +- 从 `main` 分支拉取新分支。 +- 确保每个 Pull Request 的范围仅限于一个明确的更改。 +- 说明更改动机、实现摘要以及你运行的验证。 +- 在请求审查之前,确保本地检查已通过。 +- 如果审查反馈导致行为更改,请重新运行相关的验证命令。 diff --git a/Cargo.lock b/Cargo.lock index 9030127..443b79d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -28,12 +28,86 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "base64" version = "0.22.1" @@ -49,6 +123,12 @@ dependencies = [ "serde", ] +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.11.0" @@ -98,11 +178,40 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "claw-cli" +version = "0.1.0" +dependencies = [ + "api", + "commands", + "compat-harness", + "crossterm", + "plugins", + "pulldown-cmark", + "runtime", + "rustyline", + "serde_json", + "syntect", + "tokio", + "tools", +] + +[[package]] +name = "clipboard-win" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bde03770d3df201d4fb868f2c9c59e66a3e4e2bd06692a0fe701e7103c7e84d4" +dependencies = [ + "error-code", +] + [[package]] name = "commands" version = "0.1.0" dependencies = [ + "plugins", "runtime", + "serde_json", ] [[package]] @@ -138,11 +247,11 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" dependencies = [ - "bitflags", + "bitflags 2.11.0", "crossterm_winapi", "mio", "parking_lot", - "rustix", + "rustix 0.38.44", "signal-hook", "signal-hook-mio", "winapi", @@ -197,6 +306,12 @@ dependencies = [ "syn", ] +[[package]] +name = "endian-type" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" + [[package]] name = "equivalent" version = "1.0.2" @@ -213,6 +328,23 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "error-code" +version = "3.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59" + +[[package]] +name = "fd-lock" +version = "4.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" +dependencies = [ + "cfg-if", + "rustix 1.1.4", + "windows-sys 0.59.0", +] + [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -229,6 +361,15 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "fluent-uri" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17c704e9dbe1ddd863da1e6ff3567795087b1eb201ce80d8fa81162e1516500d" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "fnv" version = "1.0.7" @@ -266,6 +407,17 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.32" @@ -286,6 +438,7 @@ checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -351,6 +504,15 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "home" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "http" version = "1.4.0" @@ -390,6 +552,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.9.0" @@ -403,6 +571,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -614,6 +783,12 @@ version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + [[package]] name = "litemap" version = "0.8.1" @@ -641,12 +816,48 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "lsp" +version = "0.1.0" +dependencies = [ + "lsp-types", + "serde", + "serde_json", + "tokio", + "url", +] + +[[package]] +name = "lsp-types" +version = "0.97.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53353550a17c04ac46c585feb189c2db82154fc84b79c7a66c96c2c644f66071" +dependencies = [ + "bitflags 1.3.2", + "fluent-uri", + "serde", + "serde_json", + "serde_repr", +] + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "memchr" version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -669,6 +880,27 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nibble_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" +dependencies = [ + "smallvec", +] + +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "num-conv" version = "0.2.1" @@ -687,7 +919,7 @@ version = "6.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "336b9c63443aceef14bea841b899035ae3abe89b7c486aaf4c5bd8aafedac3f0" dependencies = [ - "bitflags", + "bitflags 2.11.0", "libc", "once_cell", "onig_sys", @@ -757,6 +989,14 @@ dependencies = [ "time", ] +[[package]] +name = "plugins" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "potential_utf" version = "0.1.4" @@ -796,7 +1036,7 @@ version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c3a14896dfa883796f1cb410461aef38810ea05f2b2c33c5aded3649095fdad" dependencies = [ - "bitflags", + "bitflags 2.11.0", "getopts", "memchr", "pulldown-cmark-escape", @@ -888,6 +1128,16 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "radix_trie" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" +dependencies = [ + "endian-type", + "nibble_vec", +] + [[package]] name = "rand" version = "0.9.2" @@ -923,7 +1173,7 @@ version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ - "bitflags", + "bitflags 2.11.0", ] [[package]] @@ -985,12 +1235,14 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-rustls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", ] @@ -1014,6 +1266,8 @@ name = "runtime" version = "0.1.0" dependencies = [ "glob", + "lsp", + "plugins", "regex", "serde", "serde_json", @@ -1034,11 +1288,24 @@ version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags", + "bitflags 2.11.0", "errno", "libc", - "linux-raw-sys", - "windows-sys 0.52.0", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags 2.11.0", + "errno", + "libc", + "linux-raw-sys 0.12.1", + "windows-sys 0.61.2", ] [[package]] @@ -1098,6 +1365,28 @@ dependencies = [ "tools", ] +[[package]] +name = "rustyline" +version = "15.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ee1e066dc922e513bda599c6ccb5f3bb2b0ea5870a579448f2622993f0a9a2f" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "clipboard-win", + "fd-lock", + "home", + "libc", + "log", + "memchr", + "nix", + "radix_trie", + "unicode-segmentation", + "unicode-width", + "utf8parse", + "windows-sys 0.59.0", +] + [[package]] name = "ryu" version = "1.0.23" @@ -1162,6 +1451,28 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1174,6 +1485,19 @@ dependencies = [ "serde", ] +[[package]] +name = "server" +version = "0.1.0" +dependencies = [ + "async-stream", + "axum", + "reqwest", + "runtime", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "sha2" version = "0.10.9" @@ -1427,14 +1751,30 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tools" version = "0.1.0" dependencies = [ + "api", + "plugins", "reqwest", "runtime", "serde", "serde_json", + "tokio", ] [[package]] @@ -1450,6 +1790,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -1458,7 +1799,7 @@ version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ - "bitflags", + "bitflags 2.11.0", "bytes", "futures-util", "http", @@ -1488,6 +1829,7 @@ version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ + "log", "pin-project-lite", "tracing-core", ] @@ -1525,6 +1867,12 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-segmentation" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" + [[package]] name = "unicode-width" version = "0.2.2" @@ -1555,6 +1903,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "version_check" version = "0.9.5" @@ -1650,6 +2004,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.93" @@ -1725,6 +2092,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.60.2" diff --git a/Cargo.toml b/Cargo.toml index 4a2f4d4..aa2f4ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,10 @@ edition = "2021" license = "MIT" publish = false +[workspace.dependencies] +lsp-types = "0.97" +serde_json = "1" + [workspace.lints.rust] unsafe_code = "forbid" diff --git a/README.md b/README.md index 26a0d4a..fb4ef6b 100644 --- a/README.md +++ b/README.md @@ -1,230 +1,122 @@ -# Rusty Claude CLI +# Claw Code -`rust/` contains the Rust workspace for the integrated `rusty-claude-cli` deliverable. -It is intended to be something you can clone, build, and run directly. +Claw Code 是一个使用安全 Rust 实现的本地编程代理(coding-agent)命令行工具。它的设计灵感来自 **Claude Code**,并作为一个**净室实现(clean-room implementation)**开发:旨在提供强大的本地代理体验,但它**不是** Claude Code 的直接移植或复制。 -## Workspace layout +Rust 工作区是当前主要的产品界面。`claw` 二进制文件在单个工作区内提供交互式会话、单次提示、工作区感知工具、本地代理工作流以及支持插件的操作。 -```text -rust/ -├── Cargo.toml -├── Cargo.lock -├── README.md -└── crates/ - ├── api/ # Anthropic API client + SSE streaming support - ├── commands/ # Shared slash-command metadata/help surfaces - ├── compat-harness/ # Upstream TS manifest extraction harness - ├── runtime/ # Session/runtime/config/prompt orchestration - ├── rusty-claude-cli/ # Main CLI binary - └── tools/ # Built-in tool implementations -``` +## 当前状态 -## Prerequisites +- **版本:** `0.1.0` +- **发布阶段:** 初始公开发布,源码编译分发 +- **主要实现:** 本仓库中的 Rust 工作区 +- **平台焦点:** macOS 和 Linux 开发工作站 -- Rust toolchain installed (`rustup`, stable toolchain) -- Network access and Anthropic credentials for live prompt/REPL usage +## 安装、构建与运行 -## Build +### 准备工作 -From the repository root: +- Rust 稳定版工具链 +- Cargo +- 你想使用的模型的提供商凭据 + +### 身份验证 + +兼容 Anthropic 的模型: ```bash -cd rust -cargo build --release -p rusty-claude-cli +export ANTHROPIC_API_KEY="..." +# 使用兼容的端点时可选 +export ANTHROPIC_BASE_URL="https://api.anthropic.com" ``` -The optimized binary will be written to: +Grok 模型: ```bash -./target/release/rusty-claude-cli +export XAI_API_KEY="..." +# 使用兼容的端点时可选 +export XAI_BASE_URL="https://api.x.ai" ``` -## Test - -Run the verified workspace test suite used for release-readiness: +也可以使用 OAuth 登录: ```bash -cd rust -cargo test --workspace --exclude compat-harness +cargo run --bin claw -- login ``` -## Quick start - -### Show help +### 本地安装 ```bash -cd rust -cargo run -p rusty-claude-cli -- --help +cargo install --path crates/claw-cli --locked ``` -### Print version +### 从源码构建 ```bash -cd rust -cargo run -p rusty-claude-cli -- --version +cargo build --release -p claw-cli ``` -### Login with OAuth +### 运行 -Configure `settings.json` with an `oauth` block containing `clientId`, `authorizeUrl`, `tokenUrl`, optional `callbackPort`, and optional `scopes`, then run: +在工作区内运行: ```bash -cd rust -cargo run -p rusty-claude-cli -- login +cargo run --bin claw -- --help +cargo run --bin claw -- +cargo run --bin claw -- prompt "总结此工作区" +cargo run --bin claw -- --model sonnet "审查最新更改" ``` -This opens the browser, listens on the configured localhost callback, exchanges the auth code for tokens, and stores OAuth credentials in `~/.claude/credentials.json` (or `$CLAUDE_CONFIG_HOME/credentials.json`). - -### Logout +运行发布版本: ```bash -cd rust -cargo run -p rusty-claude-cli -- logout +./target/release/claw +./target/release/claw prompt "解释 crates/runtime" ``` -This removes only the stored OAuth credentials and preserves unrelated JSON fields in `credentials.json`. +## 支持的功能 -### Self-update +- 交互式 REPL 和单次提示执行 +- 已保存会话的检查和恢复流程 +- 内置工作区工具:shell、文件读/写/编辑、搜索、网页获取/搜索、待办事项和笔记本更新 +- 斜杠命令:状态、压缩、配置检查、差异(diff)、导出、会话管理和版本报告 +- 本地代理和技能发现:通过 `claw agents` 和 `claw skills` +- 通过命令行和斜杠命令界面发现并管理插件 +- OAuth 登录/注销,以及从命令行选择模型/提供商 +- 工作区感知的指令/配置加载(`CLAW.md`、配置文件、权限、插件设置) -```bash -cd rust -cargo run -p rusty-claude-cli -- self-update -``` +## 当前限制 -The command checks the latest GitHub release for `instructkr/clawd-code`, compares it to the current binary version, downloads the matching binary asset plus checksum manifest, verifies SHA-256, replaces the current executable, and prints the release changelog. If no published release or matching asset exists, it exits safely with an explanatory message. +- 目前公开发布**仅限源码构建**;此工作区尚未设置 crates.io 发布 +- GitHub CI 验证 `cargo check`、`cargo test` 和发布构建,但尚未提供自动化的发布打包 +- 当前 CI 目标为 Ubuntu 和 macOS;Windows 的发布就绪性仍待建立 +- 一些实时提供商集成覆盖是可选的,因为它们需要外部凭据 and 网络访问 +- 命令界面可能会在 `0.x` 系列期间继续演进 -## Usage examples +## 实现现状 -### 1) Prompt mode +Rust 工作区是当前的产品实现。目前包含以下 crate: -Send one prompt, stream the answer, then exit: +- `claw-cli` — 面向用户的二进制文件 +- `api` — 提供商客户端和流式处理 +- `runtime` — 会话、配置、权限、提示词和运行时循环 +- `tools` — 内置工具实现 +- `commands` — 斜杠命令注册和处理程序 +- `plugins` — 插件发现、注册和生命周期支持 +- `lsp` — 语言服务器协议支持类型和进程助手 +- `server` 和 `compat-harness` — 支持服务和兼容性工具 -```bash -cd rust -cargo run -p rusty-claude-cli -- prompt "Summarize the architecture of this repository" -``` +## 路线图 -Use a specific model: +- 发布打包好的构件,用于公共安装 +- 添加可重复的发布工作流和长期维护的变更日志(changelog)规范 +- 将平台验证扩展到当前 CI 矩阵之外 +- 添加更多以任务为中心的示例和操作员文档 +- 继续加强 Rust 实现的功能覆盖并磨炼用户体验(UX) -```bash -cd rust -cargo run -p rusty-claude-cli -- --model claude-sonnet-4-20250514 prompt "List the key crates in this workspace" -``` +## 发行版本说明 -Restrict enabled tools in an interactive session: +- 0.1.0 发行说明草案:[`docs/releases/0.1.0.md`](docs/releases/0.1.0.md) -```bash -cd rust -cargo run -p rusty-claude-cli -- --allowedTools read,glob -``` +## 许可 -Bootstrap Claude project files for the current repo: - -```bash -cd rust -cargo run -p rusty-claude-cli -- init -``` - -### 2) REPL mode - -Start the interactive shell: - -```bash -cd rust -cargo run -p rusty-claude-cli -- -``` - -Inside the REPL, useful commands include: - -```text -/help -/status -/model claude-sonnet-4-20250514 -/permissions workspace-write -/cost -/compact -/memory -/config -/init -/diff -/version -/export notes.txt -/sessions -/session list -/exit -``` - -### 3) Resume an existing session - -Inspect or maintain a saved session file without entering the REPL: - -```bash -cd rust -cargo run -p rusty-claude-cli -- --resume session-123456 /status /compact /cost -``` - -You can also inspect memory/config state for a restored session: - -```bash -cd rust -cargo run -p rusty-claude-cli -- --resume ~/.claude/sessions/session-123456.json /memory /config -``` - -## Available commands - -### Top-level CLI commands - -- `prompt ` — run one prompt non-interactively -- `--resume [/commands...]` — inspect or maintain a saved session stored under `~/.claude/sessions/` -- `dump-manifests` — print extracted upstream manifest counts -- `bootstrap-plan` — print the current bootstrap skeleton -- `system-prompt [--cwd PATH] [--date YYYY-MM-DD]` — render the synthesized system prompt -- `self-update` — update the installed binary from the latest GitHub release when a matching asset is available -- `--help` / `-h` — show CLI help -- `--version` / `-V` — print the CLI version and build info locally (no API call) -- `--output-format text|json` — choose non-interactive prompt output rendering -- `--allowedTools ` — restrict enabled tools for interactive sessions and prompt-mode tool use - -### Interactive slash commands - -- `/help` — show command help -- `/status` — show current session status -- `/compact` — compact local session history -- `/model [model]` — inspect or switch the active model -- `/permissions [read-only|workspace-write|danger-full-access]` — inspect or switch permissions -- `/clear [--confirm]` — clear the current local session -- `/cost` — show token usage totals -- `/resume ` — load a saved session into the REPL -- `/config [env|hooks|model]` — inspect discovered Claude config -- `/memory` — inspect loaded instruction memory files -- `/init` — bootstrap `.claude.json`, `.claude/`, `CLAUDE.md`, and local ignore rules -- `/diff` — show the current git diff for the workspace -- `/version` — print version and build metadata locally -- `/export [file]` — export the current conversation transcript -- `/sessions` — list recent managed local sessions from `~/.claude/sessions/` -- `/session [list|switch ]` — inspect or switch managed local sessions -- `/exit` — leave the REPL - -## Environment variables - -### Anthropic/API - -- `ANTHROPIC_API_KEY` — highest-precedence API credential -- `ANTHROPIC_AUTH_TOKEN` — bearer-token override used when no API key is set -- Persisted OAuth credentials in `~/.claude/credentials.json` — used when neither env var is set -- `ANTHROPIC_BASE_URL` — override the Anthropic API base URL -- `ANTHROPIC_MODEL` — default model used by selected live integration tests - -### CLI/runtime - -- `RUSTY_CLAUDE_PERMISSION_MODE` — default REPL permission mode (`read-only`, `workspace-write`, or `danger-full-access`) -- `CLAUDE_CONFIG_HOME` — override Claude config discovery root -- `CLAUDE_CODE_REMOTE` — enable remote-session bootstrap handling when supported -- `CLAUDE_CODE_REMOTE_SESSION_ID` — remote session identifier when using remote mode -- `CLAUDE_CODE_UPSTREAM` — override the upstream TS source path for compat-harness extraction -- `CLAWD_WEB_SEARCH_BASE_URL` — override the built-in web search service endpoint used by tooling - -## Notes - -- `compat-harness` exists to compare the Rust port against the upstream TypeScript codebase and is intentionally excluded from the requested release test run. -- The CLI currently focuses on a practical integrated workflow: prompt execution, REPL operation, session inspection/resume, config discovery, and tool/runtime plumbing. +有关许可详情,请参阅仓库根目录。 diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index c5e152e..b9923a8 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -9,7 +9,7 @@ publish.workspace = true reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } runtime = { path = "../runtime" } serde = { version = "1", features = ["derive"] } -serde_json = "1" +serde_json.workspace = true tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] } [lints] diff --git a/crates/api/src/client.rs b/crates/api/src/client.rs index 110a80b..b596777 100644 --- a/crates/api/src/client.rs +++ b/crates/api/src/client.rs @@ -1,1006 +1,141 @@ -use std::collections::VecDeque; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - -use runtime::{ - load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, - OAuthTokenExchangeRequest, -}; -use serde::Deserialize; - use crate::error::ApiError; -use crate::sse::SseParser; +use crate::providers::claw_provider::{self, AuthSource, ClawApiClient}; +use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig}; +use crate::providers::{self, Provider, ProviderKind}; use crate::types::{MessageRequest, MessageResponse, StreamEvent}; -const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; -const ANTHROPIC_VERSION: &str = "2023-06-01"; -const REQUEST_ID_HEADER: &str = "request-id"; -const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; -const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); -const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); -const DEFAULT_MAX_RETRIES: u32 = 2; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum AuthSource { - None, - ApiKey(String), - BearerToken(String), - ApiKeyAndBearer { - api_key: String, - bearer_token: String, - }, +async fn send_via_provider( + provider: &P, + request: &MessageRequest, +) -> Result { + provider.send_message(request).await } -impl AuthSource { - pub fn from_env() -> Result { - let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?; - let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?; - match (api_key, auth_token) { - (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer { - api_key, - bearer_token, - }), - (Some(api_key), None) => Ok(Self::ApiKey(api_key)), - (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), - (None, None) => Err(ApiError::MissingApiKey), - } - } - - #[must_use] - pub fn api_key(&self) -> Option<&str> { - match self { - Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key), - Self::None | Self::BearerToken(_) => None, - } - } - - #[must_use] - pub fn bearer_token(&self) -> Option<&str> { - match self { - Self::BearerToken(token) - | Self::ApiKeyAndBearer { - bearer_token: token, - .. - } => Some(token), - Self::None | Self::ApiKey(_) => None, - } - } - - #[must_use] - pub fn masked_authorization_header(&self) -> &'static str { - if self.bearer_token().is_some() { - "Bearer [REDACTED]" - } else { - "" - } - } - - pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - if let Some(api_key) = self.api_key() { - request_builder = request_builder.header("x-api-key", api_key); - } - if let Some(token) = self.bearer_token() { - request_builder = request_builder.bearer_auth(token); - } - request_builder - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] -pub struct OAuthTokenSet { - pub access_token: String, - pub refresh_token: Option, - pub expires_at: Option, - #[serde(default)] - pub scopes: Vec, -} - -impl From for AuthSource { - fn from(value: OAuthTokenSet) -> Self { - Self::BearerToken(value.access_token) - } +async fn stream_via_provider( + provider: &P, + request: &MessageRequest, +) -> Result { + provider.stream_message(request).await } #[derive(Debug, Clone)] -pub struct AnthropicClient { - http: reqwest::Client, - auth: AuthSource, - base_url: String, - max_retries: u32, - initial_backoff: Duration, - max_backoff: Duration, +pub enum ProviderClient { + ClawApi(ClawApiClient), + Xai(OpenAiCompatClient), + OpenAi(OpenAiCompatClient), } -impl AnthropicClient { - #[must_use] - pub fn new(api_key: impl Into) -> Self { - Self { - http: reqwest::Client::new(), - auth: AuthSource::ApiKey(api_key.into()), - base_url: DEFAULT_BASE_URL.to_string(), - max_retries: DEFAULT_MAX_RETRIES, - initial_backoff: DEFAULT_INITIAL_BACKOFF, - max_backoff: DEFAULT_MAX_BACKOFF, +impl ProviderClient { + pub fn from_model(model: &str) -> Result { + Self::from_model_with_default_auth(model, None) + } + + pub fn from_model_with_default_auth( + model: &str, + default_auth: Option, + ) -> Result { + let resolved_model = providers::resolve_model_alias(model); + match providers::detect_provider_kind(&resolved_model) { + ProviderKind::ClawApi => Ok(Self::ClawApi(match default_auth { + Some(auth) => ClawApiClient::from_auth(auth), + None => ClawApiClient::from_env()?, + })), + ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env( + OpenAiCompatConfig::xai(), + )?)), + ProviderKind::OpenAi => Ok(Self::OpenAi(OpenAiCompatClient::from_env( + OpenAiCompatConfig::openai(), + )?)), } } #[must_use] - pub fn from_auth(auth: AuthSource) -> Self { - Self { - http: reqwest::Client::new(), - auth, - base_url: DEFAULT_BASE_URL.to_string(), - max_retries: DEFAULT_MAX_RETRIES, - initial_backoff: DEFAULT_INITIAL_BACKOFF, - max_backoff: DEFAULT_MAX_BACKOFF, + pub const fn provider_kind(&self) -> ProviderKind { + match self { + Self::ClawApi(_) => ProviderKind::ClawApi, + Self::Xai(_) => ProviderKind::Xai, + Self::OpenAi(_) => ProviderKind::OpenAi, } } - pub fn from_env() -> Result { - Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url())) - } - - #[must_use] - pub fn with_auth_source(mut self, auth: AuthSource) -> Self { - self.auth = auth; - self - } - - #[must_use] - pub fn with_auth_token(mut self, auth_token: Option) -> Self { - match ( - self.auth.api_key().map(ToOwned::to_owned), - auth_token.filter(|token| !token.is_empty()), - ) { - (Some(api_key), Some(bearer_token)) => { - self.auth = AuthSource::ApiKeyAndBearer { - api_key, - bearer_token, - }; - } - (Some(api_key), None) => { - self.auth = AuthSource::ApiKey(api_key); - } - (None, Some(bearer_token)) => { - self.auth = AuthSource::BearerToken(bearer_token); - } - (None, None) => { - self.auth = AuthSource::None; - } - } - self - } - - #[must_use] - pub fn with_base_url(mut self, base_url: impl Into) -> Self { - self.base_url = base_url.into(); - self - } - - #[must_use] - pub fn with_retry_policy( - mut self, - max_retries: u32, - initial_backoff: Duration, - max_backoff: Duration, - ) -> Self { - self.max_retries = max_retries; - self.initial_backoff = initial_backoff; - self.max_backoff = max_backoff; - self - } - - #[must_use] - pub fn auth_source(&self) -> &AuthSource { - &self.auth - } - pub async fn send_message( &self, request: &MessageRequest, ) -> Result { - let request = MessageRequest { - stream: false, - ..request.clone() - }; - let response = self.send_with_retry(&request).await?; - let request_id = request_id_from_headers(response.headers()); - let mut response = response - .json::() - .await - .map_err(ApiError::from)?; - if response.request_id.is_none() { - response.request_id = request_id; + match self { + Self::ClawApi(client) => send_via_provider(client, request).await, + Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await, } - Ok(response) } pub async fn stream_message( &self, request: &MessageRequest, ) -> Result { - let response = self - .send_with_retry(&request.clone().with_streaming()) - .await?; - Ok(MessageStream { - request_id: request_id_from_headers(response.headers()), - response, - parser: SseParser::new(), - pending: VecDeque::new(), - done: false, - }) - } - - pub async fn exchange_oauth_code( - &self, - config: &OAuthConfig, - request: &OAuthTokenExchangeRequest, - ) -> Result { - let response = self - .http - .post(&config.token_url) - .header("content-type", "application/x-www-form-urlencoded") - .form(&request.form_params()) - .send() - .await - .map_err(ApiError::from)?; - let response = expect_success(response).await?; - response - .json::() - .await - .map_err(ApiError::from) - } - - pub async fn refresh_oauth_token( - &self, - config: &OAuthConfig, - request: &OAuthRefreshRequest, - ) -> Result { - let response = self - .http - .post(&config.token_url) - .header("content-type", "application/x-www-form-urlencoded") - .form(&request.form_params()) - .send() - .await - .map_err(ApiError::from)?; - let response = expect_success(response).await?; - response - .json::() - .await - .map_err(ApiError::from) - } - - async fn send_with_retry( - &self, - request: &MessageRequest, - ) -> Result { - let mut attempts = 0; - let mut last_error: Option; - - loop { - attempts += 1; - match self.send_raw_request(request).await { - Ok(response) => match expect_success(response).await { - Ok(response) => return Ok(response), - Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { - last_error = Some(error); - } - Err(error) => return Err(error), - }, - Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { - last_error = Some(error); - } - Err(error) => return Err(error), - } - - if attempts > self.max_retries { - break; - } - - tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; - } - - Err(ApiError::RetriesExhausted { - attempts, - last_error: Box::new(last_error.expect("retry loop must capture an error")), - }) - } - - async fn send_raw_request( - &self, - request: &MessageRequest, - ) -> Result { - let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); - let resolved_base_url = self.base_url.trim_end_matches('/'); - eprintln!("[anthropic-client] resolved_base_url={resolved_base_url}"); - eprintln!("[anthropic-client] request_url={request_url}"); - let request_builder = self - .http - .post(&request_url) - .header("anthropic-version", ANTHROPIC_VERSION) - .header("content-type", "application/json"); - let mut request_builder = self.auth.apply(request_builder); - - eprintln!( - "[anthropic-client] headers x-api-key={} authorization={} anthropic-version={ANTHROPIC_VERSION} content-type=application/json", - if self.auth.api_key().is_some() { - "[REDACTED]" - } else { - "" - }, - self.auth.masked_authorization_header() - ); - - request_builder = request_builder.json(request); - request_builder.send().await.map_err(ApiError::from) - } - - fn backoff_for_attempt(&self, attempt: u32) -> Result { - let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { - return Err(ApiError::BackoffOverflow { - attempt, - base_delay: self.initial_backoff, - }); - }; - Ok(self - .initial_backoff - .checked_mul(multiplier) - .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) - } -} - -impl AuthSource { - pub fn from_env_or_saved() -> Result { - if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { - return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - Some(bearer_token) => Ok(Self::ApiKeyAndBearer { - api_key, - bearer_token, - }), - None => Ok(Self::ApiKey(api_key)), - }; - } - if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - return Ok(Self::BearerToken(bearer_token)); - } - match load_saved_oauth_token() { - Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => { - if token_set.refresh_token.is_some() { - Err(ApiError::Auth( - "saved OAuth token is expired; load runtime OAuth config to refresh it" - .to_string(), - )) - } else { - Err(ApiError::ExpiredOAuthToken) - } - } - Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), - Ok(None) => Err(ApiError::MissingApiKey), - Err(error) => Err(error), + match self { + Self::ClawApi(client) => stream_via_provider(client, request) + .await + .map(MessageStream::ClawApi), + Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request) + .await + .map(MessageStream::OpenAiCompat), } } } -#[must_use] -pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool { - token_set - .expires_at - .is_some_and(|expires_at| expires_at <= now_unix_timestamp()) -} - -pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result, ApiError> { - let Some(token_set) = load_saved_oauth_token()? else { - return Ok(None); - }; - resolve_saved_oauth_token_set(config, token_set).map(Some) -} - -pub fn resolve_startup_auth_source(load_oauth_config: F) -> Result -where - F: FnOnce() -> Result, ApiError>, -{ - if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { - return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer { - api_key, - bearer_token, - }), - None => Ok(AuthSource::ApiKey(api_key)), - }; - } - if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - return Ok(AuthSource::BearerToken(bearer_token)); - } - - let Some(token_set) = load_saved_oauth_token()? else { - return Err(ApiError::MissingApiKey); - }; - if !oauth_token_is_expired(&token_set) { - return Ok(AuthSource::BearerToken(token_set.access_token)); - } - if token_set.refresh_token.is_none() { - return Err(ApiError::ExpiredOAuthToken); - } - - let Some(config) = load_oauth_config()? else { - return Err(ApiError::Auth( - "saved OAuth token is expired; runtime OAuth config is missing".to_string(), - )); - }; - Ok(AuthSource::from(resolve_saved_oauth_token_set( - &config, token_set, - )?)) -} - -fn resolve_saved_oauth_token_set( - config: &OAuthConfig, - token_set: OAuthTokenSet, -) -> Result { - if !oauth_token_is_expired(&token_set) { - return Ok(token_set); - } - let Some(refresh_token) = token_set.refresh_token.clone() else { - return Err(ApiError::ExpiredOAuthToken); - }; - let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url()); - let refreshed = client_runtime_block_on(async { - client - .refresh_oauth_token( - config, - &OAuthRefreshRequest::from_config( - config, - refresh_token, - Some(token_set.scopes.clone()), - ), - ) - .await - })?; - let resolved = OAuthTokenSet { - access_token: refreshed.access_token, - refresh_token: refreshed.refresh_token.or(token_set.refresh_token), - expires_at: refreshed.expires_at, - scopes: refreshed.scopes, - }; - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: resolved.access_token.clone(), - refresh_token: resolved.refresh_token.clone(), - expires_at: resolved.expires_at, - scopes: resolved.scopes.clone(), - }) - .map_err(ApiError::from)?; - Ok(resolved) -} - -fn client_runtime_block_on(future: F) -> Result -where - F: std::future::Future>, -{ - tokio::runtime::Runtime::new() - .map_err(ApiError::from)? - .block_on(future) -} - -fn load_saved_oauth_token() -> Result, ApiError> { - let token_set = load_oauth_credentials().map_err(ApiError::from)?; - Ok(token_set.map(|token_set| OAuthTokenSet { - access_token: token_set.access_token, - refresh_token: token_set.refresh_token, - expires_at: token_set.expires_at, - scopes: token_set.scopes, - })) -} - -fn now_unix_timestamp() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_or(0, |duration| duration.as_secs()) -} - -fn read_env_non_empty(key: &str) -> Result, ApiError> { - match std::env::var(key) { - Ok(value) if !value.is_empty() => Ok(Some(value)), - Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), - Err(error) => Err(ApiError::from(error)), - } -} - -#[cfg(test)] -fn read_api_key() -> Result { - let auth = AuthSource::from_env_or_saved()?; - auth.api_key() - .or_else(|| auth.bearer_token()) - .map(ToOwned::to_owned) - .ok_or(ApiError::MissingApiKey) -} - -#[cfg(test)] -fn read_auth_token() -> Option { - read_env_non_empty("ANTHROPIC_AUTH_TOKEN") - .ok() - .and_then(std::convert::identity) -} - -pub fn read_base_url() -> String { - std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()) -} - -fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { - headers - .get(REQUEST_ID_HEADER) - .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) - .and_then(|value| value.to_str().ok()) - .map(ToOwned::to_owned) -} - #[derive(Debug)] -pub struct MessageStream { - request_id: Option, - response: reqwest::Response, - parser: SseParser, - pending: VecDeque, - done: bool, +pub enum MessageStream { + ClawApi(claw_provider::MessageStream), + OpenAiCompat(openai_compat::MessageStream), } impl MessageStream { #[must_use] pub fn request_id(&self) -> Option<&str> { - self.request_id.as_deref() + match self { + Self::ClawApi(stream) => stream.request_id(), + Self::OpenAiCompat(stream) => stream.request_id(), + } } pub async fn next_event(&mut self) -> Result, ApiError> { - loop { - if let Some(event) = self.pending.pop_front() { - return Ok(Some(event)); - } - - if self.done { - let remaining = self.parser.finish()?; - self.pending.extend(remaining); - if let Some(event) = self.pending.pop_front() { - return Ok(Some(event)); - } - return Ok(None); - } - - match self.response.chunk().await? { - Some(chunk) => { - self.pending.extend(self.parser.push(&chunk)?); - } - None => { - self.done = true; - } - } + match self { + Self::ClawApi(stream) => stream.next_event().await, + Self::OpenAiCompat(stream) => stream.next_event().await, } } } -async fn expect_success(response: reqwest::Response) -> Result { - let status = response.status(); - if status.is_success() { - return Ok(response); - } - - let body = response.text().await.unwrap_or_else(|_| String::new()); - let parsed_error = serde_json::from_str::(&body).ok(); - let retryable = is_retryable_status(status); - - Err(ApiError::Api { - status, - error_type: parsed_error - .as_ref() - .map(|error| error.error.error_type.clone()), - message: parsed_error - .as_ref() - .map(|error| error.error.message.clone()), - body, - retryable, - }) +pub use claw_provider::{ + oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, OAuthTokenSet, +}; +#[must_use] +pub fn read_base_url() -> String { + claw_provider::read_base_url() } -const fn is_retryable_status(status: reqwest::StatusCode) -> bool { - matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) -} - -#[derive(Debug, Deserialize)] -struct AnthropicErrorEnvelope { - error: AnthropicErrorBody, -} - -#[derive(Debug, Deserialize)] -struct AnthropicErrorBody { - #[serde(rename = "type")] - error_type: String, - message: String, +#[must_use] +pub fn read_xai_base_url() -> String { + openai_compat::read_base_url(OpenAiCompatConfig::xai()) } #[cfg(test)] mod tests { - use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; - use std::io::{Read, Write}; - use std::net::TcpListener; - use std::sync::{Mutex, OnceLock}; - use std::thread; - use std::time::{Duration, SystemTime, UNIX_EPOCH}; + use crate::providers::{detect_provider_kind, resolve_model_alias, ProviderKind}; - use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig}; - - use crate::client::{ - now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, - resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet, - }; - use crate::types::{ContentBlockDelta, MessageRequest}; - - fn env_lock() -> std::sync::MutexGuard<'static, ()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) - .lock() - .expect("env lock") - } - - fn temp_config_home() -> std::path::PathBuf { - std::env::temp_dir().join(format!( - "api-oauth-test-{}-{}", - std::process::id(), - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("time") - .as_nanos() - )) - } - - fn sample_oauth_config(token_url: String) -> OAuthConfig { - OAuthConfig { - client_id: "runtime-client".to_string(), - authorize_url: "https://console.test/oauth/authorize".to_string(), - token_url, - callback_port: Some(4545), - manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), - scopes: vec!["org:read".to_string(), "user:write".to_string()], - } - } - - fn spawn_token_server(response_body: &'static str) -> String { - let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); - let address = listener.local_addr().expect("local addr"); - thread::spawn(move || { - let (mut stream, _) = listener.accept().expect("accept connection"); - let mut buffer = [0_u8; 4096]; - let _ = stream.read(&mut buffer).expect("read request"); - let response = format!( - "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", - response_body.len(), - response_body - ); - stream - .write_all(response.as_bytes()) - .expect("write response"); - }); - format!("http://{address}/oauth/token") + #[test] + fn resolves_existing_and_grok_aliases() { + assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6"); + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); } #[test] - fn read_api_key_requires_presence() { - let _guard = env_lock(); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - let error = super::read_api_key().expect_err("missing key should error"); - assert!(matches!(error, crate::error::ApiError::MissingApiKey)); - } - - #[test] - fn read_api_key_requires_non_empty_value() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); - std::env::remove_var("ANTHROPIC_API_KEY"); - let error = super::read_api_key().expect_err("empty key should error"); - assert!(matches!(error, crate::error::ApiError::MissingApiKey)); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - } - - #[test] - fn read_api_key_prefers_api_key_env() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); - std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + fn provider_detection_prefers_model_family() { + assert_eq!(detect_provider_kind("grok-3"), ProviderKind::Xai); assert_eq!( - super::read_api_key().expect("api key should load"), - "legacy-key" - ); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - } - - #[test] - fn read_auth_token_reads_auth_token_env() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); - assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - } - - #[test] - fn oauth_token_maps_to_bearer_auth_source() { - let auth = AuthSource::from(OAuthTokenSet { - access_token: "access-token".to_string(), - refresh_token: Some("refresh".to_string()), - expires_at: Some(123), - scopes: vec!["scope:a".to_string()], - }); - assert_eq!(auth.bearer_token(), Some("access-token")); - assert_eq!(auth.api_key(), None); - } - - #[test] - fn auth_source_from_env_combines_api_key_and_bearer_token() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); - std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); - let auth = AuthSource::from_env().expect("env auth"); - assert_eq!(auth.api_key(), Some("legacy-key")); - assert_eq!(auth.bearer_token(), Some("auth-token")); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - } - - #[test] - fn auth_source_from_saved_oauth_when_env_absent() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "saved-access-token".to_string(), - refresh_token: Some("refresh".to_string()), - expires_at: Some(now_unix_timestamp() + 300), - scopes: vec!["scope:a".to_string()], - }) - .expect("save oauth credentials"); - - let auth = AuthSource::from_env_or_saved().expect("saved auth"); - assert_eq!(auth.bearer_token(), Some("saved-access-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn oauth_token_expiry_uses_expires_at_timestamp() { - assert!(oauth_token_is_expired(&OAuthTokenSet { - access_token: "access-token".to_string(), - refresh_token: None, - expires_at: Some(1), - scopes: Vec::new(), - })); - assert!(!oauth_token_is_expired(&OAuthTokenSet { - access_token: "access-token".to_string(), - refresh_token: None, - expires_at: Some(now_unix_timestamp() + 60), - scopes: Vec::new(), - })); - } - - #[test] - fn resolve_saved_oauth_token_refreshes_expired_credentials() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "expired-access-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expires_at: Some(1), - scopes: vec!["scope:a".to_string()], - }) - .expect("save expired oauth credentials"); - - let token_url = spawn_token_server( - "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", - ); - let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) - .expect("resolve refreshed token") - .expect("token set present"); - assert_eq!(resolved.access_token, "refreshed-token"); - let stored = runtime::load_oauth_credentials() - .expect("load stored credentials") - .expect("stored token set"); - assert_eq!(stored.access_token, "refreshed-token"); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "saved-access-token".to_string(), - refresh_token: Some("refresh".to_string()), - expires_at: Some(now_unix_timestamp() + 300), - scopes: vec!["scope:a".to_string()], - }) - .expect("save oauth credentials"); - - let auth = resolve_startup_auth_source(|| panic!("config should not be loaded")) - .expect("startup auth"); - assert_eq!(auth.bearer_token(), Some("saved-access-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "expired-access-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expires_at: Some(1), - scopes: vec!["scope:a".to_string()], - }) - .expect("save expired oauth credentials"); - - let error = - resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error"); - assert!( - matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing")) - ); - - let stored = runtime::load_oauth_credentials() - .expect("load stored credentials") - .expect("stored token set"); - assert_eq!(stored.access_token, "expired-access-token"); - assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "expired-access-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expires_at: Some(1), - scopes: vec!["scope:a".to_string()], - }) - .expect("save expired oauth credentials"); - - let token_url = spawn_token_server( - "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", - ); - let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) - .expect("resolve refreshed token") - .expect("token set present"); - assert_eq!(resolved.access_token, "refreshed-token"); - assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token")); - let stored = runtime::load_oauth_credentials() - .expect("load stored credentials") - .expect("stored token set"); - assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn message_request_stream_helper_sets_stream_true() { - let request = MessageRequest { - model: "claude-opus-4-6".to_string(), - max_tokens: 64, - messages: vec![], - system: None, - tools: None, - tool_choice: None, - stream: false, - }; - - assert!(request.with_streaming().stream); - } - - #[test] - fn backoff_doubles_until_maximum() { - let client = AnthropicClient::new("test-key").with_retry_policy( - 3, - Duration::from_millis(10), - Duration::from_millis(25), - ); - assert_eq!( - client.backoff_for_attempt(1).expect("attempt 1"), - Duration::from_millis(10) - ); - assert_eq!( - client.backoff_for_attempt(2).expect("attempt 2"), - Duration::from_millis(20) - ); - assert_eq!( - client.backoff_for_attempt(3).expect("attempt 3"), - Duration::from_millis(25) - ); - } - - #[test] - fn retryable_statuses_are_detected() { - assert!(super::is_retryable_status( - reqwest::StatusCode::TOO_MANY_REQUESTS - )); - assert!(super::is_retryable_status( - reqwest::StatusCode::INTERNAL_SERVER_ERROR - )); - assert!(!super::is_retryable_status( - reqwest::StatusCode::UNAUTHORIZED - )); - } - - #[test] - fn tool_delta_variant_round_trips() { - let delta = ContentBlockDelta::InputJsonDelta { - partial_json: "{\"city\":\"Paris\"}".to_string(), - }; - let encoded = serde_json::to_string(&delta).expect("delta should serialize"); - let decoded: ContentBlockDelta = - serde_json::from_str(&encoded).expect("delta should deserialize"); - assert_eq!(decoded, delta); - } - - #[test] - fn request_id_uses_primary_or_fallback_header() { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header")); - assert_eq!( - super::request_id_from_headers(&headers).as_deref(), - Some("req_primary") - ); - - headers.clear(); - headers.insert( - ALT_REQUEST_ID_HEADER, - "req_fallback".parse().expect("header"), - ); - assert_eq!( - super::request_id_from_headers(&headers).as_deref(), - Some("req_fallback") - ); - } - - #[test] - fn auth_source_applies_headers() { - let auth = AuthSource::ApiKeyAndBearer { - api_key: "test-key".to_string(), - bearer_token: "proxy-token".to_string(), - }; - let request = auth - .apply(reqwest::Client::new().post("https://example.test")) - .build() - .expect("request build"); - let headers = request.headers(); - assert_eq!( - headers.get("x-api-key").and_then(|v| v.to_str().ok()), - Some("test-key") - ); - assert_eq!( - headers.get("authorization").and_then(|v| v.to_str().ok()), - Some("Bearer proxy-token") + detect_provider_kind("claude-sonnet-4-6"), + ProviderKind::ClawApi ); } } diff --git a/crates/api/src/error.rs b/crates/api/src/error.rs index 2c31691..7649889 100644 --- a/crates/api/src/error.rs +++ b/crates/api/src/error.rs @@ -4,7 +4,10 @@ use std::time::Duration; #[derive(Debug)] pub enum ApiError { - MissingApiKey, + MissingCredentials { + provider: &'static str, + env_vars: &'static [&'static str], + }, ExpiredOAuthToken, Auth(String), InvalidApiKeyEnv(VarError), @@ -30,13 +33,21 @@ pub enum ApiError { } impl ApiError { + #[must_use] + pub const fn missing_credentials( + provider: &'static str, + env_vars: &'static [&'static str], + ) -> Self { + Self::MissingCredentials { provider, env_vars } + } + #[must_use] pub fn is_retryable(&self) -> bool { match self { Self::Http(error) => error.is_connect() || error.is_timeout() || error.is_request(), Self::Api { retryable, .. } => *retryable, Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(), - Self::MissingApiKey + Self::MissingCredentials { .. } | Self::ExpiredOAuthToken | Self::Auth(_) | Self::InvalidApiKeyEnv(_) @@ -51,12 +62,11 @@ impl ApiError { impl Display for ApiError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::MissingApiKey => { - write!( - f, - "ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API" - ) - } + Self::MissingCredentials { provider, env_vars } => write!( + f, + "missing {provider} credentials; export {} before calling the {provider} API", + env_vars.join(" or ") + ), Self::ExpiredOAuthToken => { write!( f, @@ -65,10 +75,7 @@ impl Display for ApiError { } Self::Auth(message) => write!(f, "auth error: {message}"), Self::InvalidApiKeyEnv(error) => { - write!( - f, - "failed to read ANTHROPIC_AUTH_TOKEN / ANTHROPIC_API_KEY: {error}" - ) + write!(f, "failed to read credential environment variable: {error}") } Self::Http(error) => write!(f, "http error: {error}"), Self::Io(error) => write!(f, "io error: {error}"), @@ -81,20 +88,14 @@ impl Display for ApiError { .. } => match (error_type, message) { (Some(error_type), Some(message)) => { - write!( - f, - "anthropic api returned {status} ({error_type}): {message}" - ) + write!(f, "api returned {status} ({error_type}): {message}") } - _ => write!(f, "anthropic api returned {status}: {body}"), + _ => write!(f, "api returned {status}: {body}"), }, Self::RetriesExhausted { attempts, last_error, - } => write!( - f, - "anthropic api failed after {attempts} attempts: {last_error}" - ), + } => write!(f, "api failed after {attempts} attempts: {last_error}"), Self::InvalidSseFrame(message) => write!(f, "invalid sse frame: {message}"), Self::BackoffOverflow { attempt, diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index a91344b..3306f53 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -1,13 +1,19 @@ mod client; mod error; +mod providers; mod sse; mod types; pub use client::{ - oauth_token_is_expired, read_base_url, resolve_saved_oauth_token, - resolve_startup_auth_source, AnthropicClient, AuthSource, MessageStream, OAuthTokenSet, + oauth_token_is_expired, read_base_url, read_xai_base_url, resolve_saved_oauth_token, + resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient, }; pub use error::ApiError; +pub use providers::claw_provider::{AuthSource, ClawApiClient, ClawApiClient as ApiClient}; +pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig}; +pub use providers::{ + detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind, +}; pub use sse::{parse_frame, SseParser}; pub use types::{ ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, diff --git a/crates/api/src/providers/claw_provider.rs b/crates/api/src/providers/claw_provider.rs new file mode 100644 index 0000000..d9046cd --- /dev/null +++ b/crates/api/src/providers/claw_provider.rs @@ -0,0 +1,1046 @@ +use std::collections::VecDeque; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use runtime::{ + load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, + OAuthTokenExchangeRequest, +}; +use serde::Deserialize; + +use crate::error::ApiError; + +use super::{Provider, ProviderFuture}; +use crate::sse::SseParser; +use crate::types::{MessageRequest, MessageResponse, StreamEvent}; + +pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); +const DEFAULT_MAX_RETRIES: u32 = 2; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AuthSource { + None, + ApiKey(String), + BearerToken(String), + ApiKeyAndBearer { + api_key: String, + bearer_token: String, + }, +} + +impl AuthSource { + pub fn from_env() -> Result { + let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?; + let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?; + match (api_key, auth_token) { + (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + (Some(api_key), None) => Ok(Self::ApiKey(api_key)), + (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), + (None, None) => Err(ApiError::missing_credentials( + "Claw", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )), + } + } + + #[must_use] + pub fn api_key(&self) -> Option<&str> { + match self { + Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key), + Self::None | Self::BearerToken(_) => None, + } + } + + #[must_use] + pub fn bearer_token(&self) -> Option<&str> { + match self { + Self::BearerToken(token) + | Self::ApiKeyAndBearer { + bearer_token: token, + .. + } => Some(token), + Self::None | Self::ApiKey(_) => None, + } + } + + #[must_use] + pub fn masked_authorization_header(&self) -> &'static str { + if self.bearer_token().is_some() { + "Bearer [REDACTED]" + } else { + "" + } + } + + pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(api_key) = self.api_key() { + request_builder = request_builder.header("x-api-key", api_key); + } + if let Some(token) = self.bearer_token() { + request_builder = request_builder.bearer_auth(token); + } + request_builder + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +pub struct OAuthTokenSet { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, + #[serde(default)] + pub scopes: Vec, +} + +impl From for AuthSource { + fn from(value: OAuthTokenSet) -> Self { + Self::BearerToken(value.access_token) + } +} + +#[derive(Debug, Clone)] +pub struct ClawApiClient { + http: reqwest::Client, + auth: AuthSource, + base_url: String, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, +} + +impl ClawApiClient { + #[must_use] + pub fn new(api_key: impl Into) -> Self { + Self { + http: reqwest::Client::new(), + auth: AuthSource::ApiKey(api_key.into()), + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + #[must_use] + pub fn from_auth(auth: AuthSource) -> Self { + Self { + http: reqwest::Client::new(), + auth, + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + pub fn from_env() -> Result { + Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url())) + } + + #[must_use] + pub fn with_auth_source(mut self, auth: AuthSource) -> Self { + self.auth = auth; + self + } + + #[must_use] + pub fn with_auth_token(mut self, auth_token: Option) -> Self { + match ( + self.auth.api_key().map(ToOwned::to_owned), + auth_token.filter(|token| !token.is_empty()), + ) { + (Some(api_key), Some(bearer_token)) => { + self.auth = AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }; + } + (Some(api_key), None) => { + self.auth = AuthSource::ApiKey(api_key); + } + (None, Some(bearer_token)) => { + self.auth = AuthSource::BearerToken(bearer_token); + } + (None, None) => { + self.auth = AuthSource::None; + } + } + self + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + #[must_use] + pub fn with_retry_policy( + mut self, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, + ) -> Self { + self.max_retries = max_retries; + self.initial_backoff = initial_backoff; + self.max_backoff = max_backoff; + self + } + + #[must_use] + pub fn auth_source(&self) -> &AuthSource { + &self.auth + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + let request = MessageRequest { + stream: false, + ..request.clone() + }; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let mut response = response + .json::() + .await + .map_err(ApiError::from)?; + if response.request_id.is_none() { + response.request_id = request_id; + } + Ok(response) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + let response = self + .send_with_retry(&request.clone().with_streaming()) + .await?; + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: SseParser::new(), + pending: VecDeque::new(), + done: false, + }) + } + + pub async fn exchange_oauth_code( + &self, + config: &OAuthConfig, + request: &OAuthTokenExchangeRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + pub async fn refresh_oauth_token( + &self, + config: &OAuthConfig, + request: &OAuthRefreshRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + async fn send_with_retry( + &self, + request: &MessageRequest, + ) -> Result { + let mut attempts = 0; + let mut last_error: Option; + + loop { + attempts += 1; + match self.send_raw_request(request).await { + Ok(response) => match expect_success(response).await { + Ok(response) => return Ok(response), + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + last_error = Some(error); + } + Err(error) => return Err(error), + }, + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + last_error = Some(error); + } + Err(error) => return Err(error), + } + + if attempts > self.max_retries { + break; + } + + tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + } + + Err(ApiError::RetriesExhausted { + attempts, + last_error: Box::new(last_error.expect("retry loop must capture an error")), + }) + } + + async fn send_raw_request( + &self, + request: &MessageRequest, + ) -> Result { + let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); + let request_builder = self + .http + .post(&request_url) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("content-type", "application/json"); + let mut request_builder = self.auth.apply(request_builder); + + request_builder = request_builder.json(request); + request_builder.send().await.map_err(ApiError::from) + } + + fn backoff_for_attempt(&self, attempt: u32) -> Result { + let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { + return Err(ApiError::BackoffOverflow { + attempt, + base_delay: self.initial_backoff, + }); + }; + Ok(self + .initial_backoff + .checked_mul(multiplier) + .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) + } +} + +impl AuthSource { + pub fn from_env_or_saved() -> Result { + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(Self::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(Self::BearerToken(bearer_token)); + } + match load_saved_oauth_token() { + Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => { + if token_set.refresh_token.is_some() { + Err(ApiError::Auth( + "saved OAuth token is expired; load runtime OAuth config to refresh it" + .to_string(), + )) + } else { + Err(ApiError::ExpiredOAuthToken) + } + } + Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), + Ok(None) => Err(ApiError::missing_credentials( + "Claw", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )), + Err(error) => Err(error), + } + } +} + +#[must_use] +pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool { + token_set + .expires_at + .is_some_and(|expires_at| expires_at <= now_unix_timestamp()) +} + +pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result, ApiError> { + let Some(token_set) = load_saved_oauth_token()? else { + return Ok(None); + }; + resolve_saved_oauth_token_set(config, token_set).map(Some) +} + +pub fn has_auth_from_env_or_saved() -> Result { + Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some() + || read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some() + || load_saved_oauth_token()?.is_some()) +} + +pub fn resolve_startup_auth_source(load_oauth_config: F) -> Result +where + F: FnOnce() -> Result, ApiError>, +{ + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(AuthSource::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(AuthSource::BearerToken(bearer_token)); + } + + let Some(token_set) = load_saved_oauth_token()? else { + return Err(ApiError::missing_credentials( + "Claw", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )); + }; + if !oauth_token_is_expired(&token_set) { + return Ok(AuthSource::BearerToken(token_set.access_token)); + } + if token_set.refresh_token.is_none() { + return Err(ApiError::ExpiredOAuthToken); + } + + let Some(config) = load_oauth_config()? else { + return Err(ApiError::Auth( + "saved OAuth token is expired; runtime OAuth config is missing".to_string(), + )); + }; + Ok(AuthSource::from(resolve_saved_oauth_token_set( + &config, token_set, + )?)) +} + +fn resolve_saved_oauth_token_set( + config: &OAuthConfig, + token_set: OAuthTokenSet, +) -> Result { + if !oauth_token_is_expired(&token_set) { + return Ok(token_set); + } + let Some(refresh_token) = token_set.refresh_token.clone() else { + return Err(ApiError::ExpiredOAuthToken); + }; + let client = ClawApiClient::from_auth(AuthSource::None).with_base_url(read_base_url()); + let refreshed = client_runtime_block_on(async { + client + .refresh_oauth_token( + config, + &OAuthRefreshRequest::from_config( + config, + refresh_token, + Some(token_set.scopes.clone()), + ), + ) + .await + })?; + let resolved = OAuthTokenSet { + access_token: refreshed.access_token, + refresh_token: refreshed.refresh_token.or(token_set.refresh_token), + expires_at: refreshed.expires_at, + scopes: refreshed.scopes, + }; + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: resolved.access_token.clone(), + refresh_token: resolved.refresh_token.clone(), + expires_at: resolved.expires_at, + scopes: resolved.scopes.clone(), + }) + .map_err(ApiError::from)?; + Ok(resolved) +} + +fn client_runtime_block_on(future: F) -> Result +where + F: std::future::Future>, +{ + tokio::runtime::Runtime::new() + .map_err(ApiError::from)? + .block_on(future) +} + +fn load_saved_oauth_token() -> Result, ApiError> { + let token_set = load_oauth_credentials().map_err(ApiError::from)?; + Ok(token_set.map(|token_set| OAuthTokenSet { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + scopes: token_set.scopes, + })) +} + +fn now_unix_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Err(error) => Err(ApiError::from(error)), + } +} + +#[cfg(test)] +fn read_api_key() -> Result { + let auth = AuthSource::from_env_or_saved()?; + auth.api_key() + .or_else(|| auth.bearer_token()) + .map(ToOwned::to_owned) + .ok_or(ApiError::missing_credentials( + "Claw", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )) +} + +#[cfg(test)] +fn read_auth_token() -> Option { + read_env_non_empty("ANTHROPIC_AUTH_TOKEN") + .ok() + .and_then(std::convert::identity) +} + +#[must_use] +pub fn read_base_url() -> String { + std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()) +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|value| value.to_str().ok()) + .map(ToOwned::to_owned) +} + +impl Provider for ClawApiClient { + type Stream = MessageStream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse> { + Box::pin(async move { self.send_message(request).await }) + } + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream> { + Box::pin(async move { self.stream_message(request).await }) + } +} + +#[derive(Debug)] +pub struct MessageStream { + request_id: Option, + response: reqwest::Response, + parser: SseParser, + pending: VecDeque, + done: bool, +} + +impl MessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + let remaining = self.parser.finish()?; + self.pending.extend(remaining); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + self.pending.extend(self.parser.push(&chunk)?); + } + None => { + self.done = true; + } + } + } + } +} + +async fn expect_success(response: reqwest::Response) -> Result { + let status = response.status(); + if status.is_success() { + return Ok(response); + } + + let body = response.text().await.unwrap_or_else(|_| String::new()); + let parsed_error = serde_json::from_str::(&body).ok(); + let retryable = is_retryable_status(status); + + Err(ApiError::Api { + status, + error_type: parsed_error + .as_ref() + .map(|error| error.error.error_type.clone()), + message: parsed_error + .as_ref() + .map(|error| error.error.message.clone()), + body, + retryable, + }) +} + +const fn is_retryable_status(status: reqwest::StatusCode) -> bool { + matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) +} + +#[derive(Debug, Deserialize)] +struct ApiErrorEnvelope { + error: ApiErrorBody, +} + +#[derive(Debug, Deserialize)] +struct ApiErrorBody { + #[serde(rename = "type")] + error_type: String, + message: String, +} + +#[cfg(test)] +mod tests { + use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; + use std::io::{Read, Write}; + use std::net::TcpListener; + use std::sync::{Mutex, OnceLock}; + use std::thread; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig}; + + use super::{ + now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, + resolve_startup_auth_source, AuthSource, ClawApiClient, OAuthTokenSet, + }; + use crate::types::{ContentBlockDelta, MessageRequest}; + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } + + fn temp_config_home() -> std::path::PathBuf { + std::env::temp_dir().join(format!( + "api-oauth-test-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )) + } + + fn cleanup_temp_config_home(config_home: &std::path::Path) { + match std::fs::remove_dir_all(config_home) { + Ok(()) => {} + Err(error) if error.kind() == std::io::ErrorKind::NotFound => {} + Err(error) => panic!("cleanup temp dir: {error}"), + } + } + + fn sample_oauth_config(token_url: String) -> OAuthConfig { + OAuthConfig { + client_id: "runtime-client".to_string(), + authorize_url: "https://console.test/oauth/authorize".to_string(), + token_url, + callback_port: Some(4545), + manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), + scopes: vec!["org:read".to_string(), "user:write".to_string()], + } + } + + fn spawn_token_server(response_body: &'static str) -> String { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); + let address = listener.local_addr().expect("local addr"); + thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept connection"); + let mut buffer = [0_u8; 4096]; + let _ = stream.read(&mut buffer).expect("read request"); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", + response_body.len(), + response_body + ); + stream + .write_all(response.as_bytes()) + .expect("write response"); + }); + format!("http://{address}/oauth/token") + } + + #[test] + fn read_api_key_requires_presence() { + let _guard = env_lock(); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + std::env::remove_var("CLAW_CONFIG_HOME"); + let error = super::read_api_key().expect_err("missing key should error"); + assert!(matches!( + error, + crate::error::ApiError::MissingCredentials { .. } + )); + } + + #[test] + fn read_api_key_requires_non_empty_value() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); + std::env::remove_var("ANTHROPIC_API_KEY"); + let error = super::read_api_key().expect_err("empty key should error"); + assert!(matches!( + error, + crate::error::ApiError::MissingCredentials { .. } + )); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + } + + #[test] + fn read_api_key_prefers_api_key_env() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + assert_eq!( + super::read_api_key().expect("api key should load"), + "legacy-key" + ); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn read_auth_token_reads_auth_token_env() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + } + + #[test] + fn oauth_token_maps_to_bearer_auth_source() { + let auth = AuthSource::from(OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(123), + scopes: vec!["scope:a".to_string()], + }); + assert_eq!(auth.bearer_token(), Some("access-token")); + assert_eq!(auth.api_key(), None); + } + + #[test] + fn auth_source_from_env_combines_api_key_and_bearer_token() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + let auth = AuthSource::from_env().expect("env auth"); + assert_eq!(auth.api_key(), Some("legacy-key")); + assert_eq!(auth.bearer_token(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn auth_source_from_saved_oauth_when_env_absent() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = AuthSource::from_env_or_saved().expect("saved auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn oauth_token_expiry_uses_expires_at_timestamp() { + assert!(oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(1), + scopes: Vec::new(), + })); + assert!(!oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(now_unix_timestamp() + 60), + scopes: Vec::new(), + })); + } + + #[test] + fn resolve_saved_oauth_token_refreshes_expired_credentials() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "refreshed-token"); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = resolve_startup_auth_source(|| panic!("config should not be loaded")) + .expect("startup auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let error = + resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error"); + assert!( + matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing")) + ); + + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "expired-access-token"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token")); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn message_request_stream_helper_sets_stream_true() { + let request = MessageRequest { + model: "claude-opus-4-6".to_string(), + max_tokens: 64, + messages: vec![], + system: None, + tools: None, + tool_choice: None, + stream: false, + }; + + assert!(request.with_streaming().stream); + } + + #[test] + fn backoff_doubles_until_maximum() { + let client = ClawApiClient::new("test-key").with_retry_policy( + 3, + Duration::from_millis(10), + Duration::from_millis(25), + ); + assert_eq!( + client.backoff_for_attempt(1).expect("attempt 1"), + Duration::from_millis(10) + ); + assert_eq!( + client.backoff_for_attempt(2).expect("attempt 2"), + Duration::from_millis(20) + ); + assert_eq!( + client.backoff_for_attempt(3).expect("attempt 3"), + Duration::from_millis(25) + ); + } + + #[test] + fn retryable_statuses_are_detected() { + assert!(super::is_retryable_status( + reqwest::StatusCode::TOO_MANY_REQUESTS + )); + assert!(super::is_retryable_status( + reqwest::StatusCode::INTERNAL_SERVER_ERROR + )); + assert!(!super::is_retryable_status( + reqwest::StatusCode::UNAUTHORIZED + )); + } + + #[test] + fn tool_delta_variant_round_trips() { + let delta = ContentBlockDelta::InputJsonDelta { + partial_json: "{\"city\":\"Paris\"}".to_string(), + }; + let encoded = serde_json::to_string(&delta).expect("delta should serialize"); + let decoded: ContentBlockDelta = + serde_json::from_str(&encoded).expect("delta should deserialize"); + assert_eq!(decoded, delta); + } + + #[test] + fn request_id_uses_primary_or_fallback_header() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header")); + assert_eq!( + super::request_id_from_headers(&headers).as_deref(), + Some("req_primary") + ); + + headers.clear(); + headers.insert( + ALT_REQUEST_ID_HEADER, + "req_fallback".parse().expect("header"), + ); + assert_eq!( + super::request_id_from_headers(&headers).as_deref(), + Some("req_fallback") + ); + } + + #[test] + fn auth_source_applies_headers() { + let auth = AuthSource::ApiKeyAndBearer { + api_key: "test-key".to_string(), + bearer_token: "proxy-token".to_string(), + }; + let request = auth + .apply(reqwest::Client::new().post("https://example.test")) + .build() + .expect("request build"); + let headers = request.headers(); + assert_eq!( + headers.get("x-api-key").and_then(|v| v.to_str().ok()), + Some("test-key") + ); + assert_eq!( + headers.get("authorization").and_then(|v| v.to_str().ok()), + Some("Bearer proxy-token") + ); + } +} diff --git a/crates/api/src/providers/mod.rs b/crates/api/src/providers/mod.rs new file mode 100644 index 0000000..192afd6 --- /dev/null +++ b/crates/api/src/providers/mod.rs @@ -0,0 +1,239 @@ +use std::future::Future; +use std::pin::Pin; + +use crate::error::ApiError; +use crate::types::{MessageRequest, MessageResponse}; + +pub mod claw_provider; +pub mod openai_compat; + +pub type ProviderFuture<'a, T> = Pin> + Send + 'a>>; + +pub trait Provider { + type Stream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse>; + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream>; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ProviderKind { + ClawApi, + Xai, + OpenAi, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ProviderMetadata { + pub provider: ProviderKind, + pub auth_env: &'static str, + pub base_url_env: &'static str, + pub default_base_url: &'static str, +} + +const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ + ( + "opus", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "sonnet", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "haiku", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "claude-opus-4-6", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "claude-sonnet-4-6", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "claude-haiku-4-5-20251213", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "grok", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-3", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-mini", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-3-mini", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-2", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), +]; + +#[must_use] +pub fn resolve_model_alias(model: &str) -> String { + let trimmed = model.trim(); + let lower = trimmed.to_ascii_lowercase(); + MODEL_REGISTRY + .iter() + .find_map(|(alias, metadata)| { + (*alias == lower).then_some(match metadata.provider { + ProviderKind::ClawApi => match *alias { + "opus" => "claude-opus-4-6", + "sonnet" => "claude-sonnet-4-6", + "haiku" => "claude-haiku-4-5-20251213", + _ => trimmed, + }, + ProviderKind::Xai => match *alias { + "grok" | "grok-3" => "grok-3", + "grok-mini" | "grok-3-mini" => "grok-3-mini", + "grok-2" => "grok-2", + _ => trimmed, + }, + ProviderKind::OpenAi => trimmed, + }) + }) + .map_or_else(|| trimmed.to_string(), ToOwned::to_owned) +} + +#[must_use] +pub fn metadata_for_model(model: &str) -> Option { + let canonical = resolve_model_alias(model); + let lower = canonical.to_ascii_lowercase(); + if let Some((_, metadata)) = MODEL_REGISTRY.iter().find(|(alias, _)| *alias == lower) { + return Some(*metadata); + } + if lower.starts_with("grok") { + return Some(ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }); + } + None +} + +#[must_use] +pub fn detect_provider_kind(model: &str) -> ProviderKind { + if let Some(metadata) = metadata_for_model(model) { + return metadata.provider; + } + if claw_provider::has_auth_from_env_or_saved().unwrap_or(false) { + return ProviderKind::ClawApi; + } + if openai_compat::has_api_key("OPENAI_API_KEY") { + return ProviderKind::OpenAi; + } + if openai_compat::has_api_key("XAI_API_KEY") { + return ProviderKind::Xai; + } + ProviderKind::ClawApi +} + +#[must_use] +pub fn max_tokens_for_model(model: &str) -> u32 { + let canonical = resolve_model_alias(model); + if canonical.contains("opus") { + 32_000 + } else { + 64_000 + } +} + +#[cfg(test)] +mod tests { + use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind}; + + #[test] + fn resolves_grok_aliases() { + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); + assert_eq!(resolve_model_alias("grok-2"), "grok-2"); + } + + #[test] + fn detects_provider_from_model_name_first() { + assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai); + assert_eq!( + detect_provider_kind("claude-sonnet-4-6"), + ProviderKind::ClawApi + ); + } + + #[test] + fn keeps_existing_max_token_heuristic() { + assert_eq!(max_tokens_for_model("opus"), 32_000); + assert_eq!(max_tokens_for_model("grok-3"), 64_000); + } +} diff --git a/crates/api/src/providers/openai_compat.rs b/crates/api/src/providers/openai_compat.rs new file mode 100644 index 0000000..e8210ae --- /dev/null +++ b/crates/api/src/providers/openai_compat.rs @@ -0,0 +1,1050 @@ +use std::collections::{BTreeMap, VecDeque}; +use std::time::Duration; + +use serde::Deserialize; +use serde_json::{json, Value}; + +use crate::error::ApiError; +use crate::types::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest, + MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, + ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, +}; + +use super::{Provider, ProviderFuture}; + +pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; +pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); +const DEFAULT_MAX_RETRIES: u32 = 2; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct OpenAiCompatConfig { + pub provider_name: &'static str, + pub api_key_env: &'static str, + pub base_url_env: &'static str, + pub default_base_url: &'static str, +} + +const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"]; +const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"]; + +impl OpenAiCompatConfig { + #[must_use] + pub const fn xai() -> Self { + Self { + provider_name: "xAI", + api_key_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: DEFAULT_XAI_BASE_URL, + } + } + + #[must_use] + pub const fn openai() -> Self { + Self { + provider_name: "OpenAI", + api_key_env: "OPENAI_API_KEY", + base_url_env: "OPENAI_BASE_URL", + default_base_url: DEFAULT_OPENAI_BASE_URL, + } + } + #[must_use] + pub fn credential_env_vars(self) -> &'static [&'static str] { + match self.provider_name { + "xAI" => XAI_ENV_VARS, + "OpenAI" => OPENAI_ENV_VARS, + _ => &[], + } + } +} + +#[derive(Debug, Clone)] +pub struct OpenAiCompatClient { + http: reqwest::Client, + api_key: String, + base_url: String, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, +} + +impl OpenAiCompatClient { + #[must_use] + pub fn new(api_key: impl Into, config: OpenAiCompatConfig) -> Self { + Self { + http: reqwest::Client::new(), + api_key: api_key.into(), + base_url: read_base_url(config), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + pub fn from_env(config: OpenAiCompatConfig) -> Result { + let Some(api_key) = read_env_non_empty(config.api_key_env)? else { + return Err(ApiError::missing_credentials( + config.provider_name, + config.credential_env_vars(), + )); + }; + Ok(Self::new(api_key, config)) + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + #[must_use] + pub fn with_retry_policy( + mut self, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, + ) -> Self { + self.max_retries = max_retries; + self.initial_backoff = initial_backoff; + self.max_backoff = max_backoff; + self + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + let request = MessageRequest { + stream: false, + ..request.clone() + }; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let payload = response.json::().await?; + let mut normalized = normalize_response(&request.model, payload)?; + if normalized.request_id.is_none() { + normalized.request_id = request_id; + } + Ok(normalized) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + let response = self + .send_with_retry(&request.clone().with_streaming()) + .await?; + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: OpenAiSseParser::new(), + pending: VecDeque::new(), + done: false, + state: StreamState::new(request.model.clone()), + }) + } + + async fn send_with_retry( + &self, + request: &MessageRequest, + ) -> Result { + let mut attempts = 0; + + let last_error = loop { + attempts += 1; + let retryable_error = match self.send_raw_request(request).await { + Ok(response) => match expect_success(response).await { + Ok(response) => return Ok(response), + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, + Err(error) => return Err(error), + }, + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, + Err(error) => return Err(error), + }; + + if attempts > self.max_retries { + break retryable_error; + } + + tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + }; + + Err(ApiError::RetriesExhausted { + attempts, + last_error: Box::new(last_error), + }) + } + + async fn send_raw_request( + &self, + request: &MessageRequest, + ) -> Result { + let request_url = chat_completions_endpoint(&self.base_url); + self.http + .post(&request_url) + .header("content-type", "application/json") + .bearer_auth(&self.api_key) + .json(&build_chat_completion_request(request)) + .send() + .await + .map_err(ApiError::from) + } + + fn backoff_for_attempt(&self, attempt: u32) -> Result { + let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { + return Err(ApiError::BackoffOverflow { + attempt, + base_delay: self.initial_backoff, + }); + }; + Ok(self + .initial_backoff + .checked_mul(multiplier) + .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) + } +} + +impl Provider for OpenAiCompatClient { + type Stream = MessageStream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse> { + Box::pin(async move { self.send_message(request).await }) + } + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream> { + Box::pin(async move { self.stream_message(request).await }) + } +} + +#[derive(Debug)] +pub struct MessageStream { + request_id: Option, + response: reqwest::Response, + parser: OpenAiSseParser, + pending: VecDeque, + done: bool, + state: StreamState, +} + +impl MessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + self.pending.extend(self.state.finish()?); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + for parsed in self.parser.push(&chunk)? { + self.pending.extend(self.state.ingest_chunk(parsed)?); + } + } + None => { + self.done = true; + } + } + } + } +} + +#[derive(Debug, Default)] +struct OpenAiSseParser { + buffer: Vec, +} + +impl OpenAiSseParser { + fn new() -> Self { + Self::default() + } + + fn push(&mut self, chunk: &[u8]) -> Result, ApiError> { + self.buffer.extend_from_slice(chunk); + let mut events = Vec::new(); + + while let Some(frame) = next_sse_frame(&mut self.buffer) { + if let Some(event) = parse_sse_frame(&frame)? { + events.push(event); + } + } + + Ok(events) + } +} + +#[derive(Debug)] +struct StreamState { + model: String, + message_started: bool, + text_started: bool, + text_finished: bool, + finished: bool, + stop_reason: Option, + usage: Option, + tool_calls: BTreeMap, +} + +impl StreamState { + fn new(model: String) -> Self { + Self { + model, + message_started: false, + text_started: false, + text_finished: false, + finished: false, + stop_reason: None, + usage: None, + tool_calls: BTreeMap::new(), + } + } + + fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result, ApiError> { + let mut events = Vec::new(); + if !self.message_started { + self.message_started = true; + events.push(StreamEvent::MessageStart(MessageStartEvent { + message: MessageResponse { + id: chunk.id.clone(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: Vec::new(), + model: chunk.model.clone().unwrap_or_else(|| self.model.clone()), + stop_reason: None, + stop_sequence: None, + usage: Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 0, + }, + request_id: None, + }, + })); + } + + if let Some(usage) = chunk.usage { + self.usage = Some(Usage { + input_tokens: usage.prompt_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: usage.completion_tokens, + }); + } + + for choice in chunk.choices { + if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) { + if !self.text_started { + self.text_started = true; + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 0, + content_block: OutputContentBlock::Text { + text: String::new(), + }, + })); + } + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::TextDelta { text: content }, + })); + } + + for tool_call in choice.delta.tool_calls { + let state = self.tool_calls.entry(tool_call.index).or_default(); + state.apply(tool_call); + let block_index = state.block_index(); + if !state.started { + if let Some(start_event) = state.start_event()? { + state.started = true; + events.push(StreamEvent::ContentBlockStart(start_event)); + } else { + continue; + } + } + if let Some(delta_event) = state.delta_event() { + events.push(StreamEvent::ContentBlockDelta(delta_event)); + } + if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: block_index, + })); + } + } + + if let Some(finish_reason) = choice.finish_reason { + self.stop_reason = Some(normalize_finish_reason(&finish_reason)); + if finish_reason == "tool_calls" { + for state in self.tool_calls.values_mut() { + if state.started && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: state.block_index(), + })); + } + } + } + } + } + + Ok(events) + } + + fn finish(&mut self) -> Result, ApiError> { + if self.finished { + return Ok(Vec::new()); + } + self.finished = true; + + let mut events = Vec::new(); + if self.text_started && !self.text_finished { + self.text_finished = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: 0, + })); + } + + for state in self.tool_calls.values_mut() { + if !state.started { + if let Some(start_event) = state.start_event()? { + state.started = true; + events.push(StreamEvent::ContentBlockStart(start_event)); + if let Some(delta_event) = state.delta_event() { + events.push(StreamEvent::ContentBlockDelta(delta_event)); + } + } + } + if state.started && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: state.block_index(), + })); + } + } + + if self.message_started { + events.push(StreamEvent::MessageDelta(MessageDeltaEvent { + delta: MessageDelta { + stop_reason: Some( + self.stop_reason + .clone() + .unwrap_or_else(|| "end_turn".to_string()), + ), + stop_sequence: None, + }, + usage: self.usage.clone().unwrap_or(Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 0, + }), + })); + events.push(StreamEvent::MessageStop(MessageStopEvent {})); + } + Ok(events) + } +} + +#[derive(Debug, Default)] +struct ToolCallState { + openai_index: u32, + id: Option, + name: Option, + arguments: String, + emitted_len: usize, + started: bool, + stopped: bool, +} + +impl ToolCallState { + fn apply(&mut self, tool_call: DeltaToolCall) { + self.openai_index = tool_call.index; + if let Some(id) = tool_call.id { + self.id = Some(id); + } + if let Some(name) = tool_call.function.name { + self.name = Some(name); + } + if let Some(arguments) = tool_call.function.arguments { + self.arguments.push_str(&arguments); + } + } + + const fn block_index(&self) -> u32 { + self.openai_index + 1 + } + + fn start_event(&self) -> Result, ApiError> { + let Some(name) = self.name.clone() else { + return Ok(None); + }; + let id = self + .id + .clone() + .unwrap_or_else(|| format!("tool_call_{}", self.openai_index)); + Ok(Some(ContentBlockStartEvent { + index: self.block_index(), + content_block: OutputContentBlock::ToolUse { + id, + name, + input: json!({}), + }, + })) + } + + fn delta_event(&mut self) -> Option { + if self.emitted_len >= self.arguments.len() { + return None; + } + let delta = self.arguments[self.emitted_len..].to_string(); + self.emitted_len = self.arguments.len(); + Some(ContentBlockDeltaEvent { + index: self.block_index(), + delta: ContentBlockDelta::InputJsonDelta { + partial_json: delta, + }, + }) + } +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionResponse { + id: String, + model: String, + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatChoice { + message: ChatMessage, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatMessage { + role: String, + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Vec, +} + +#[derive(Debug, Deserialize)] +struct ResponseToolCall { + id: String, + function: ResponseToolFunction, +} + +#[derive(Debug, Deserialize)] +struct ResponseToolFunction { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct OpenAiUsage { + #[serde(default)] + prompt_tokens: u32, + #[serde(default)] + completion_tokens: u32, +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionChunk { + id: String, + #[serde(default)] + model: Option, + #[serde(default)] + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ChunkChoice { + delta: ChunkDelta, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Default, Deserialize)] +struct ChunkDelta { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Vec, +} + +#[derive(Debug, Deserialize)] +struct DeltaToolCall { + #[serde(default)] + index: u32, + #[serde(default)] + id: Option, + #[serde(default)] + function: DeltaFunction, +} + +#[derive(Debug, Default, Deserialize)] +struct DeltaFunction { + #[serde(default)] + name: Option, + #[serde(default)] + arguments: Option, +} + +#[derive(Debug, Deserialize)] +struct ErrorEnvelope { + error: ErrorBody, +} + +#[derive(Debug, Deserialize)] +struct ErrorBody { + #[serde(rename = "type")] + error_type: Option, + message: Option, +} + +fn build_chat_completion_request(request: &MessageRequest) -> Value { + let mut messages = Vec::new(); + if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) { + messages.push(json!({ + "role": "system", + "content": system, + })); + } + for message in &request.messages { + messages.extend(translate_message(message)); + } + + let mut payload = json!({ + "model": request.model, + "max_tokens": request.max_tokens, + "messages": messages, + "stream": request.stream, + }); + + if let Some(tools) = &request.tools { + payload["tools"] = + Value::Array(tools.iter().map(openai_tool_definition).collect::>()); + } + if let Some(tool_choice) = &request.tool_choice { + payload["tool_choice"] = openai_tool_choice(tool_choice); + } + + payload +} + +fn translate_message(message: &InputMessage) -> Vec { + match message.role.as_str() { + "assistant" => { + let mut text = String::new(); + let mut tool_calls = Vec::new(); + for block in &message.content { + match block { + InputContentBlock::Text { text: value } => text.push_str(value), + InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({ + "id": id, + "type": "function", + "function": { + "name": name, + "arguments": input.to_string(), + } + })), + InputContentBlock::ToolResult { .. } => {} + } + } + if text.is_empty() && tool_calls.is_empty() { + Vec::new() + } else { + vec![json!({ + "role": "assistant", + "content": (!text.is_empty()).then_some(text), + "tool_calls": tool_calls, + })] + } + } + _ => message + .content + .iter() + .filter_map(|block| match block { + InputContentBlock::Text { text } => Some(json!({ + "role": "user", + "content": text, + })), + InputContentBlock::ToolResult { + tool_use_id, + content, + is_error, + } => Some(json!({ + "role": "tool", + "tool_call_id": tool_use_id, + "content": flatten_tool_result_content(content), + "is_error": is_error, + })), + InputContentBlock::ToolUse { .. } => None, + }) + .collect(), + } +} + +fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String { + content + .iter() + .map(|block| match block { + ToolResultContentBlock::Text { text } => text.clone(), + ToolResultContentBlock::Json { value } => value.to_string(), + }) + .collect::>() + .join("\n") +} + +fn openai_tool_definition(tool: &ToolDefinition) -> Value { + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + } + }) +} + +fn openai_tool_choice(tool_choice: &ToolChoice) -> Value { + match tool_choice { + ToolChoice::Auto => Value::String("auto".to_string()), + ToolChoice::Any => Value::String("required".to_string()), + ToolChoice::Tool { name } => json!({ + "type": "function", + "function": { "name": name }, + }), + } +} + +fn normalize_response( + model: &str, + response: ChatCompletionResponse, +) -> Result { + let choice = response + .choices + .into_iter() + .next() + .ok_or(ApiError::InvalidSseFrame( + "chat completion response missing choices", + ))?; + let mut content = Vec::new(); + if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) { + content.push(OutputContentBlock::Text { text }); + } + for tool_call in choice.message.tool_calls { + content.push(OutputContentBlock::ToolUse { + id: tool_call.id, + name: tool_call.function.name, + input: parse_tool_arguments(&tool_call.function.arguments), + }); + } + + Ok(MessageResponse { + id: response.id, + kind: "message".to_string(), + role: choice.message.role, + content, + model: response.model.if_empty_then(model.to_string()), + stop_reason: choice + .finish_reason + .map(|value| normalize_finish_reason(&value)), + stop_sequence: None, + usage: Usage { + input_tokens: response + .usage + .as_ref() + .map_or(0, |usage| usage.prompt_tokens), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: response + .usage + .as_ref() + .map_or(0, |usage| usage.completion_tokens), + }, + request_id: None, + }) +} + +fn parse_tool_arguments(arguments: &str) -> Value { + serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments })) +} + +fn next_sse_frame(buffer: &mut Vec) -> Option { + let separator = buffer + .windows(2) + .position(|window| window == b"\n\n") + .map(|position| (position, 2)) + .or_else(|| { + buffer + .windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|position| (position, 4)) + })?; + + let (position, separator_len) = separator; + let frame = buffer.drain(..position + separator_len).collect::>(); + let frame_len = frame.len().saturating_sub(separator_len); + Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned()) +} + +fn parse_sse_frame(frame: &str) -> Result, ApiError> { + let trimmed = frame.trim(); + if trimmed.is_empty() { + return Ok(None); + } + + let mut data_lines = Vec::new(); + for line in trimmed.lines() { + if line.starts_with(':') { + continue; + } + if let Some(data) = line.strip_prefix("data:") { + data_lines.push(data.trim_start()); + } + } + if data_lines.is_empty() { + return Ok(None); + } + let payload = data_lines.join("\n"); + if payload == "[DONE]" { + return Ok(None); + } + serde_json::from_str(&payload) + .map(Some) + .map_err(ApiError::from) +} + +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Err(error) => Err(ApiError::from(error)), + } +} + +#[must_use] +pub fn has_api_key(key: &str) -> bool { + read_env_non_empty(key) + .ok() + .and_then(std::convert::identity) + .is_some() +} + +#[must_use] +pub fn read_base_url(config: OpenAiCompatConfig) -> String { + std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string()) +} + +fn chat_completions_endpoint(base_url: &str) -> String { + let trimmed = base_url.trim_end_matches('/'); + if trimmed.ends_with("/chat/completions") { + trimmed.to_string() + } else { + format!("{trimmed}/chat/completions") + } +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|value| value.to_str().ok()) + .map(ToOwned::to_owned) +} + +async fn expect_success(response: reqwest::Response) -> Result { + let status = response.status(); + if status.is_success() { + return Ok(response); + } + + let body = response.text().await.unwrap_or_default(); + let parsed_error = serde_json::from_str::(&body).ok(); + let retryable = is_retryable_status(status); + + Err(ApiError::Api { + status, + error_type: parsed_error + .as_ref() + .and_then(|error| error.error.error_type.clone()), + message: parsed_error + .as_ref() + .and_then(|error| error.error.message.clone()), + body, + retryable, + }) +} + +const fn is_retryable_status(status: reqwest::StatusCode) -> bool { + matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) +} + +fn normalize_finish_reason(value: &str) -> String { + match value { + "stop" => "end_turn", + "tool_calls" => "tool_use", + other => other, + } + .to_string() +} + +trait StringExt { + fn if_empty_then(self, fallback: String) -> String; +} + +impl StringExt for String { + fn if_empty_then(self, fallback: String) -> String { + if self.is_empty() { + fallback + } else { + self + } + } +} + +#[cfg(test)] +mod tests { + use super::{ + build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason, + openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig, + }; + use crate::error::ApiError; + use crate::types::{ + InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition, + ToolResultContentBlock, + }; + use serde_json::json; + use std::sync::{Mutex, OnceLock}; + + #[test] + fn request_translation_uses_openai_compatible_shape() { + let payload = build_chat_completion_request(&MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![ + InputContentBlock::Text { + text: "hello".to_string(), + }, + InputContentBlock::ToolResult { + tool_use_id: "tool_1".to_string(), + content: vec![ToolResultContentBlock::Json { + value: json!({"ok": true}), + }], + is_error: false, + }, + ], + }], + system: Some("be helpful".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Get weather".to_string()), + input_schema: json!({"type": "object"}), + }]), + tool_choice: Some(ToolChoice::Auto), + stream: false, + }); + + assert_eq!(payload["messages"][0]["role"], json!("system")); + assert_eq!(payload["messages"][1]["role"], json!("user")); + assert_eq!(payload["messages"][2]["role"], json!("tool")); + assert_eq!(payload["tools"][0]["type"], json!("function")); + assert_eq!(payload["tool_choice"], json!("auto")); + } + + #[test] + fn tool_choice_translation_supports_required_function() { + assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required")); + assert_eq!( + openai_tool_choice(&ToolChoice::Tool { + name: "weather".to_string(), + }), + json!({"type": "function", "function": {"name": "weather"}}) + ); + } + + #[test] + fn parses_tool_arguments_fallback() { + assert_eq!( + parse_tool_arguments("{\"city\":\"Paris\"}"), + json!({"city": "Paris"}) + ); + assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"})); + } + + #[test] + fn missing_xai_api_key_is_provider_specific() { + let _lock = env_lock(); + std::env::remove_var("XAI_API_KEY"); + let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai()) + .expect_err("missing key should error"); + assert!(matches!( + error, + ApiError::MissingCredentials { + provider: "xAI", + .. + } + )); + } + + #[test] + fn endpoint_builder_accepts_base_urls_and_full_endpoints() { + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1"), + "https://api.x.ai/v1/chat/completions" + ); + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1/"), + "https://api.x.ai/v1/chat/completions" + ); + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1/chat/completions"), + "https://api.x.ai/v1/chat/completions" + ); + } + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock") + } + + #[test] + fn normalizes_stop_reasons() { + assert_eq!(normalize_finish_reason("stop"), "end_turn"); + assert_eq!(normalize_finish_reason("tool_calls"), "tool_use"); + } +} diff --git a/crates/api/src/sse.rs b/crates/api/src/sse.rs index d7334cd..5f54e50 100644 --- a/crates/api/src/sse.rs +++ b/crates/api/src/sse.rs @@ -216,4 +216,64 @@ mod tests { )) ); } + + #[test] + fn parses_thinking_content_block_start() { + let frame = concat!( + "event: content_block_start\n", + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\",\"signature\":null}}\n\n" + ); + + let event = parse_frame(frame).expect("frame should parse"); + assert_eq!( + event, + Some(StreamEvent::ContentBlockStart( + crate::types::ContentBlockStartEvent { + index: 0, + content_block: OutputContentBlock::Thinking { + thinking: String::new(), + signature: None, + }, + }, + )) + ); + } + + #[test] + fn parses_thinking_related_deltas() { + let thinking = concat!( + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"step 1\"}}\n\n" + ); + let signature = concat!( + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"signature_delta\",\"signature\":\"sig_123\"}}\n\n" + ); + + let thinking_event = parse_frame(thinking).expect("thinking delta should parse"); + let signature_event = parse_frame(signature).expect("signature delta should parse"); + + assert_eq!( + thinking_event, + Some(StreamEvent::ContentBlockDelta( + crate::types::ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::ThinkingDelta { + thinking: "step 1".to_string(), + }, + } + )) + ); + assert_eq!( + signature_event, + Some(StreamEvent::ContentBlockDelta( + crate::types::ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::SignatureDelta { + signature: "sig_123".to_string(), + }, + } + )) + ); + } } diff --git a/crates/api/src/types.rs b/crates/api/src/types.rs index 45d5c08..c060be6 100644 --- a/crates/api/src/types.rs +++ b/crates/api/src/types.rs @@ -135,6 +135,15 @@ pub enum OutputContentBlock { name: String, input: Value, }, + Thinking { + #[serde(default)] + thinking: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + signature: Option, + }, + RedactedThinking { + data: Value, + }, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -190,6 +199,8 @@ pub struct ContentBlockDeltaEvent { pub enum ContentBlockDelta { TextDelta { text: String }, InputJsonDelta { partial_json: String }, + ThinkingDelta { thinking: String }, + SignatureDelta { signature: String }, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] diff --git a/crates/api/tests/client_integration.rs b/crates/api/tests/client_integration.rs index c37fa99..3b6a3c3 100644 --- a/crates/api/tests/client_integration.rs +++ b/crates/api/tests/client_integration.rs @@ -3,9 +3,9 @@ use std::sync::Arc; use std::time::Duration; use api::{ - AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, - InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock, - StreamEvent, ToolChoice, ToolDefinition, + ApiClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent, + ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, + OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, }; use serde_json::json; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -20,8 +20,8 @@ async fn send_message_posts_json_and_parses_response() { "\"id\":\"msg_test\",", "\"type\":\"message\",", "\"role\":\"assistant\",", - "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],", - "\"model\":\"claude-3-7-sonnet-latest\",", + "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claw\"}],", + "\"model\":\"claude-sonnet-4-6\",", "\"stop_reason\":\"end_turn\",", "\"stop_sequence\":null,", "\"usage\":{\"input_tokens\":12,\"output_tokens\":4},", @@ -34,7 +34,7 @@ async fn send_message_posts_json_and_parses_response() { ) .await; - let client = AnthropicClient::new("test-key") + let client = ApiClient::new("test-key") .with_auth_token(Some("proxy-token".to_string())) .with_base_url(server.base_url()); let response = client @@ -48,7 +48,7 @@ async fn send_message_posts_json_and_parses_response() { assert_eq!( response.content, vec![OutputContentBlock::Text { - text: "Hello from Claude".to_string(), + text: "Hello from Claw".to_string(), }] ); @@ -68,7 +68,7 @@ async fn send_message_posts_json_and_parses_response() { serde_json::from_str(&request.body).expect("request body should be json"); assert_eq!( body.get("model").and_then(serde_json::Value::as_str), - Some("claude-3-7-sonnet-latest") + Some("claude-sonnet-4-6") ); assert!(body.get("stream").is_none()); assert_eq!(body["tools"][0]["name"], json!("get_weather")); @@ -80,7 +80,7 @@ async fn stream_message_parses_sse_events_with_tool_use() { let state = Arc::new(Mutex::new(Vec::::new())); let sse = concat!( "event: message_start\n", - "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n", "event: content_block_start\n", "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n", "event: content_block_delta\n", @@ -104,7 +104,7 @@ async fn stream_message_parses_sse_events_with_tool_use() { ) .await; - let client = AnthropicClient::new("test-key") + let client = ApiClient::new("test-key") .with_auth_token(Some("proxy-token".to_string())) .with_base_url(server.base_url()); let mut stream = client @@ -176,13 +176,13 @@ async fn retries_retryable_failures_before_succeeding() { http_response( "200 OK", "application/json", - "{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", + "{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", ), ], ) .await; - let client = AnthropicClient::new("test-key") + let client = ApiClient::new("test-key") .with_base_url(server.base_url()) .with_retry_policy(2, Duration::from_millis(1), Duration::from_millis(2)); @@ -195,6 +195,47 @@ async fn retries_retryable_failures_before_succeeding() { assert_eq!(state.lock().await.len(), 2); } +#[tokio::test] +async fn provider_client_dispatches_api_requests() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", + )], + ) + .await; + + let client = ProviderClient::from_model_with_default_auth( + "claude-sonnet-4-6", + Some(AuthSource::ApiKey("test-key".to_string())), + ) + .expect("api provider client should be constructed"); + let client = match client { + ProviderClient::ClawApi(client) => { + ProviderClient::ClawApi(client.with_base_url(server.base_url())) + } + other => panic!("expected default provider, got {other:?}"), + }; + + let response = client + .send_message(&sample_request(false)) + .await + .expect("provider-dispatched request should succeed"); + + assert_eq!(response.total_tokens(), 5); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/v1/messages"); + assert_eq!( + request.headers.get("x-api-key").map(String::as_str), + Some("test-key") + ); +} + #[tokio::test] async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { let state = Arc::new(Mutex::new(Vec::::new())); @@ -215,7 +256,7 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { ) .await; - let client = AnthropicClient::new("test-key") + let client = ApiClient::new("test-key") .with_base_url(server.base_url()) .with_retry_policy(1, Duration::from_millis(1), Duration::from_millis(2)); @@ -246,11 +287,10 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY and network access"] async fn live_stream_smoke_test() { - let client = AnthropicClient::from_env().expect("ANTHROPIC_API_KEY must be set"); + let client = ApiClient::from_env().expect("ANTHROPIC_API_KEY must be set"); let mut stream = client .stream_message(&MessageRequest { - model: std::env::var("ANTHROPIC_MODEL") - .unwrap_or_else(|_| "claude-3-7-sonnet-latest".to_string()), + model: std::env::var("CLAW_MODEL").unwrap_or_else(|_| "claude-sonnet-4-6".to_string()), max_tokens: 32, messages: vec![InputMessage::user_text( "Reply with exactly: hello from rust", @@ -410,7 +450,7 @@ fn http_response_with_headers( fn sample_request(stream: bool) -> MessageRequest { MessageRequest { - model: "claude-3-7-sonnet-latest".to_string(), + model: "claude-sonnet-4-6".to_string(), max_tokens: 64, messages: vec![InputMessage { role: "user".to_string(), diff --git a/crates/api/tests/openai_compat_integration.rs b/crates/api/tests/openai_compat_integration.rs new file mode 100644 index 0000000..b345b1f --- /dev/null +++ b/crates/api/tests/openai_compat_integration.rs @@ -0,0 +1,415 @@ +use std::collections::HashMap; +use std::ffi::OsString; +use std::sync::Arc; +use std::sync::{Mutex as StdMutex, OnceLock}; + +use api::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig, + OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, +}; +use serde_json::json; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::sync::Mutex; + +#[tokio::test] +async fn send_message_uses_openai_compatible_endpoint_and_auth() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_test\",", + "\"model\":\"grok-3\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Hello from Grok\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":5}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(server.base_url()); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.model, "grok-3"); + assert_eq!(response.total_tokens(), 16); + assert_eq!( + response.content, + vec![OutputContentBlock::Text { + text: "Hello from Grok".to_string(), + }] + ); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/chat/completions"); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer xai-test-key") + ); + let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body"); + assert_eq!(body["model"], json!("grok-3")); + assert_eq!(body["messages"][0]["role"], json!("system")); + assert_eq!(body["tools"][0]["type"], json!("function")); +} + +#[tokio::test] +async fn send_message_accepts_full_chat_completions_endpoint_override() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_full_endpoint\",", + "\"model\":\"grok-3\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Endpoint override works\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let endpoint_url = format!("{}/chat/completions", server.base_url()); + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(endpoint_url); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.total_tokens(), 10); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/chat/completions"); +} + +#[tokio::test] +async fn stream_message_normalizes_text_and_multiple_tool_calls() { + let state = Arc::new(Mutex::new(Vec::::new())); + let sse = concat!( + "data: {\"id\":\"chatcmpl_stream\",\"model\":\"grok-3\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n", + "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}},{\"index\":1,\"id\":\"call_2\",\"function\":{\"name\":\"clock\",\"arguments\":\"{\\\"zone\\\":\\\"UTC\\\"}\"}}]}}]}\n\n", + "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n", + "data: [DONE]\n\n" + ); + let server = spawn_server( + state.clone(), + vec![http_response_with_headers( + "200 OK", + "text/event-stream", + sse, + &[("x-request-id", "req_grok_stream")], + )], + ) + .await; + + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(server.base_url()); + let mut stream = client + .stream_message(&sample_request(false)) + .await + .expect("stream should start"); + + assert_eq!(stream.request_id(), Some("req_grok_stream")); + + let mut events = Vec::new(); + while let Some(event) = stream.next_event().await.expect("event should parse") { + events.push(event); + } + + assert!(matches!(events[0], StreamEvent::MessageStart(_))); + assert!(matches!( + events[1], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + content_block: OutputContentBlock::Text { .. }, + .. + }) + )); + assert!(matches!( + events[2], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + delta: ContentBlockDelta::TextDelta { .. }, + .. + }) + )); + assert!(matches!( + events[3], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 1, + content_block: OutputContentBlock::ToolUse { .. }, + }) + )); + assert!(matches!( + events[4], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 1, + delta: ContentBlockDelta::InputJsonDelta { .. }, + }) + )); + assert!(matches!( + events[5], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 2, + content_block: OutputContentBlock::ToolUse { .. }, + }) + )); + assert!(matches!( + events[6], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 2, + delta: ContentBlockDelta::InputJsonDelta { .. }, + }) + )); + assert!(matches!( + events[7], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 1 }) + )); + assert!(matches!( + events[8], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 2 }) + )); + assert!(matches!( + events[9], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 }) + )); + assert!(matches!(events[10], StreamEvent::MessageDelta(_))); + assert!(matches!(events[11], StreamEvent::MessageStop(_))); + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + assert!(request.body.contains("\"stream\":true")); +} + +#[tokio::test] +async fn provider_client_dispatches_xai_requests_from_env() { + let _lock = env_lock(); + let _api_key = ScopedEnvVar::set("XAI_API_KEY", "xai-test-key"); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"chatcmpl_provider\",\"model\":\"grok-3\",\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Through provider client\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}", + )], + ) + .await; + let _base_url = ScopedEnvVar::set("XAI_BASE_URL", server.base_url()); + + let client = + ProviderClient::from_model("grok").expect("xAI provider client should be constructed"); + assert!(matches!(client, ProviderClient::Xai(_))); + + let response = client + .send_message(&sample_request(false)) + .await + .expect("provider-dispatched request should succeed"); + + assert_eq!(response.total_tokens(), 13); + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer xai-test-key") + ); +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CapturedRequest { + path: String, + headers: HashMap, + body: String, +} + +struct TestServer { + base_url: String, + join_handle: tokio::task::JoinHandle<()>, +} + +impl TestServer { + fn base_url(&self) -> String { + self.base_url.clone() + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.join_handle.abort(); + } +} + +async fn spawn_server( + state: Arc>>, + responses: Vec, +) -> TestServer { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let address = listener.local_addr().expect("listener addr"); + let join_handle = tokio::spawn(async move { + for response in responses { + let (mut socket, _) = listener.accept().await.expect("accept"); + let mut buffer = Vec::new(); + let mut header_end = None; + loop { + let mut chunk = [0_u8; 1024]; + let read = socket.read(&mut chunk).await.expect("read request"); + if read == 0 { + break; + } + buffer.extend_from_slice(&chunk[..read]); + if let Some(position) = find_header_end(&buffer) { + header_end = Some(position); + break; + } + } + + let header_end = header_end.expect("headers should exist"); + let (header_bytes, remaining) = buffer.split_at(header_end); + let header_text = String::from_utf8(header_bytes.to_vec()).expect("utf8 headers"); + let mut lines = header_text.split("\r\n"); + let request_line = lines.next().expect("request line"); + let path = request_line + .split_whitespace() + .nth(1) + .expect("path") + .to_string(); + let mut headers = HashMap::new(); + let mut content_length = 0_usize; + for line in lines { + if line.is_empty() { + continue; + } + let (name, value) = line.split_once(':').expect("header"); + let value = value.trim().to_string(); + if name.eq_ignore_ascii_case("content-length") { + content_length = value.parse().expect("content length"); + } + headers.insert(name.to_ascii_lowercase(), value); + } + + let mut body = remaining[4..].to_vec(); + while body.len() < content_length { + let mut chunk = vec![0_u8; content_length - body.len()]; + let read = socket.read(&mut chunk).await.expect("read body"); + if read == 0 { + break; + } + body.extend_from_slice(&chunk[..read]); + } + + state.lock().await.push(CapturedRequest { + path, + headers, + body: String::from_utf8(body).expect("utf8 body"), + }); + + socket + .write_all(response.as_bytes()) + .await + .expect("write response"); + } + }); + + TestServer { + base_url: format!("http://{address}"), + join_handle, + } +} + +fn find_header_end(bytes: &[u8]) -> Option { + bytes.windows(4).position(|window| window == b"\r\n\r\n") +} + +fn http_response(status: &str, content_type: &str, body: &str) -> String { + http_response_with_headers(status, content_type, body, &[]) +} + +fn http_response_with_headers( + status: &str, + content_type: &str, + body: &str, + headers: &[(&str, &str)], +) -> String { + let mut extra_headers = String::new(); + for (name, value) in headers { + use std::fmt::Write as _; + write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write"); + } + format!( + "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ) +} + +fn sample_request(stream: bool) -> MessageRequest { + MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "Say hello".to_string(), + }], + }], + system: Some("Use tools when needed".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Fetches weather".to_string()), + input_schema: json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + }), + }]), + tool_choice: Some(ToolChoice::Auto), + stream, + } +} + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| StdMutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +struct ScopedEnvVar { + key: &'static str, + previous: Option, +} + +impl ScopedEnvVar { + fn set(key: &'static str, value: impl AsRef) -> Self { + let previous = std::env::var_os(key); + std::env::set_var(key, value); + Self { key, previous } + } +} + +impl Drop for ScopedEnvVar { + fn drop(&mut self) { + match &self.previous { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +} diff --git a/crates/api/tests/provider_client_integration.rs b/crates/api/tests/provider_client_integration.rs new file mode 100644 index 0000000..abeebdd --- /dev/null +++ b/crates/api/tests/provider_client_integration.rs @@ -0,0 +1,86 @@ +use std::ffi::OsString; +use std::sync::{Mutex, OnceLock}; + +use api::{read_xai_base_url, ApiError, AuthSource, ProviderClient, ProviderKind}; + +#[test] +fn provider_client_routes_grok_aliases_through_xai() { + let _lock = env_lock(); + let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", Some("xai-test-key")); + + let client = ProviderClient::from_model("grok-mini").expect("grok alias should resolve"); + + assert_eq!(client.provider_kind(), ProviderKind::Xai); +} + +#[test] +fn provider_client_reports_missing_xai_credentials_for_grok_models() { + let _lock = env_lock(); + let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", None); + + let error = ProviderClient::from_model("grok-3") + .expect_err("grok requests without XAI_API_KEY should fail fast"); + + match error { + ApiError::MissingCredentials { provider, env_vars } => { + assert_eq!(provider, "xAI"); + assert_eq!(env_vars, &["XAI_API_KEY"]); + } + other => panic!("expected missing xAI credentials, got {other:?}"), + } +} + +#[test] +fn provider_client_uses_explicit_auth_without_env_lookup() { + let _lock = env_lock(); + let _api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None); + let _auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None); + + let client = ProviderClient::from_model_with_default_auth( + "claude-sonnet-4-6", + Some(AuthSource::ApiKey("claw-test-key".to_string())), + ) + .expect("explicit auth should avoid env lookup"); + + assert_eq!(client.provider_kind(), ProviderKind::ClawApi); +} + +#[test] +fn read_xai_base_url_prefers_env_override() { + let _lock = env_lock(); + let _xai_base_url = EnvVarGuard::set("XAI_BASE_URL", Some("https://example.xai.test/v1")); + + assert_eq!(read_xai_base_url(), "https://example.xai.test/v1"); +} + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +struct EnvVarGuard { + key: &'static str, + original: Option, +} + +impl EnvVarGuard { + fn set(key: &'static str, value: Option<&str>) -> Self { + let original = std::env::var_os(key); + match value { + Some(value) => std::env::set_var(key, value), + None => std::env::remove_var(key), + } + Self { key, original } + } +} + +impl Drop for EnvVarGuard { + fn drop(&mut self) { + match &self.original { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +} diff --git a/crates/claw-cli/Cargo.toml b/crates/claw-cli/Cargo.toml new file mode 100644 index 0000000..074718a --- /dev/null +++ b/crates/claw-cli/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "claw-cli" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true + +[[bin]] +name = "claw" +path = "src/main.rs" + +[dependencies] +api = { path = "../api" } +commands = { path = "../commands" } +compat-harness = { path = "../compat-harness" } +crossterm = "0.28" +pulldown-cmark = "0.13" +rustyline = "15" +runtime = { path = "../runtime" } +plugins = { path = "../plugins" } +serde_json.workspace = true +syntect = "5" +tokio = { version = "1", features = ["rt-multi-thread", "time"] } +tools = { path = "../tools" } + +[lints] +workspace = true diff --git a/crates/claw-cli/src/app.rs b/crates/claw-cli/src/app.rs new file mode 100644 index 0000000..85e754f --- /dev/null +++ b/crates/claw-cli/src/app.rs @@ -0,0 +1,402 @@ +use std::io::{self, Write}; +use std::path::PathBuf; + +use crate::args::{OutputFormat, PermissionMode}; +use crate::input::{LineEditor, ReadOutcome}; +use crate::render::{Spinner, TerminalRenderer}; +use runtime::{ConversationClient, ConversationMessage, RuntimeError, StreamEvent, UsageSummary}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionConfig { + pub model: String, + pub permission_mode: PermissionMode, + pub config: Option, + pub output_format: OutputFormat, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionState { + pub turns: usize, + pub compacted_messages: usize, + pub last_model: String, + pub last_usage: UsageSummary, +} + +impl SessionState { + #[must_use] + pub fn new(model: impl Into) -> Self { + Self { + turns: 0, + compacted_messages: 0, + last_model: model.into(), + last_usage: UsageSummary::default(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CommandResult { + Continue, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SlashCommand { + Help, + Status, + Compact, + Unknown(String), +} + +impl SlashCommand { + #[must_use] + pub fn parse(input: &str) -> Option { + let trimmed = input.trim(); + if !trimmed.starts_with('/') { + return None; + } + + let command = trimmed + .trim_start_matches('/') + .split_whitespace() + .next() + .unwrap_or_default(); + Some(match command { + "help" => Self::Help, + "status" => Self::Status, + "compact" => Self::Compact, + other => Self::Unknown(other.to_string()), + }) + } +} + +struct SlashCommandHandler { + command: SlashCommand, + summary: &'static str, +} + +const SLASH_COMMAND_HANDLERS: &[SlashCommandHandler] = &[ + SlashCommandHandler { + command: SlashCommand::Help, + summary: "Show command help", + }, + SlashCommandHandler { + command: SlashCommand::Status, + summary: "Show current session status", + }, + SlashCommandHandler { + command: SlashCommand::Compact, + summary: "Compact local session history", + }, +]; + +pub struct CliApp { + config: SessionConfig, + renderer: TerminalRenderer, + state: SessionState, + conversation_client: ConversationClient, + conversation_history: Vec, +} + +impl CliApp { + pub fn new(config: SessionConfig) -> Result { + let state = SessionState::new(config.model.clone()); + let conversation_client = ConversationClient::from_env(config.model.clone())?; + Ok(Self { + config, + renderer: TerminalRenderer::new(), + state, + conversation_client, + conversation_history: Vec::new(), + }) + } + + pub fn run_repl(&mut self) -> io::Result<()> { + let mut editor = LineEditor::new("› ", Vec::new()); + println!("Claw Code interactive mode"); + println!("Type /help for commands. Shift+Enter or Ctrl+J inserts a newline."); + + loop { + match editor.read_line()? { + ReadOutcome::Submit(input) => { + if input.trim().is_empty() { + continue; + } + self.handle_submission(&input, &mut io::stdout())?; + } + ReadOutcome::Cancel => continue, + ReadOutcome::Exit => break, + } + } + + Ok(()) + } + + pub fn run_prompt(&mut self, prompt: &str, out: &mut impl Write) -> io::Result<()> { + self.render_response(prompt, out) + } + + pub fn handle_submission( + &mut self, + input: &str, + out: &mut impl Write, + ) -> io::Result { + if let Some(command) = SlashCommand::parse(input) { + return self.dispatch_slash_command(command, out); + } + + self.state.turns += 1; + self.render_response(input, out)?; + Ok(CommandResult::Continue) + } + + fn dispatch_slash_command( + &mut self, + command: SlashCommand, + out: &mut impl Write, + ) -> io::Result { + match command { + SlashCommand::Help => Self::handle_help(out), + SlashCommand::Status => self.handle_status(out), + SlashCommand::Compact => self.handle_compact(out), + SlashCommand::Unknown(name) => { + writeln!(out, "Unknown slash command: /{name}")?; + Ok(CommandResult::Continue) + } + _ => { + writeln!(out, "Slash command unavailable in this mode")?; + Ok(CommandResult::Continue) + } + } + } + + fn handle_help(out: &mut impl Write) -> io::Result { + writeln!(out, "Available commands:")?; + for handler in SLASH_COMMAND_HANDLERS { + let name = match handler.command { + SlashCommand::Help => "/help", + SlashCommand::Status => "/status", + SlashCommand::Compact => "/compact", + _ => continue, + }; + writeln!(out, " {name:<9} {}", handler.summary)?; + } + Ok(CommandResult::Continue) + } + + fn handle_status(&mut self, out: &mut impl Write) -> io::Result { + writeln!( + out, + "status: turns={} model={} permission-mode={:?} output-format={:?} last-usage={} in/{} out config={}", + self.state.turns, + self.state.last_model, + self.config.permission_mode, + self.config.output_format, + self.state.last_usage.input_tokens, + self.state.last_usage.output_tokens, + self.config + .config + .as_ref() + .map_or_else(|| String::from(""), |path| path.display().to_string()) + )?; + Ok(CommandResult::Continue) + } + + fn handle_compact(&mut self, out: &mut impl Write) -> io::Result { + self.state.compacted_messages += self.state.turns; + self.state.turns = 0; + self.conversation_history.clear(); + writeln!( + out, + "Compacted session history into a local summary ({} messages total compacted).", + self.state.compacted_messages + )?; + Ok(CommandResult::Continue) + } + + fn handle_stream_event( + renderer: &TerminalRenderer, + event: StreamEvent, + stream_spinner: &mut Spinner, + tool_spinner: &mut Spinner, + saw_text: &mut bool, + turn_usage: &mut UsageSummary, + out: &mut impl Write, + ) { + match event { + StreamEvent::TextDelta(delta) => { + if !*saw_text { + let _ = + stream_spinner.finish("Streaming response", renderer.color_theme(), out); + *saw_text = true; + } + let _ = write!(out, "{delta}"); + let _ = out.flush(); + } + StreamEvent::ToolCallStart { name, input } => { + if *saw_text { + let _ = writeln!(out); + } + let _ = tool_spinner.tick( + &format!("Running tool `{name}` with {input}"), + renderer.color_theme(), + out, + ); + } + StreamEvent::ToolCallResult { + name, + output, + is_error, + } => { + let label = if is_error { + format!("Tool `{name}` failed") + } else { + format!("Tool `{name}` completed") + }; + let _ = tool_spinner.finish(&label, renderer.color_theme(), out); + let rendered_output = format!("### Tool `{name}`\n\n```text\n{output}\n```\n"); + let _ = renderer.stream_markdown(&rendered_output, out); + } + StreamEvent::Usage(usage) => { + *turn_usage = usage; + } + } + } + + fn write_turn_output( + &self, + summary: &runtime::TurnSummary, + out: &mut impl Write, + ) -> io::Result<()> { + match self.config.output_format { + OutputFormat::Text => { + writeln!( + out, + "\nToken usage: {} input / {} output", + self.state.last_usage.input_tokens, self.state.last_usage.output_tokens + )?; + } + OutputFormat::Json => { + writeln!( + out, + "{}", + serde_json::json!({ + "message": summary.assistant_text, + "usage": { + "input_tokens": self.state.last_usage.input_tokens, + "output_tokens": self.state.last_usage.output_tokens, + } + }) + )?; + } + OutputFormat::Ndjson => { + writeln!( + out, + "{}", + serde_json::json!({ + "type": "message", + "text": summary.assistant_text, + "usage": { + "input_tokens": self.state.last_usage.input_tokens, + "output_tokens": self.state.last_usage.output_tokens, + } + }) + )?; + } + } + Ok(()) + } + + fn render_response(&mut self, input: &str, out: &mut impl Write) -> io::Result<()> { + let mut stream_spinner = Spinner::new(); + stream_spinner.tick( + "Opening conversation stream", + self.renderer.color_theme(), + out, + )?; + + let mut turn_usage = UsageSummary::default(); + let mut tool_spinner = Spinner::new(); + let mut saw_text = false; + let renderer = &self.renderer; + + let result = + self.conversation_client + .run_turn(&mut self.conversation_history, input, |event| { + Self::handle_stream_event( + renderer, + event, + &mut stream_spinner, + &mut tool_spinner, + &mut saw_text, + &mut turn_usage, + out, + ); + }); + + let summary = match result { + Ok(summary) => summary, + Err(error) => { + stream_spinner.fail( + "Streaming response failed", + self.renderer.color_theme(), + out, + )?; + return Err(io::Error::other(error)); + } + }; + self.state.last_usage = summary.usage.clone(); + if saw_text { + writeln!(out)?; + } else { + stream_spinner.finish("Streaming response", self.renderer.color_theme(), out)?; + } + + self.write_turn_output(&summary, out)?; + let _ = turn_usage; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use crate::args::{OutputFormat, PermissionMode}; + + use super::{CommandResult, SessionConfig, SlashCommand}; + + #[test] + fn parses_required_slash_commands() { + assert_eq!(SlashCommand::parse("/help"), Some(SlashCommand::Help)); + assert_eq!(SlashCommand::parse(" /status "), Some(SlashCommand::Status)); + assert_eq!( + SlashCommand::parse("/compact now"), + Some(SlashCommand::Compact) + ); + } + + #[test] + fn help_output_lists_commands() { + let mut out = Vec::new(); + let result = super::CliApp::handle_help(&mut out).expect("help succeeds"); + assert_eq!(result, CommandResult::Continue); + let output = String::from_utf8_lossy(&out); + assert!(output.contains("/help")); + assert!(output.contains("/status")); + assert!(output.contains("/compact")); + } + + #[test] + fn session_state_tracks_config_values() { + let config = SessionConfig { + model: "sonnet".into(), + permission_mode: PermissionMode::DangerFullAccess, + config: Some(PathBuf::from("settings.toml")), + output_format: OutputFormat::Text, + }; + + assert_eq!(config.model, "sonnet"); + assert_eq!(config.permission_mode, PermissionMode::DangerFullAccess); + assert_eq!(config.config, Some(PathBuf::from("settings.toml"))); + } +} diff --git a/crates/claw-cli/src/args.rs b/crates/claw-cli/src/args.rs new file mode 100644 index 0000000..3c204a9 --- /dev/null +++ b/crates/claw-cli/src/args.rs @@ -0,0 +1,104 @@ +use std::path::PathBuf; + +use clap::{Parser, Subcommand, ValueEnum}; + +#[derive(Debug, Clone, Parser, PartialEq, Eq)] +#[command(name = "claw-cli", version, about = "Claw Code CLI")] +pub struct Cli { + #[arg(long, default_value = "claude-opus-4-6")] + pub model: String, + + #[arg(long, value_enum, default_value_t = PermissionMode::DangerFullAccess)] + pub permission_mode: PermissionMode, + + #[arg(long)] + pub config: Option, + + #[arg(long, value_enum, default_value_t = OutputFormat::Text)] + pub output_format: OutputFormat, + + #[command(subcommand)] + pub command: Option, +} + +#[derive(Debug, Clone, Subcommand, PartialEq, Eq)] +pub enum Command { + /// Read upstream TS sources and print extracted counts + DumpManifests, + /// Print the current bootstrap phase skeleton + BootstrapPlan, + /// Start the OAuth login flow + Login, + /// Clear saved OAuth credentials + Logout, + /// Run a non-interactive prompt and exit + Prompt { prompt: Vec }, +} + +#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)] +pub enum PermissionMode { + ReadOnly, + WorkspaceWrite, + DangerFullAccess, +} + +#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)] +pub enum OutputFormat { + Text, + Json, + Ndjson, +} + +#[cfg(test)] +mod tests { + use clap::Parser; + + use super::{Cli, Command, OutputFormat, PermissionMode}; + + #[test] + fn parses_requested_flags() { + let cli = Cli::parse_from([ + "claw-cli", + "--model", + "claude-haiku-4-5-20251213", + "--permission-mode", + "read-only", + "--config", + "/tmp/config.toml", + "--output-format", + "ndjson", + "prompt", + "hello", + "world", + ]); + + assert_eq!(cli.model, "claude-haiku-4-5-20251213"); + assert_eq!(cli.permission_mode, PermissionMode::ReadOnly); + assert_eq!( + cli.config.as_deref(), + Some(std::path::Path::new("/tmp/config.toml")) + ); + assert_eq!(cli.output_format, OutputFormat::Ndjson); + assert_eq!( + cli.command, + Some(Command::Prompt { + prompt: vec!["hello".into(), "world".into()] + }) + ); + } + + #[test] + fn parses_login_and_logout_commands() { + let login = Cli::parse_from(["claw-cli", "login"]); + assert_eq!(login.command, Some(Command::Login)); + + let logout = Cli::parse_from(["claw-cli", "logout"]); + assert_eq!(logout.command, Some(Command::Logout)); + } + + #[test] + fn defaults_to_danger_full_access_permission_mode() { + let cli = Cli::parse_from(["claw-cli"]); + assert_eq!(cli.permission_mode, PermissionMode::DangerFullAccess); + } +} diff --git a/crates/claw-cli/src/init.rs b/crates/claw-cli/src/init.rs new file mode 100644 index 0000000..f4db53a --- /dev/null +++ b/crates/claw-cli/src/init.rs @@ -0,0 +1,432 @@ +use std::fs; +use std::path::{Path, PathBuf}; + +const STARTER_CLAW_JSON: &str = concat!( + "{\n", + " \"permissions\": {\n", + " \"defaultMode\": \"dontAsk\"\n", + " }\n", + "}\n", +); +const GITIGNORE_COMMENT: &str = "# Claw Code local artifacts"; +const GITIGNORE_ENTRIES: [&str; 2] = [".claw/settings.local.json", ".claw/sessions/"]; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum InitStatus { + Created, + Updated, + Skipped, +} + +impl InitStatus { + #[must_use] + pub(crate) fn label(self) -> &'static str { + match self { + Self::Created => "created", + Self::Updated => "updated", + Self::Skipped => "skipped (already exists)", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct InitArtifact { + pub(crate) name: &'static str, + pub(crate) status: InitStatus, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct InitReport { + pub(crate) project_root: PathBuf, + pub(crate) artifacts: Vec, +} + +impl InitReport { + #[must_use] + pub(crate) fn render(&self) -> String { + let mut lines = vec![ + "Init".to_string(), + format!(" Project {}", self.project_root.display()), + ]; + for artifact in &self.artifacts { + lines.push(format!( + " {:<16} {}", + artifact.name, + artifact.status.label() + )); + } + lines.push(" Next step Review and tailor the generated guidance".to_string()); + lines.join("\n") + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[allow(clippy::struct_excessive_bools)] +struct RepoDetection { + rust_workspace: bool, + rust_root: bool, + python: bool, + package_json: bool, + typescript: bool, + nextjs: bool, + react: bool, + vite: bool, + nest: bool, + src_dir: bool, + tests_dir: bool, + rust_dir: bool, +} + +pub(crate) fn initialize_repo(cwd: &Path) -> Result> { + let mut artifacts = Vec::new(); + + let claw_dir = cwd.join(".claw"); + artifacts.push(InitArtifact { + name: ".claw/", + status: ensure_dir(&claw_dir)?, + }); + + let claw_json = cwd.join(".claw.json"); + artifacts.push(InitArtifact { + name: ".claw.json", + status: write_file_if_missing(&claw_json, STARTER_CLAW_JSON)?, + }); + + let gitignore = cwd.join(".gitignore"); + artifacts.push(InitArtifact { + name: ".gitignore", + status: ensure_gitignore_entries(&gitignore)?, + }); + + let claw_md = cwd.join("CLAW.md"); + let content = render_init_claw_md(cwd); + artifacts.push(InitArtifact { + name: "CLAW.md", + status: write_file_if_missing(&claw_md, &content)?, + }); + + Ok(InitReport { + project_root: cwd.to_path_buf(), + artifacts, + }) +} + +fn ensure_dir(path: &Path) -> Result { + if path.is_dir() { + return Ok(InitStatus::Skipped); + } + fs::create_dir_all(path)?; + Ok(InitStatus::Created) +} + +fn write_file_if_missing(path: &Path, content: &str) -> Result { + if path.exists() { + return Ok(InitStatus::Skipped); + } + fs::write(path, content)?; + Ok(InitStatus::Created) +} + +fn ensure_gitignore_entries(path: &Path) -> Result { + if !path.exists() { + let mut lines = vec![GITIGNORE_COMMENT.to_string()]; + lines.extend(GITIGNORE_ENTRIES.iter().map(|entry| (*entry).to_string())); + fs::write(path, format!("{}\n", lines.join("\n")))?; + return Ok(InitStatus::Created); + } + + let existing = fs::read_to_string(path)?; + let mut lines = existing.lines().map(ToOwned::to_owned).collect::>(); + let mut changed = false; + + if !lines.iter().any(|line| line == GITIGNORE_COMMENT) { + lines.push(GITIGNORE_COMMENT.to_string()); + changed = true; + } + + for entry in GITIGNORE_ENTRIES { + if !lines.iter().any(|line| line == entry) { + lines.push(entry.to_string()); + changed = true; + } + } + + if !changed { + return Ok(InitStatus::Skipped); + } + + fs::write(path, format!("{}\n", lines.join("\n")))?; + Ok(InitStatus::Updated) +} + +pub(crate) fn render_init_claw_md(cwd: &Path) -> String { + let detection = detect_repo(cwd); + let mut lines = vec![ + "# CLAW.md".to_string(), + String::new(), + "This file provides guidance to Claw Code (clawcode.dev) when working with code in this repository.".to_string(), + String::new(), + ]; + + let detected_languages = detected_languages(&detection); + let detected_frameworks = detected_frameworks(&detection); + lines.push("## Detected stack".to_string()); + if detected_languages.is_empty() { + lines.push("- No specific language markers were detected yet; document the primary language and verification commands once the project structure settles.".to_string()); + } else { + lines.push(format!("- Languages: {}.", detected_languages.join(", "))); + } + if detected_frameworks.is_empty() { + lines.push("- Frameworks: none detected from the supported starter markers.".to_string()); + } else { + lines.push(format!( + "- Frameworks/tooling markers: {}.", + detected_frameworks.join(", ") + )); + } + lines.push(String::new()); + + let verification_lines = verification_lines(cwd, &detection); + if !verification_lines.is_empty() { + lines.push("## Verification".to_string()); + lines.extend(verification_lines); + lines.push(String::new()); + } + + let structure_lines = repository_shape_lines(&detection); + if !structure_lines.is_empty() { + lines.push("## Repository shape".to_string()); + lines.extend(structure_lines); + lines.push(String::new()); + } + + let framework_lines = framework_notes(&detection); + if !framework_lines.is_empty() { + lines.push("## Framework notes".to_string()); + lines.extend(framework_lines); + lines.push(String::new()); + } + + lines.push("## Working agreement".to_string()); + lines.push("- Prefer small, reviewable changes and keep generated bootstrap files aligned with actual repo workflows.".to_string()); + lines.push("- Keep shared defaults in `.claw.json`; reserve `.claw/settings.local.json` for machine-local overrides.".to_string()); + lines.push("- Do not overwrite existing `CLAW.md` content automatically; update it intentionally when repo workflows change.".to_string()); + lines.push(String::new()); + + lines.join("\n") +} + +fn detect_repo(cwd: &Path) -> RepoDetection { + let package_json_contents = fs::read_to_string(cwd.join("package.json")) + .unwrap_or_default() + .to_ascii_lowercase(); + RepoDetection { + rust_workspace: cwd.join("rust").join("Cargo.toml").is_file(), + rust_root: cwd.join("Cargo.toml").is_file(), + python: cwd.join("pyproject.toml").is_file() + || cwd.join("requirements.txt").is_file() + || cwd.join("setup.py").is_file(), + package_json: cwd.join("package.json").is_file(), + typescript: cwd.join("tsconfig.json").is_file() + || package_json_contents.contains("typescript"), + nextjs: package_json_contents.contains("\"next\""), + react: package_json_contents.contains("\"react\""), + vite: package_json_contents.contains("\"vite\""), + nest: package_json_contents.contains("@nestjs"), + src_dir: cwd.join("src").is_dir(), + tests_dir: cwd.join("tests").is_dir(), + rust_dir: cwd.join("rust").is_dir(), + } +} + +fn detected_languages(detection: &RepoDetection) -> Vec<&'static str> { + let mut languages = Vec::new(); + if detection.rust_workspace || detection.rust_root { + languages.push("Rust"); + } + if detection.python { + languages.push("Python"); + } + if detection.typescript { + languages.push("TypeScript"); + } else if detection.package_json { + languages.push("JavaScript/Node.js"); + } + languages +} + +fn detected_frameworks(detection: &RepoDetection) -> Vec<&'static str> { + let mut frameworks = Vec::new(); + if detection.nextjs { + frameworks.push("Next.js"); + } + if detection.react { + frameworks.push("React"); + } + if detection.vite { + frameworks.push("Vite"); + } + if detection.nest { + frameworks.push("NestJS"); + } + frameworks +} + +fn verification_lines(cwd: &Path, detection: &RepoDetection) -> Vec { + let mut lines = Vec::new(); + if detection.rust_workspace { + lines.push("- Run Rust verification from `rust/`: `cargo fmt`, `cargo clippy --workspace --all-targets -- -D warnings`, `cargo test --workspace`".to_string()); + } else if detection.rust_root { + lines.push("- Run Rust verification from the repo root: `cargo fmt`, `cargo clippy --workspace --all-targets -- -D warnings`, `cargo test --workspace`".to_string()); + } + if detection.python { + if cwd.join("pyproject.toml").is_file() { + lines.push("- Run the Python project checks declared in `pyproject.toml` (for example: `pytest`, `ruff check`, and `mypy` when configured).".to_string()); + } else { + lines.push( + "- Run the repo's Python test/lint commands before shipping changes.".to_string(), + ); + } + } + if detection.package_json { + lines.push("- Run the JavaScript/TypeScript checks from `package.json` before shipping changes (`npm test`, `npm run lint`, `npm run build`, or the repo equivalent).".to_string()); + } + if detection.tests_dir && detection.src_dir { + lines.push("- `src/` and `tests/` are both present; update both surfaces together when behavior changes.".to_string()); + } + lines +} + +fn repository_shape_lines(detection: &RepoDetection) -> Vec { + let mut lines = Vec::new(); + if detection.rust_dir { + lines.push( + "- `rust/` contains the Rust workspace and active CLI/runtime implementation." + .to_string(), + ); + } + if detection.src_dir { + lines.push("- `src/` contains source files that should stay consistent with generated guidance and tests.".to_string()); + } + if detection.tests_dir { + lines.push("- `tests/` contains validation surfaces that should be reviewed alongside code changes.".to_string()); + } + lines +} + +fn framework_notes(detection: &RepoDetection) -> Vec { + let mut lines = Vec::new(); + if detection.nextjs { + lines.push("- Next.js detected: preserve routing/data-fetching conventions and verify production builds after changing app structure.".to_string()); + } + if detection.react && !detection.nextjs { + lines.push("- React detected: keep component behavior covered with focused tests and avoid unnecessary prop/API churn.".to_string()); + } + if detection.vite { + lines.push("- Vite detected: validate the production bundle after changing build-sensitive configuration or imports.".to_string()); + } + if detection.nest { + lines.push("- NestJS detected: keep module/provider boundaries explicit and verify controller/service wiring after refactors.".to_string()); + } + lines +} + +#[cfg(test)] +mod tests { + use super::{initialize_repo, render_init_claw_md}; + use std::fs; + use std::path::Path; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn temp_dir() -> std::path::PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("claw-init-{nanos}")) + } + + #[test] + fn initialize_repo_creates_expected_files_and_gitignore_entries() { + let root = temp_dir(); + fs::create_dir_all(root.join("rust")).expect("create rust dir"); + fs::write(root.join("rust").join("Cargo.toml"), "[workspace]\n").expect("write cargo"); + + let report = initialize_repo(&root).expect("init should succeed"); + let rendered = report.render(); + assert!(rendered.contains(".claw/ created")); + assert!(rendered.contains(".claw.json created")); + assert!(rendered.contains(".gitignore created")); + assert!(rendered.contains("CLAW.md created")); + assert!(root.join(".claw").is_dir()); + assert!(root.join(".claw.json").is_file()); + assert!(root.join("CLAW.md").is_file()); + assert_eq!( + fs::read_to_string(root.join(".claw.json")).expect("read claw json"), + concat!( + "{\n", + " \"permissions\": {\n", + " \"defaultMode\": \"dontAsk\"\n", + " }\n", + "}\n", + ) + ); + let gitignore = fs::read_to_string(root.join(".gitignore")).expect("read gitignore"); + assert!(gitignore.contains(".claw/settings.local.json")); + assert!(gitignore.contains(".claw/sessions/")); + let claw_md = fs::read_to_string(root.join("CLAW.md")).expect("read claw md"); + assert!(claw_md.contains("Languages: Rust.")); + assert!(claw_md.contains("cargo clippy --workspace --all-targets -- -D warnings")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn initialize_repo_is_idempotent_and_preserves_existing_files() { + let root = temp_dir(); + fs::create_dir_all(&root).expect("create root"); + fs::write(root.join("CLAW.md"), "custom guidance\n").expect("write existing claw md"); + fs::write(root.join(".gitignore"), ".claw/settings.local.json\n").expect("write gitignore"); + + let first = initialize_repo(&root).expect("first init should succeed"); + assert!(first + .render() + .contains("CLAW.md skipped (already exists)")); + let second = initialize_repo(&root).expect("second init should succeed"); + let second_rendered = second.render(); + assert!(second_rendered.contains(".claw/ skipped (already exists)")); + assert!(second_rendered.contains(".claw.json skipped (already exists)")); + assert!(second_rendered.contains(".gitignore skipped (already exists)")); + assert!(second_rendered.contains("CLAW.md skipped (already exists)")); + assert_eq!( + fs::read_to_string(root.join("CLAW.md")).expect("read existing claw md"), + "custom guidance\n" + ); + let gitignore = fs::read_to_string(root.join(".gitignore")).expect("read gitignore"); + assert_eq!(gitignore.matches(".claw/settings.local.json").count(), 1); + assert_eq!(gitignore.matches(".claw/sessions/").count(), 1); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn render_init_template_mentions_detected_python_and_nextjs_markers() { + let root = temp_dir(); + fs::create_dir_all(&root).expect("create root"); + fs::write(root.join("pyproject.toml"), "[project]\nname = \"demo\"\n") + .expect("write pyproject"); + fs::write( + root.join("package.json"), + r#"{"dependencies":{"next":"14.0.0","react":"18.0.0"},"devDependencies":{"typescript":"5.0.0"}}"#, + ) + .expect("write package json"); + + let rendered = render_init_claw_md(Path::new(&root)); + assert!(rendered.contains("Languages: Python, TypeScript.")); + assert!(rendered.contains("Frameworks/tooling markers: Next.js, React.")); + assert!(rendered.contains("pyproject.toml")); + assert!(rendered.contains("Next.js detected")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } +} diff --git a/crates/claw-cli/src/input.rs b/crates/claw-cli/src/input.rs new file mode 100644 index 0000000..a718cd7 --- /dev/null +++ b/crates/claw-cli/src/input.rs @@ -0,0 +1,1195 @@ +use std::borrow::Cow; +use std::io::{self, IsTerminal, Write}; + +use crossterm::cursor::{MoveToColumn, MoveUp}; +use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyEventKind, KeyModifiers}; +use crossterm::queue; +use crossterm::terminal::{self, Clear, ClearType}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ReadOutcome { + Submit(String), + Cancel, + Exit, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum EditorMode { + Plain, + Insert, + Normal, + Visual, + Command, +} + +impl EditorMode { + fn indicator(self, vim_enabled: bool) -> Option<&'static str> { + if !vim_enabled { + return None; + } + + Some(match self { + Self::Plain => "PLAIN", + Self::Insert => "INSERT", + Self::Normal => "NORMAL", + Self::Visual => "VISUAL", + Self::Command => "COMMAND", + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +struct YankBuffer { + text: String, + linewise: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct EditSession { + text: String, + cursor: usize, + mode: EditorMode, + pending_operator: Option, + visual_anchor: Option, + command_buffer: String, + command_cursor: usize, + history_index: Option, + history_backup: Option, + rendered_cursor_row: usize, + rendered_lines: usize, +} + +impl EditSession { + fn new(vim_enabled: bool) -> Self { + Self { + text: String::new(), + cursor: 0, + mode: if vim_enabled { + EditorMode::Insert + } else { + EditorMode::Plain + }, + pending_operator: None, + visual_anchor: None, + command_buffer: String::new(), + command_cursor: 0, + history_index: None, + history_backup: None, + rendered_cursor_row: 0, + rendered_lines: 1, + } + } + + fn active_text(&self) -> &str { + if self.mode == EditorMode::Command { + &self.command_buffer + } else { + &self.text + } + } + + fn current_len(&self) -> usize { + self.active_text().len() + } + + fn has_input(&self) -> bool { + !self.active_text().is_empty() + } + + fn current_line(&self) -> String { + self.active_text().to_string() + } + + fn set_text_from_history(&mut self, entry: String) { + self.text = entry; + self.cursor = self.text.len(); + self.pending_operator = None; + self.visual_anchor = None; + if self.mode != EditorMode::Plain && self.mode != EditorMode::Insert { + self.mode = EditorMode::Normal; + } + } + + fn enter_insert_mode(&mut self) { + self.mode = EditorMode::Insert; + self.pending_operator = None; + self.visual_anchor = None; + } + + fn enter_normal_mode(&mut self) { + self.mode = EditorMode::Normal; + self.pending_operator = None; + self.visual_anchor = None; + } + + fn enter_visual_mode(&mut self) { + self.mode = EditorMode::Visual; + self.pending_operator = None; + self.visual_anchor = Some(self.cursor); + } + + fn enter_command_mode(&mut self) { + self.mode = EditorMode::Command; + self.pending_operator = None; + self.visual_anchor = None; + self.command_buffer.clear(); + self.command_buffer.push(':'); + self.command_cursor = self.command_buffer.len(); + } + + fn exit_command_mode(&mut self) { + self.command_buffer.clear(); + self.command_cursor = 0; + self.enter_normal_mode(); + } + + fn visible_buffer(&self) -> Cow<'_, str> { + if self.mode != EditorMode::Visual { + return Cow::Borrowed(self.active_text()); + } + + let Some(anchor) = self.visual_anchor else { + return Cow::Borrowed(self.active_text()); + }; + let Some((start, end)) = selection_bounds(&self.text, anchor, self.cursor) else { + return Cow::Borrowed(self.active_text()); + }; + + Cow::Owned(render_selected_text(&self.text, start, end)) + } + + fn prompt<'a>(&self, base_prompt: &'a str, vim_enabled: bool) -> Cow<'a, str> { + match self.mode.indicator(vim_enabled) { + Some(mode) => Cow::Owned(format!("[{mode}] {base_prompt}")), + None => Cow::Borrowed(base_prompt), + } + } + + fn clear_render(&self, out: &mut impl Write) -> io::Result<()> { + if self.rendered_cursor_row > 0 { + queue!(out, MoveUp(to_u16(self.rendered_cursor_row)?))?; + } + queue!(out, MoveToColumn(0), Clear(ClearType::FromCursorDown))?; + out.flush() + } + + fn render( + &mut self, + out: &mut impl Write, + base_prompt: &str, + vim_enabled: bool, + ) -> io::Result<()> { + self.clear_render(out)?; + + let prompt = self.prompt(base_prompt, vim_enabled); + let buffer = self.visible_buffer(); + write!(out, "{prompt}{buffer}")?; + + let (cursor_row, cursor_col, total_lines) = self.cursor_layout(prompt.as_ref()); + let rows_to_move_up = total_lines.saturating_sub(cursor_row + 1); + if rows_to_move_up > 0 { + queue!(out, MoveUp(to_u16(rows_to_move_up)?))?; + } + queue!(out, MoveToColumn(to_u16(cursor_col)?))?; + out.flush()?; + + self.rendered_cursor_row = cursor_row; + self.rendered_lines = total_lines; + Ok(()) + } + + fn finalize_render( + &self, + out: &mut impl Write, + base_prompt: &str, + vim_enabled: bool, + ) -> io::Result<()> { + self.clear_render(out)?; + let prompt = self.prompt(base_prompt, vim_enabled); + let buffer = self.visible_buffer(); + write!(out, "{prompt}{buffer}")?; + writeln!(out) + } + + fn cursor_layout(&self, prompt: &str) -> (usize, usize, usize) { + let active_text = self.active_text(); + let cursor = if self.mode == EditorMode::Command { + self.command_cursor + } else { + self.cursor + }; + + let cursor_prefix = &active_text[..cursor]; + let cursor_row = cursor_prefix.bytes().filter(|byte| *byte == b'\n').count(); + let cursor_col = match cursor_prefix.rsplit_once('\n') { + Some((_, suffix)) => suffix.chars().count(), + None => prompt.chars().count() + cursor_prefix.chars().count(), + }; + let total_lines = active_text.bytes().filter(|byte| *byte == b'\n').count() + 1; + (cursor_row, cursor_col, total_lines) + } +} + +enum KeyAction { + Continue, + Submit(String), + Cancel, + Exit, + ToggleVim, +} + +pub struct LineEditor { + prompt: String, + completions: Vec, + history: Vec, + yank_buffer: YankBuffer, + vim_enabled: bool, + completion_state: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CompletionState { + prefix: String, + matches: Vec, + next_index: usize, +} + +impl LineEditor { + #[must_use] + pub fn new(prompt: impl Into, completions: Vec) -> Self { + Self { + prompt: prompt.into(), + completions, + history: Vec::new(), + yank_buffer: YankBuffer::default(), + vim_enabled: false, + completion_state: None, + } + } + + pub fn push_history(&mut self, entry: impl Into) { + let entry = entry.into(); + if entry.trim().is_empty() { + return; + } + + self.history.push(entry); + } + + pub fn read_line(&mut self) -> io::Result { + if !io::stdin().is_terminal() || !io::stdout().is_terminal() { + return self.read_line_fallback(); + } + + let _raw_mode = RawModeGuard::new()?; + let mut stdout = io::stdout(); + let mut session = EditSession::new(self.vim_enabled); + session.render(&mut stdout, &self.prompt, self.vim_enabled)?; + + loop { + let Event::Key(key) = event::read()? else { + continue; + }; + if !matches!(key.kind, KeyEventKind::Press | KeyEventKind::Repeat) { + continue; + } + + match self.handle_key_event(&mut session, key) { + KeyAction::Continue => { + session.render(&mut stdout, &self.prompt, self.vim_enabled)?; + } + KeyAction::Submit(line) => { + session.finalize_render(&mut stdout, &self.prompt, self.vim_enabled)?; + return Ok(ReadOutcome::Submit(line)); + } + KeyAction::Cancel => { + session.clear_render(&mut stdout)?; + writeln!(stdout)?; + return Ok(ReadOutcome::Cancel); + } + KeyAction::Exit => { + session.clear_render(&mut stdout)?; + writeln!(stdout)?; + return Ok(ReadOutcome::Exit); + } + KeyAction::ToggleVim => { + session.clear_render(&mut stdout)?; + self.vim_enabled = !self.vim_enabled; + writeln!( + stdout, + "Vim mode {}.", + if self.vim_enabled { + "enabled" + } else { + "disabled" + } + )?; + session = EditSession::new(self.vim_enabled); + session.render(&mut stdout, &self.prompt, self.vim_enabled)?; + } + } + } + } + + fn read_line_fallback(&mut self) -> io::Result { + loop { + let mut stdout = io::stdout(); + write!(stdout, "{}", self.prompt)?; + stdout.flush()?; + + let mut buffer = String::new(); + let bytes_read = io::stdin().read_line(&mut buffer)?; + if bytes_read == 0 { + return Ok(ReadOutcome::Exit); + } + + while matches!(buffer.chars().last(), Some('\n' | '\r')) { + buffer.pop(); + } + + if self.handle_submission(&buffer) == Submission::ToggleVim { + self.vim_enabled = !self.vim_enabled; + writeln!( + stdout, + "Vim mode {}.", + if self.vim_enabled { + "enabled" + } else { + "disabled" + } + )?; + continue; + } + + return Ok(ReadOutcome::Submit(buffer)); + } + } + + fn handle_key_event(&mut self, session: &mut EditSession, key: KeyEvent) -> KeyAction { + if key.code != KeyCode::Tab { + self.completion_state = None; + } + + if key.modifiers.contains(KeyModifiers::CONTROL) { + match key.code { + KeyCode::Char('c') | KeyCode::Char('C') => { + return if session.has_input() { + KeyAction::Cancel + } else { + KeyAction::Exit + }; + } + KeyCode::Char('j') | KeyCode::Char('J') => { + if session.mode != EditorMode::Normal && session.mode != EditorMode::Visual { + self.insert_active_text(session, "\n"); + } + return KeyAction::Continue; + } + KeyCode::Char('d') | KeyCode::Char('D') => { + if session.current_len() == 0 { + return KeyAction::Exit; + } + self.delete_char_under_cursor(session); + return KeyAction::Continue; + } + _ => {} + } + } + + match key.code { + KeyCode::Enter if key.modifiers.contains(KeyModifiers::SHIFT) => { + if session.mode != EditorMode::Normal && session.mode != EditorMode::Visual { + self.insert_active_text(session, "\n"); + } + KeyAction::Continue + } + KeyCode::Enter => self.submit_or_toggle(session), + KeyCode::Esc => self.handle_escape(session), + KeyCode::Backspace => { + self.handle_backspace(session); + KeyAction::Continue + } + KeyCode::Delete => { + self.delete_char_under_cursor(session); + KeyAction::Continue + } + KeyCode::Left => { + self.move_left(session); + KeyAction::Continue + } + KeyCode::Right => { + self.move_right(session); + KeyAction::Continue + } + KeyCode::Up => { + self.history_up(session); + KeyAction::Continue + } + KeyCode::Down => { + self.history_down(session); + KeyAction::Continue + } + KeyCode::Home => { + self.move_line_start(session); + KeyAction::Continue + } + KeyCode::End => { + self.move_line_end(session); + KeyAction::Continue + } + KeyCode::Tab => { + self.complete_slash_command(session); + KeyAction::Continue + } + KeyCode::Char(ch) => { + self.handle_char(session, ch); + KeyAction::Continue + } + _ => KeyAction::Continue, + } + } + + fn handle_char(&mut self, session: &mut EditSession, ch: char) { + match session.mode { + EditorMode::Plain => self.insert_active_char(session, ch), + EditorMode::Insert => self.insert_active_char(session, ch), + EditorMode::Normal => self.handle_normal_char(session, ch), + EditorMode::Visual => self.handle_visual_char(session, ch), + EditorMode::Command => self.insert_active_char(session, ch), + } + } + + fn handle_normal_char(&mut self, session: &mut EditSession, ch: char) { + if let Some(operator) = session.pending_operator.take() { + match (operator, ch) { + ('d', 'd') => { + self.delete_current_line(session); + return; + } + ('y', 'y') => { + self.yank_current_line(session); + return; + } + _ => {} + } + } + + match ch { + 'h' => self.move_left(session), + 'j' => self.move_down(session), + 'k' => self.move_up(session), + 'l' => self.move_right(session), + 'd' | 'y' => session.pending_operator = Some(ch), + 'p' => self.paste_after(session), + 'i' => session.enter_insert_mode(), + 'v' => session.enter_visual_mode(), + ':' => session.enter_command_mode(), + _ => {} + } + } + + fn handle_visual_char(&mut self, session: &mut EditSession, ch: char) { + match ch { + 'h' => self.move_left(session), + 'j' => self.move_down(session), + 'k' => self.move_up(session), + 'l' => self.move_right(session), + 'v' => session.enter_normal_mode(), + _ => {} + } + } + + fn handle_escape(&mut self, session: &mut EditSession) -> KeyAction { + match session.mode { + EditorMode::Plain => KeyAction::Continue, + EditorMode::Insert => { + if session.cursor > 0 { + session.cursor = previous_boundary(&session.text, session.cursor); + } + session.enter_normal_mode(); + KeyAction::Continue + } + EditorMode::Normal => KeyAction::Continue, + EditorMode::Visual => { + session.enter_normal_mode(); + KeyAction::Continue + } + EditorMode::Command => { + session.exit_command_mode(); + KeyAction::Continue + } + } + } + + fn handle_backspace(&mut self, session: &mut EditSession) { + match session.mode { + EditorMode::Normal | EditorMode::Visual => self.move_left(session), + EditorMode::Command => { + if session.command_cursor <= 1 { + session.exit_command_mode(); + } else { + remove_previous_char(&mut session.command_buffer, &mut session.command_cursor); + } + } + EditorMode::Plain | EditorMode::Insert => { + remove_previous_char(&mut session.text, &mut session.cursor); + } + } + } + + fn submit_or_toggle(&mut self, session: &EditSession) -> KeyAction { + let line = session.current_line(); + match self.handle_submission(&line) { + Submission::Submit => KeyAction::Submit(line), + Submission::ToggleVim => KeyAction::ToggleVim, + } + } + + fn handle_submission(&mut self, line: &str) -> Submission { + if line.trim() == "/vim" { + Submission::ToggleVim + } else { + Submission::Submit + } + } + + fn insert_active_char(&mut self, session: &mut EditSession, ch: char) { + let mut buffer = [0; 4]; + self.insert_active_text(session, ch.encode_utf8(&mut buffer)); + } + + fn insert_active_text(&mut self, session: &mut EditSession, text: &str) { + if session.mode == EditorMode::Command { + session + .command_buffer + .insert_str(session.command_cursor, text); + session.command_cursor += text.len(); + } else { + session.text.insert_str(session.cursor, text); + session.cursor += text.len(); + } + } + + fn move_left(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + session.command_cursor = + previous_command_boundary(&session.command_buffer, session.command_cursor); + } else { + session.cursor = previous_boundary(&session.text, session.cursor); + } + } + + fn move_right(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + session.command_cursor = next_boundary(&session.command_buffer, session.command_cursor); + } else { + session.cursor = next_boundary(&session.text, session.cursor); + } + } + + fn move_line_start(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + session.command_cursor = 1; + } else { + session.cursor = line_start(&session.text, session.cursor); + } + } + + fn move_line_end(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + session.command_cursor = session.command_buffer.len(); + } else { + session.cursor = line_end(&session.text, session.cursor); + } + } + + fn move_up(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + return; + } + session.cursor = move_vertical(&session.text, session.cursor, -1); + } + + fn move_down(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + return; + } + session.cursor = move_vertical(&session.text, session.cursor, 1); + } + + fn delete_char_under_cursor(&self, session: &mut EditSession) { + match session.mode { + EditorMode::Command => { + if session.command_cursor < session.command_buffer.len() { + let end = next_boundary(&session.command_buffer, session.command_cursor); + session.command_buffer.drain(session.command_cursor..end); + } + } + _ => { + if session.cursor < session.text.len() { + let end = next_boundary(&session.text, session.cursor); + session.text.drain(session.cursor..end); + } + } + } + } + + fn delete_current_line(&mut self, session: &mut EditSession) { + let (line_start_idx, line_end_idx, delete_start_idx) = + current_line_delete_range(&session.text, session.cursor); + self.yank_buffer.text = session.text[line_start_idx..line_end_idx].to_string(); + self.yank_buffer.linewise = true; + session.text.drain(delete_start_idx..line_end_idx); + session.cursor = delete_start_idx.min(session.text.len()); + } + + fn yank_current_line(&mut self, session: &mut EditSession) { + let (line_start_idx, line_end_idx, _) = + current_line_delete_range(&session.text, session.cursor); + self.yank_buffer.text = session.text[line_start_idx..line_end_idx].to_string(); + self.yank_buffer.linewise = true; + } + + fn paste_after(&mut self, session: &mut EditSession) { + if self.yank_buffer.text.is_empty() { + return; + } + + if self.yank_buffer.linewise { + let line_end_idx = line_end(&session.text, session.cursor); + let insert_at = if line_end_idx < session.text.len() { + line_end_idx + 1 + } else { + session.text.len() + }; + let mut insertion = self.yank_buffer.text.clone(); + if insert_at == session.text.len() + && !session.text.is_empty() + && !session.text.ends_with('\n') + { + insertion.insert(0, '\n'); + } + if insert_at < session.text.len() && !insertion.ends_with('\n') { + insertion.push('\n'); + } + session.text.insert_str(insert_at, &insertion); + session.cursor = if insertion.starts_with('\n') { + insert_at + 1 + } else { + insert_at + }; + return; + } + + let insert_at = next_boundary(&session.text, session.cursor); + session.text.insert_str(insert_at, &self.yank_buffer.text); + session.cursor = insert_at + self.yank_buffer.text.len(); + } + + fn complete_slash_command(&mut self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + self.completion_state = None; + return; + } + if let Some(state) = self + .completion_state + .as_mut() + .filter(|_| session.cursor == session.text.len()) + .filter(|state| { + state + .matches + .iter() + .any(|candidate| candidate == &session.text) + }) + { + let candidate = state.matches[state.next_index % state.matches.len()].clone(); + state.next_index += 1; + session.text.replace_range(..session.cursor, &candidate); + session.cursor = candidate.len(); + return; + } + let Some(prefix) = slash_command_prefix(&session.text, session.cursor) else { + self.completion_state = None; + return; + }; + let matches = self + .completions + .iter() + .filter(|candidate| candidate.starts_with(prefix) && candidate.as_str() != prefix) + .cloned() + .collect::>(); + if matches.is_empty() { + self.completion_state = None; + return; + } + + let candidate = if let Some(state) = self + .completion_state + .as_mut() + .filter(|state| state.prefix == prefix && state.matches == matches) + { + let index = state.next_index % state.matches.len(); + state.next_index += 1; + state.matches[index].clone() + } else { + let candidate = matches[0].clone(); + self.completion_state = Some(CompletionState { + prefix: prefix.to_string(), + matches, + next_index: 1, + }); + candidate + }; + + session.text.replace_range(..session.cursor, &candidate); + session.cursor = candidate.len(); + } + + fn history_up(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command || self.history.is_empty() { + return; + } + + let next_index = match session.history_index { + Some(index) => index.saturating_sub(1), + None => { + session.history_backup = Some(session.text.clone()); + self.history.len() - 1 + } + }; + + session.history_index = Some(next_index); + session.set_text_from_history(self.history[next_index].clone()); + } + + fn history_down(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + return; + } + + let Some(index) = session.history_index else { + return; + }; + + if index + 1 < self.history.len() { + let next_index = index + 1; + session.history_index = Some(next_index); + session.set_text_from_history(self.history[next_index].clone()); + return; + } + + session.history_index = None; + let restored = session.history_backup.take().unwrap_or_default(); + session.set_text_from_history(restored); + if self.vim_enabled { + session.enter_insert_mode(); + } else { + session.mode = EditorMode::Plain; + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Submission { + Submit, + ToggleVim, +} + +struct RawModeGuard; + +impl RawModeGuard { + fn new() -> io::Result { + terminal::enable_raw_mode().map_err(io::Error::other)?; + Ok(Self) + } +} + +impl Drop for RawModeGuard { + fn drop(&mut self) { + let _ = terminal::disable_raw_mode(); + } +} + +fn previous_boundary(text: &str, cursor: usize) -> usize { + if cursor == 0 { + return 0; + } + + text[..cursor] + .char_indices() + .next_back() + .map_or(0, |(index, _)| index) +} + +fn previous_command_boundary(text: &str, cursor: usize) -> usize { + previous_boundary(text, cursor).max(1) +} + +fn next_boundary(text: &str, cursor: usize) -> usize { + if cursor >= text.len() { + return text.len(); + } + + text[cursor..] + .chars() + .next() + .map_or(text.len(), |ch| cursor + ch.len_utf8()) +} + +fn remove_previous_char(text: &mut String, cursor: &mut usize) { + if *cursor == 0 { + return; + } + + let start = previous_boundary(text, *cursor); + text.drain(start..*cursor); + *cursor = start; +} + +fn line_start(text: &str, cursor: usize) -> usize { + text[..cursor].rfind('\n').map_or(0, |index| index + 1) +} + +fn line_end(text: &str, cursor: usize) -> usize { + text[cursor..] + .find('\n') + .map_or(text.len(), |index| cursor + index) +} + +fn move_vertical(text: &str, cursor: usize, delta: isize) -> usize { + let starts = line_starts(text); + let current_row = text[..cursor].bytes().filter(|byte| *byte == b'\n').count(); + let current_start = starts[current_row]; + let current_col = text[current_start..cursor].chars().count(); + + let max_row = starts.len().saturating_sub(1) as isize; + let target_row = (current_row as isize + delta).clamp(0, max_row) as usize; + if target_row == current_row { + return cursor; + } + + let target_start = starts[target_row]; + let target_end = if target_row + 1 < starts.len() { + starts[target_row + 1] - 1 + } else { + text.len() + }; + byte_index_for_char_column(&text[target_start..target_end], current_col) + target_start +} + +fn line_starts(text: &str) -> Vec { + let mut starts = vec![0]; + for (index, ch) in text.char_indices() { + if ch == '\n' { + starts.push(index + 1); + } + } + starts +} + +fn byte_index_for_char_column(text: &str, column: usize) -> usize { + let mut current = 0; + for (index, _) in text.char_indices() { + if current == column { + return index; + } + current += 1; + } + text.len() +} + +fn current_line_delete_range(text: &str, cursor: usize) -> (usize, usize, usize) { + let line_start_idx = line_start(text, cursor); + let line_end_core = line_end(text, cursor); + let line_end_idx = if line_end_core < text.len() { + line_end_core + 1 + } else { + line_end_core + }; + let delete_start_idx = if line_end_idx == text.len() && line_start_idx > 0 { + line_start_idx - 1 + } else { + line_start_idx + }; + (line_start_idx, line_end_idx, delete_start_idx) +} + +fn selection_bounds(text: &str, anchor: usize, cursor: usize) -> Option<(usize, usize)> { + if text.is_empty() { + return None; + } + + if cursor >= anchor { + let end = next_boundary(text, cursor); + Some((anchor.min(text.len()), end.min(text.len()))) + } else { + let end = next_boundary(text, anchor); + Some((cursor.min(text.len()), end.min(text.len()))) + } +} + +fn render_selected_text(text: &str, start: usize, end: usize) -> String { + let mut rendered = String::new(); + let mut in_selection = false; + + for (index, ch) in text.char_indices() { + if !in_selection && index == start { + rendered.push_str("\x1b[7m"); + in_selection = true; + } + if in_selection && index == end { + rendered.push_str("\x1b[0m"); + in_selection = false; + } + rendered.push(ch); + } + + if in_selection { + rendered.push_str("\x1b[0m"); + } + + rendered +} + +fn slash_command_prefix(line: &str, pos: usize) -> Option<&str> { + if pos != line.len() { + return None; + } + + let prefix = &line[..pos]; + if prefix.contains(char::is_whitespace) || !prefix.starts_with('/') { + return None; + } + + Some(prefix) +} + +fn to_u16(value: usize) -> io::Result { + u16::try_from(value).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "terminal position overflowed u16", + ) + }) +} + +#[cfg(test)] +mod tests { + use super::{ + selection_bounds, slash_command_prefix, EditSession, EditorMode, KeyAction, LineEditor, + }; + use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + #[test] + fn extracts_only_terminal_slash_command_prefixes() { + // given + let complete_prefix = slash_command_prefix("/he", 3); + let whitespace_prefix = slash_command_prefix("/help me", 5); + let plain_text_prefix = slash_command_prefix("hello", 5); + let mid_buffer_prefix = slash_command_prefix("/help", 2); + + // when + let result = ( + complete_prefix, + whitespace_prefix, + plain_text_prefix, + mid_buffer_prefix, + ); + + // then + assert_eq!(result, (Some("/he"), None, None, None)); + } + + #[test] + fn toggle_submission_flips_vim_mode() { + // given + let mut editor = LineEditor::new("> ", vec!["/help".to_string(), "/vim".to_string()]); + + // when + let first = editor.handle_submission("/vim"); + editor.vim_enabled = true; + let second = editor.handle_submission("/vim"); + + // then + assert!(matches!(first, super::Submission::ToggleVim)); + assert!(matches!(second, super::Submission::ToggleVim)); + } + + #[test] + fn normal_mode_supports_motion_and_insert_transition() { + // given + let mut editor = LineEditor::new("> ", vec![]); + editor.vim_enabled = true; + let mut session = EditSession::new(true); + session.text = "hello".to_string(); + session.cursor = session.text.len(); + let _ = editor.handle_escape(&mut session); + + // when + editor.handle_char(&mut session, 'h'); + editor.handle_char(&mut session, 'i'); + editor.handle_char(&mut session, '!'); + + // then + assert_eq!(session.mode, EditorMode::Insert); + assert_eq!(session.text, "hel!lo"); + } + + #[test] + fn yy_and_p_paste_yanked_line_after_current_line() { + // given + let mut editor = LineEditor::new("> ", vec![]); + editor.vim_enabled = true; + let mut session = EditSession::new(true); + session.text = "alpha\nbeta\ngamma".to_string(); + session.cursor = 0; + let _ = editor.handle_escape(&mut session); + + // when + editor.handle_char(&mut session, 'y'); + editor.handle_char(&mut session, 'y'); + editor.handle_char(&mut session, 'p'); + + // then + assert_eq!(session.text, "alpha\nalpha\nbeta\ngamma"); + } + + #[test] + fn dd_and_p_paste_deleted_line_after_current_line() { + // given + let mut editor = LineEditor::new("> ", vec![]); + editor.vim_enabled = true; + let mut session = EditSession::new(true); + session.text = "alpha\nbeta\ngamma".to_string(); + session.cursor = 0; + let _ = editor.handle_escape(&mut session); + + // when + editor.handle_char(&mut session, 'j'); + editor.handle_char(&mut session, 'd'); + editor.handle_char(&mut session, 'd'); + editor.handle_char(&mut session, 'p'); + + // then + assert_eq!(session.text, "alpha\ngamma\nbeta\n"); + } + + #[test] + fn visual_mode_tracks_selection_with_motions() { + // given + let mut editor = LineEditor::new("> ", vec![]); + editor.vim_enabled = true; + let mut session = EditSession::new(true); + session.text = "alpha\nbeta".to_string(); + session.cursor = 0; + let _ = editor.handle_escape(&mut session); + + // when + editor.handle_char(&mut session, 'v'); + editor.handle_char(&mut session, 'j'); + editor.handle_char(&mut session, 'l'); + + // then + assert_eq!(session.mode, EditorMode::Visual); + assert_eq!( + selection_bounds( + &session.text, + session.visual_anchor.unwrap_or(0), + session.cursor + ), + Some((0, 8)) + ); + } + + #[test] + fn command_mode_submits_colon_prefixed_input() { + // given + let mut editor = LineEditor::new("> ", vec![]); + editor.vim_enabled = true; + let mut session = EditSession::new(true); + session.text = "draft".to_string(); + session.cursor = session.text.len(); + let _ = editor.handle_escape(&mut session); + + // when + editor.handle_char(&mut session, ':'); + editor.handle_char(&mut session, 'q'); + editor.handle_char(&mut session, '!'); + let action = editor.submit_or_toggle(&session); + + // then + assert_eq!(session.mode, EditorMode::Command); + assert_eq!(session.command_buffer, ":q!"); + assert!(matches!(action, KeyAction::Submit(line) if line == ":q!")); + } + + #[test] + fn push_history_ignores_blank_entries() { + // given + let mut editor = LineEditor::new("> ", vec!["/help".to_string()]); + + // when + editor.push_history(" "); + editor.push_history("/help"); + + // then + assert_eq!(editor.history, vec!["/help".to_string()]); + } + + #[test] + fn tab_completes_matching_slash_commands() { + // given + let mut editor = LineEditor::new("> ", vec!["/help".to_string(), "/hello".to_string()]); + let mut session = EditSession::new(false); + session.text = "/he".to_string(); + session.cursor = session.text.len(); + + // when + editor.complete_slash_command(&mut session); + + // then + assert_eq!(session.text, "/help"); + assert_eq!(session.cursor, 5); + } + + #[test] + fn tab_cycles_between_matching_slash_commands() { + // given + let mut editor = LineEditor::new( + "> ", + vec!["/permissions".to_string(), "/plugin".to_string()], + ); + let mut session = EditSession::new(false); + session.text = "/p".to_string(); + session.cursor = session.text.len(); + + // when + editor.complete_slash_command(&mut session); + let first = session.text.clone(); + session.cursor = session.text.len(); + editor.complete_slash_command(&mut session); + let second = session.text.clone(); + + // then + assert_eq!(first, "/permissions"); + assert_eq!(second, "/plugin"); + } + + #[test] + fn ctrl_c_cancels_when_input_exists() { + // given + let mut editor = LineEditor::new("> ", vec![]); + let mut session = EditSession::new(false); + session.text = "draft".to_string(); + session.cursor = session.text.len(); + + // when + let action = editor.handle_key_event( + &mut session, + KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL), + ); + + // then + assert!(matches!(action, KeyAction::Cancel)); + } +} diff --git a/crates/claw-cli/src/main.rs b/crates/claw-cli/src/main.rs new file mode 100644 index 0000000..2b7d6f1 --- /dev/null +++ b/crates/claw-cli/src/main.rs @@ -0,0 +1,5090 @@ +mod init; +mod input; +mod render; + +use std::collections::BTreeSet; +use std::env; +use std::fmt::Write as _; +use std::fs; +use std::io::{self, IsTerminal, Read, Write}; +use std::net::TcpListener; +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::sync::mpsc::{self, RecvTimeoutError}; +use std::sync::{Arc, Mutex}; +use std::thread; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +use api::{ + resolve_startup_auth_source, AuthSource, ClawApiClient, ContentBlockDelta, InputContentBlock, + InputMessage, MessageRequest, MessageResponse, OutputContentBlock, + StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, +}; + +use commands::{ + handle_agents_slash_command, handle_plugins_slash_command, handle_skills_slash_command, + render_slash_command_help, resume_supported_slash_commands, slash_command_specs, + suggest_slash_commands, SlashCommand, +}; +use compat_harness::{extract_manifest, UpstreamPaths}; +use init::initialize_repo; +use plugins::{PluginManager, PluginManagerConfig}; +use render::{MarkdownStreamState, Spinner, TerminalRenderer}; +use runtime::{ + clear_oauth_credentials, generate_pkce_pair, generate_state, load_system_prompt, + parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest, + AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, + ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig, + OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError, + Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, +}; +use serde_json::json; +use tools::GlobalToolRegistry; + +const DEFAULT_MODEL: &str = "claude-opus-4-6"; +fn max_tokens_for_model(model: &str) -> u32 { + if model.contains("opus") { + 32_000 + } else { + 64_000 + } +} +const DEFAULT_DATE: &str = "2026-03-31"; +const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545; +const VERSION: &str = env!("CARGO_PKG_VERSION"); +const BUILD_TARGET: Option<&str> = option_env!("TARGET"); +const GIT_SHA: Option<&str> = option_env!("GIT_SHA"); +const INTERNAL_PROGRESS_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(3); + +type AllowedToolSet = BTreeSet; + +fn main() { + if let Err(error) = run() { + eprintln!("{}", render_cli_error(&error.to_string())); + std::process::exit(1); + } +} + +fn render_cli_error(problem: &str) -> String { + let mut lines = vec!["Error".to_string()]; + for (index, line) in problem.lines().enumerate() { + let label = if index == 0 { + " Problem " + } else { + " " + }; + lines.push(format!("{label}{line}")); + } + lines.push(" Help claw --help".to_string()); + lines.join("\n") +} + +fn run() -> Result<(), Box> { + let args: Vec = env::args().skip(1).collect(); + match parse_args(&args)? { + CliAction::DumpManifests => dump_manifests(), + CliAction::BootstrapPlan => print_bootstrap_plan(), + CliAction::Agents { args } => LiveCli::print_agents(args.as_deref())?, + CliAction::Skills { args } => LiveCli::print_skills(args.as_deref())?, + CliAction::PrintSystemPrompt { cwd, date } => print_system_prompt(cwd, date), + CliAction::Version => print_version(), + CliAction::ResumeSession { + session_path, + commands, + } => resume_session(&session_path, &commands), + CliAction::Prompt { + prompt, + model, + output_format, + allowed_tools, + permission_mode, + } => LiveCli::new(model, true, allowed_tools, permission_mode)? + .run_turn_with_output(&prompt, output_format)?, + CliAction::Login => run_login()?, + CliAction::Logout => run_logout()?, + CliAction::Init => run_init()?, + CliAction::Repl { + model, + allowed_tools, + permission_mode, + } => run_repl(model, allowed_tools, permission_mode)?, + CliAction::Help => print_help(), + } + Ok(()) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum CliAction { + DumpManifests, + BootstrapPlan, + Agents { + args: Option, + }, + Skills { + args: Option, + }, + PrintSystemPrompt { + cwd: PathBuf, + date: String, + }, + Version, + ResumeSession { + session_path: PathBuf, + commands: Vec, + }, + Prompt { + prompt: String, + model: String, + output_format: CliOutputFormat, + allowed_tools: Option, + permission_mode: PermissionMode, + }, + Login, + Logout, + Init, + Repl { + model: String, + allowed_tools: Option, + permission_mode: PermissionMode, + }, + // prompt-mode formatting is only supported for non-interactive runs + Help, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CliOutputFormat { + Text, + Json, +} + +impl CliOutputFormat { + fn parse(value: &str) -> Result { + match value { + "text" => Ok(Self::Text), + "json" => Ok(Self::Json), + other => Err(format!( + "unsupported value for --output-format: {other} (expected text or json)" + )), + } + } +} + +#[allow(clippy::too_many_lines)] +fn parse_args(args: &[String]) -> Result { + let mut model = DEFAULT_MODEL.to_string(); + let mut output_format = CliOutputFormat::Text; + let mut permission_mode = default_permission_mode(); + let mut wants_version = false; + let mut allowed_tool_values = Vec::new(); + let mut rest = Vec::new(); + let mut index = 0; + + while index < args.len() { + match args[index].as_str() { + "--version" | "-V" => { + wants_version = true; + index += 1; + } + "--model" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --model".to_string())?; + model = resolve_model_alias(value).to_string(); + index += 2; + } + flag if flag.starts_with("--model=") => { + model = resolve_model_alias(&flag[8..]).to_string(); + index += 1; + } + "--output-format" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --output-format".to_string())?; + output_format = CliOutputFormat::parse(value)?; + index += 2; + } + "--permission-mode" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --permission-mode".to_string())?; + permission_mode = parse_permission_mode_arg(value)?; + index += 2; + } + flag if flag.starts_with("--output-format=") => { + output_format = CliOutputFormat::parse(&flag[16..])?; + index += 1; + } + flag if flag.starts_with("--permission-mode=") => { + permission_mode = parse_permission_mode_arg(&flag[18..])?; + index += 1; + } + "--dangerously-skip-permissions" => { + permission_mode = PermissionMode::DangerFullAccess; + index += 1; + } + "-p" => { + // Claw Code compat: -p "prompt" = one-shot prompt + let prompt = args[index + 1..].join(" "); + if prompt.trim().is_empty() { + return Err("-p requires a prompt string".to_string()); + } + return Ok(CliAction::Prompt { + prompt, + model: resolve_model_alias(&model).to_string(), + output_format, + allowed_tools: normalize_allowed_tools(&allowed_tool_values)?, + permission_mode, + }); + } + "--print" => { + // Claw Code compat: --print makes output non-interactive + output_format = CliOutputFormat::Text; + index += 1; + } + "--allowedTools" | "--allowed-tools" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --allowedTools".to_string())?; + allowed_tool_values.push(value.clone()); + index += 2; + } + flag if flag.starts_with("--allowedTools=") => { + allowed_tool_values.push(flag[15..].to_string()); + index += 1; + } + flag if flag.starts_with("--allowed-tools=") => { + allowed_tool_values.push(flag[16..].to_string()); + index += 1; + } + other => { + rest.push(other.to_string()); + index += 1; + } + } + } + + if wants_version { + return Ok(CliAction::Version); + } + + let allowed_tools = normalize_allowed_tools(&allowed_tool_values)?; + + if rest.is_empty() { + return Ok(CliAction::Repl { + model, + allowed_tools, + permission_mode, + }); + } + if matches!(rest.first().map(String::as_str), Some("--help" | "-h")) { + return Ok(CliAction::Help); + } + if rest.first().map(String::as_str) == Some("--resume") { + return parse_resume_args(&rest[1..]); + } + + match rest[0].as_str() { + "dump-manifests" => Ok(CliAction::DumpManifests), + "bootstrap-plan" => Ok(CliAction::BootstrapPlan), + "agents" => Ok(CliAction::Agents { + args: join_optional_args(&rest[1..]), + }), + "skills" => Ok(CliAction::Skills { + args: join_optional_args(&rest[1..]), + }), + "system-prompt" => parse_system_prompt_args(&rest[1..]), + "login" => Ok(CliAction::Login), + "logout" => Ok(CliAction::Logout), + "init" => Ok(CliAction::Init), + "prompt" => { + let prompt = rest[1..].join(" "); + if prompt.trim().is_empty() { + return Err("prompt subcommand requires a prompt string".to_string()); + } + Ok(CliAction::Prompt { + prompt, + model, + output_format, + allowed_tools, + permission_mode, + }) + } + other if other.starts_with('/') => parse_direct_slash_cli_action(&rest), + _other => Ok(CliAction::Prompt { + prompt: rest.join(" "), + model, + output_format, + allowed_tools, + permission_mode, + }), + } +} + +fn join_optional_args(args: &[String]) -> Option { + let joined = args.join(" "); + let trimmed = joined.trim(); + (!trimmed.is_empty()).then(|| trimmed.to_string()) +} + +fn parse_direct_slash_cli_action(rest: &[String]) -> Result { + let raw = rest.join(" "); + match SlashCommand::parse(&raw) { + Some(SlashCommand::Help) => Ok(CliAction::Help), + Some(SlashCommand::Agents { args }) => Ok(CliAction::Agents { args }), + Some(SlashCommand::Skills { args }) => Ok(CliAction::Skills { args }), + Some(command) => Err(format_direct_slash_command_error( + match &command { + SlashCommand::Unknown(name) => format!("/{name}"), + _ => rest[0].clone(), + } + .as_str(), + matches!(command, SlashCommand::Unknown(_)), + )), + None => Err(format!("unknown subcommand: {}", rest[0])), + } +} + +fn format_direct_slash_command_error(command: &str, is_unknown: bool) -> String { + let trimmed = command.trim().trim_start_matches('/'); + let mut lines = vec![ + "Direct slash command unavailable".to_string(), + format!(" Command /{trimmed}"), + ]; + if is_unknown { + append_slash_command_suggestions(&mut lines, trimmed); + } else { + lines.push(" Try Start `claw` to use interactive slash commands".to_string()); + lines.push( + " Tip Resume-safe commands also work with `claw --resume SESSION.json ...`" + .to_string(), + ); + } + lines.join("\n") +} + +fn resolve_model_alias(model: &str) -> &str { + match model { + "opus" => "claude-opus-4-6", + "sonnet" => "claude-sonnet-4-6", + "haiku" => "claude-haiku-4-5-20251213", + _ => model, + } +} + +fn normalize_allowed_tools(values: &[String]) -> Result, String> { + current_tool_registry()?.normalize_allowed_tools(values) +} + +fn current_tool_registry() -> Result { + let cwd = env::current_dir().map_err(|error| error.to_string())?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader.load().map_err(|error| error.to_string())?; + let plugin_manager = build_plugin_manager(&cwd, &loader, &runtime_config); + let plugin_tools = plugin_manager + .aggregated_tools() + .map_err(|error| error.to_string())?; + GlobalToolRegistry::with_plugin_tools(plugin_tools) +} + +fn parse_permission_mode_arg(value: &str) -> Result { + normalize_permission_mode(value) + .ok_or_else(|| { + format!( + "unsupported permission mode '{value}'. Use read-only, workspace-write, or danger-full-access." + ) + }) + .map(permission_mode_from_label) +} + +fn permission_mode_from_label(mode: &str) -> PermissionMode { + match mode { + "read-only" => PermissionMode::ReadOnly, + "workspace-write" => PermissionMode::WorkspaceWrite, + "danger-full-access" => PermissionMode::DangerFullAccess, + other => panic!("unsupported permission mode label: {other}"), + } +} + +fn default_permission_mode() -> PermissionMode { + env::var("CLAW_PERMISSION_MODE") + .ok() + .as_deref() + .and_then(normalize_permission_mode) + .map_or(PermissionMode::DangerFullAccess, permission_mode_from_label) +} + +fn filter_tool_specs( + tool_registry: &GlobalToolRegistry, + allowed_tools: Option<&AllowedToolSet>, +) -> Vec { + tool_registry.definitions(allowed_tools) +} + +fn parse_system_prompt_args(args: &[String]) -> Result { + let mut cwd = env::current_dir().map_err(|error| error.to_string())?; + let mut date = DEFAULT_DATE.to_string(); + let mut index = 0; + + while index < args.len() { + match args[index].as_str() { + "--cwd" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --cwd".to_string())?; + cwd = PathBuf::from(value); + index += 2; + } + "--date" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --date".to_string())?; + date.clone_from(value); + index += 2; + } + other => return Err(format!("unknown system-prompt option: {other}")), + } + } + + Ok(CliAction::PrintSystemPrompt { cwd, date }) +} + +fn parse_resume_args(args: &[String]) -> Result { + let session_path = args + .first() + .ok_or_else(|| "missing session path for --resume".to_string()) + .map(PathBuf::from)?; + let commands = args[1..].to_vec(); + if commands + .iter() + .any(|command| !command.trim_start().starts_with('/')) + { + return Err("--resume trailing arguments must be slash commands".to_string()); + } + Ok(CliAction::ResumeSession { + session_path, + commands, + }) +} + +fn dump_manifests() { + let workspace_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../.."); + let paths = UpstreamPaths::from_workspace_dir(&workspace_dir); + match extract_manifest(&paths) { + Ok(manifest) => { + println!("commands: {}", manifest.commands.entries().len()); + println!("tools: {}", manifest.tools.entries().len()); + println!("bootstrap phases: {}", manifest.bootstrap.phases().len()); + } + Err(error) => { + eprintln!("failed to extract manifests: {error}"); + std::process::exit(1); + } + } +} + +fn print_bootstrap_plan() { + for phase in runtime::BootstrapPlan::claw_default().phases() { + println!("- {phase:?}"); + } +} + +fn default_oauth_config() -> OAuthConfig { + OAuthConfig { + client_id: String::from("9d1c250a-e61b-44d9-88ed-5944d1962f5e"), + authorize_url: String::from("https://platform.claw.dev/oauth/authorize"), + token_url: String::from("https://platform.claw.dev/v1/oauth/token"), + callback_port: None, + manual_redirect_url: None, + scopes: vec![ + String::from("user:profile"), + String::from("user:inference"), + String::from("user:sessions:claw_code"), + ], + } +} + +fn run_login() -> Result<(), Box> { + let cwd = env::current_dir()?; + let config = ConfigLoader::default_for(&cwd).load()?; + let default_oauth = default_oauth_config(); + let oauth = config.oauth().unwrap_or(&default_oauth); + let callback_port = oauth.callback_port.unwrap_or(DEFAULT_OAUTH_CALLBACK_PORT); + let redirect_uri = runtime::loopback_redirect_uri(callback_port); + let pkce = generate_pkce_pair()?; + let state = generate_state()?; + let authorize_url = + OAuthAuthorizationRequest::from_config(oauth, redirect_uri.clone(), state.clone(), &pkce) + .build_url(); + + println!("Starting Claw OAuth login..."); + println!("Listening for callback on {redirect_uri}"); + if let Err(error) = open_browser(&authorize_url) { + eprintln!("warning: failed to open browser automatically: {error}"); + println!("Open this URL manually:\n{authorize_url}"); + } + + let callback = wait_for_oauth_callback(callback_port)?; + if let Some(error) = callback.error { + let description = callback + .error_description + .unwrap_or_else(|| "authorization failed".to_string()); + return Err(io::Error::other(format!("{error}: {description}")).into()); + } + let code = callback.code.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "callback did not include code") + })?; + let returned_state = callback.state.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "callback did not include state") + })?; + if returned_state != state { + return Err(io::Error::new(io::ErrorKind::InvalidData, "oauth state mismatch").into()); + } + + let client = ClawApiClient::from_auth(AuthSource::None).with_base_url(api::read_base_url()); + let exchange_request = + OAuthTokenExchangeRequest::from_config(oauth, code, state, pkce.verifier, redirect_uri); + let runtime = tokio::runtime::Runtime::new()?; + let token_set = runtime.block_on(client.exchange_oauth_code(oauth, &exchange_request))?; + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + scopes: token_set.scopes, + })?; + println!("Claw OAuth login complete."); + Ok(()) +} + +fn run_logout() -> Result<(), Box> { + clear_oauth_credentials()?; + println!("Claw OAuth credentials cleared."); + Ok(()) +} + +fn open_browser(url: &str) -> io::Result<()> { + let commands = if cfg!(target_os = "macos") { + vec![("open", vec![url])] + } else if cfg!(target_os = "windows") { + vec![("cmd", vec!["/C", "start", "", url])] + } else { + vec![("xdg-open", vec![url])] + }; + for (program, args) in commands { + match Command::new(program).args(args).spawn() { + Ok(_) => return Ok(()), + Err(error) if error.kind() == io::ErrorKind::NotFound => {} + Err(error) => return Err(error), + } + } + Err(io::Error::new( + io::ErrorKind::NotFound, + "no supported browser opener command found", + )) +} + +fn wait_for_oauth_callback( + port: u16, +) -> Result> { + let listener = TcpListener::bind(("127.0.0.1", port))?; + let (mut stream, _) = listener.accept()?; + let mut buffer = [0_u8; 4096]; + let bytes_read = stream.read(&mut buffer)?; + let request = String::from_utf8_lossy(&buffer[..bytes_read]); + let request_line = request.lines().next().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "missing callback request line") + })?; + let target = request_line.split_whitespace().nth(1).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "missing callback request target", + ) + })?; + let callback = parse_oauth_callback_request_target(target) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + let body = if callback.error.is_some() { + "Claw OAuth login failed. You can close this window." + } else { + "Claw OAuth login succeeded. You can close this window." + }; + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: text/plain; charset=utf-8\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + Ok(callback) +} + +fn print_system_prompt(cwd: PathBuf, date: String) { + match load_system_prompt(cwd, date, env::consts::OS, "unknown") { + Ok(sections) => println!("{}", sections.join("\n\n")), + Err(error) => { + eprintln!("failed to build system prompt: {error}"); + std::process::exit(1); + } + } +} + +fn print_version() { + println!("{}", render_version_report()); +} + +fn resume_session(session_path: &Path, commands: &[String]) { + let session = match Session::load_from_path(session_path) { + Ok(session) => session, + Err(error) => { + eprintln!("failed to restore session: {error}"); + std::process::exit(1); + } + }; + + if commands.is_empty() { + println!( + "Restored session from {} ({} messages).", + session_path.display(), + session.messages.len() + ); + return; + } + + let mut session = session; + for raw_command in commands { + let Some(command) = SlashCommand::parse(raw_command) else { + eprintln!("unsupported resumed command: {raw_command}"); + std::process::exit(2); + }; + match run_resume_command(session_path, &session, &command) { + Ok(ResumeCommandOutcome { + session: next_session, + message, + }) => { + session = next_session; + if let Some(message) = message { + println!("{message}"); + } + } + Err(error) => { + eprintln!("{error}"); + std::process::exit(2); + } + } + } +} + +#[derive(Debug, Clone)] +struct ResumeCommandOutcome { + session: Session, + message: Option, +} + +#[derive(Debug, Clone)] +struct StatusContext { + cwd: PathBuf, + session_path: Option, + loaded_config_files: usize, + discovered_config_files: usize, + memory_file_count: usize, + project_root: Option, + git_branch: Option, +} + +#[derive(Debug, Clone, Copy)] +struct StatusUsage { + message_count: usize, + turns: u32, + latest: TokenUsage, + cumulative: TokenUsage, + estimated_tokens: usize, +} + +fn format_model_report(model: &str, message_count: usize, turns: u32) -> String { + format!( + "Model + Current {model} + Session {message_count} messages · {turns} turns + +Aliases + opus claude-opus-4-6 + sonnet claude-sonnet-4-6 + haiku claude-haiku-4-5-20251213 + +Next + /model Show the current model + /model Switch models for this REPL session" + ) +} + +fn format_model_switch_report(previous: &str, next: &str, message_count: usize) -> String { + format!( + "Model updated + Previous {previous} + Current {next} + Preserved {message_count} messages + Tip Existing conversation context stayed attached" + ) +} + +fn format_permissions_report(mode: &str) -> String { + let modes = [ + ("read-only", "Read/search tools only", mode == "read-only"), + ( + "workspace-write", + "Edit files inside the workspace", + mode == "workspace-write", + ), + ( + "danger-full-access", + "Unrestricted tool access", + mode == "danger-full-access", + ), + ] + .into_iter() + .map(|(name, description, is_current)| { + let marker = if is_current { + "● current" + } else { + "○ available" + }; + format!(" {name:<18} {marker:<11} {description}") + }) + .collect::>() + .join( + " +", + ); + + let effect = match mode { + "read-only" => "Only read/search tools can run automatically", + "workspace-write" => "Editing tools can modify files in the workspace", + "danger-full-access" => "All tools can run without additional sandbox limits", + _ => "Unknown permission mode", + }; + + format!( + "Permissions + Active mode {mode} + Effect {effect} + +Modes +{modes} + +Next + /permissions Show the current mode + /permissions Switch modes for subsequent tool calls" + ) +} + +fn format_permissions_switch_report(previous: &str, next: &str) -> String { + format!( + "Permissions updated + Previous mode {previous} + Active mode {next} + Applies to Subsequent tool calls in this REPL + Tip Run /permissions to review all available modes" + ) +} + +fn format_cost_report(usage: TokenUsage) -> String { + format!( + "Cost + Input tokens {} + Output tokens {} + Cache create {} + Cache read {} + Total tokens {} + +Next + /status See session + workspace context + /compact Trim local history if the session is getting large", + usage.input_tokens, + usage.output_tokens, + usage.cache_creation_input_tokens, + usage.cache_read_input_tokens, + usage.total_tokens(), + ) +} + +fn format_resume_report(session_path: &str, message_count: usize, turns: u32) -> String { + format!( + "Session resumed + Session file {session_path} + History {message_count} messages · {turns} turns + Next /status · /diff · /export" + ) +} + +fn format_compact_report(removed: usize, resulting_messages: usize, skipped: bool) -> String { + if skipped { + format!( + "Compact + Result skipped + Reason Session is already below the compaction threshold + Messages kept {resulting_messages}" + ) + } else { + format!( + "Compact + Result compacted + Messages removed {removed} + Messages kept {resulting_messages} + Tip Use /status to review the trimmed session" + ) + } +} + +fn parse_git_status_metadata(status: Option<&str>) -> (Option, Option) { + let Some(status) = status else { + return (None, None); + }; + let branch = status.lines().next().and_then(|line| { + line.strip_prefix("## ") + .map(|line| { + line.split(['.', ' ']) + .next() + .unwrap_or_default() + .to_string() + }) + .filter(|value| !value.is_empty()) + }); + let project_root = find_git_root().ok(); + (project_root, branch) +} + +fn find_git_root() -> Result> { + let output = std::process::Command::new("git") + .args(["rev-parse", "--show-toplevel"]) + .current_dir(env::current_dir()?) + .output()?; + if !output.status.success() { + return Err("not a git repository".into()); + } + let path = String::from_utf8(output.stdout)?.trim().to_string(); + if path.is_empty() { + return Err("empty git root".into()); + } + Ok(PathBuf::from(path)) +} + +#[allow(clippy::too_many_lines)] +fn run_resume_command( + session_path: &Path, + session: &Session, + command: &SlashCommand, +) -> Result> { + match command { + SlashCommand::Help => Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(render_repl_help()), + }), + SlashCommand::Compact => { + let result = runtime::compact_session( + session, + CompactionConfig { + max_estimated_tokens: 0, + ..CompactionConfig::default() + }, + ); + let removed = result.removed_message_count; + let kept = result.compacted_session.messages.len(); + let skipped = removed == 0; + result.compacted_session.save_to_path(session_path)?; + Ok(ResumeCommandOutcome { + session: result.compacted_session, + message: Some(format_compact_report(removed, kept, skipped)), + }) + } + SlashCommand::Clear { confirm } => { + if !confirm { + return Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some( + "clear: confirmation required; rerun with /clear --confirm".to_string(), + ), + }); + } + let cleared = Session::new(); + cleared.save_to_path(session_path)?; + Ok(ResumeCommandOutcome { + session: cleared, + message: Some(format!( + "Cleared resumed session file {}.", + session_path.display() + )), + }) + } + SlashCommand::Status => { + let tracker = UsageTracker::from_session(session); + let usage = tracker.cumulative_usage(); + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(format_status_report( + "restored-session", + StatusUsage { + message_count: session.messages.len(), + turns: tracker.turns(), + latest: tracker.current_turn_usage(), + cumulative: usage, + estimated_tokens: 0, + }, + default_permission_mode().as_str(), + &status_context(Some(session_path))?, + )), + }) + } + SlashCommand::Cost => { + let usage = UsageTracker::from_session(session).cumulative_usage(); + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(format_cost_report(usage)), + }) + } + SlashCommand::Config { section } => Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(render_config_report(section.as_deref())?), + }), + SlashCommand::Memory => Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(render_memory_report()?), + }), + SlashCommand::Init => Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(init_claw_md()?), + }), + SlashCommand::Diff => Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(render_diff_report()?), + }), + SlashCommand::Version => Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(render_version_report()), + }), + SlashCommand::Export { path } => { + let export_path = resolve_export_path(path.as_deref(), session)?; + fs::write(&export_path, render_export_text(session))?; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(format!( + "Export\n Result wrote transcript\n File {}\n Messages {}", + export_path.display(), + session.messages.len(), + )), + }) + } + SlashCommand::Agents { args } => { + let cwd = env::current_dir()?; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(handle_agents_slash_command(args.as_deref(), &cwd)?), + }) + } + SlashCommand::Skills { args } => { + let cwd = env::current_dir()?; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(handle_skills_slash_command(args.as_deref(), &cwd)?), + }) + } + SlashCommand::Bughunter { .. } + | SlashCommand::Branch { .. } + | SlashCommand::Worktree { .. } + | SlashCommand::CommitPushPr { .. } + | SlashCommand::Commit + | SlashCommand::Pr { .. } + | SlashCommand::Issue { .. } + | SlashCommand::Ultraplan { .. } + | SlashCommand::Teleport { .. } + | SlashCommand::DebugToolCall + | SlashCommand::Resume { .. } + | SlashCommand::Model { .. } + | SlashCommand::Permissions { .. } + | SlashCommand::Session { .. } + | SlashCommand::Plugins { .. } + | SlashCommand::Unknown(_) => Err("unsupported resumed slash command".into()), + } +} + +fn run_repl( + model: String, + allowed_tools: Option, + permission_mode: PermissionMode, +) -> Result<(), Box> { + let mut cli = LiveCli::new(model, true, allowed_tools, permission_mode)?; + let mut editor = input::LineEditor::new("> ", slash_command_completion_candidates()); + println!("{}", cli.startup_banner()); + + loop { + match editor.read_line()? { + input::ReadOutcome::Submit(input) => { + let trimmed = input.trim(); + if trimmed.is_empty() { + continue; + } + if matches!(trimmed, "/exit" | "/quit") { + cli.persist_session()?; + break; + } + if let Some(command) = SlashCommand::parse(trimmed) { + if cli.handle_repl_command(command)? { + cli.persist_session()?; + } + continue; + } + editor.push_history(&input); + cli.run_turn(&input)?; + } + input::ReadOutcome::Cancel => {} + input::ReadOutcome::Exit => { + cli.persist_session()?; + break; + } + } + } + + Ok(()) +} + +#[derive(Debug, Clone)] +struct SessionHandle { + id: String, + path: PathBuf, +} + +#[derive(Debug, Clone)] +struct ManagedSessionSummary { + id: String, + path: PathBuf, + modified_epoch_secs: u64, + message_count: usize, +} + +struct LiveCli { + model: String, + allowed_tools: Option, + permission_mode: PermissionMode, + system_prompt: Vec, + runtime: ConversationRuntime, + session: SessionHandle, +} + +impl LiveCli { + fn new( + model: String, + enable_tools: bool, + allowed_tools: Option, + permission_mode: PermissionMode, + ) -> Result> { + let system_prompt = build_system_prompt()?; + let session = create_managed_session_handle()?; + let runtime = build_runtime( + Session::new(), + model.clone(), + system_prompt.clone(), + enable_tools, + true, + allowed_tools.clone(), + permission_mode, + None, + )?; + let cli = Self { + model, + allowed_tools, + permission_mode, + system_prompt, + runtime, + session, + }; + cli.persist_session()?; + Ok(cli) + } + + fn startup_banner(&self) -> String { + let color = io::stdout().is_terminal(); + let cwd = env::current_dir().ok(); + let cwd_display = cwd.as_ref().map_or_else( + || "".to_string(), + |path| path.display().to_string(), + ); + let workspace_name = cwd + .as_ref() + .and_then(|path| path.file_name()) + .and_then(|name| name.to_str()) + .unwrap_or("workspace"); + let git_branch = status_context(Some(&self.session.path)) + .ok() + .and_then(|context| context.git_branch); + let workspace_summary = git_branch.as_deref().map_or_else( + || workspace_name.to_string(), + |branch| format!("{workspace_name} · {branch}"), + ); + let has_claw_md = cwd + .as_ref() + .is_some_and(|path| path.join("CLAW.md").is_file()); + let mut lines = vec![ + format!( + "{} {}", + if color { + "\x1b[1;38;5;45m🦞 Claw Code\x1b[0m" + } else { + "Claw Code" + }, + if color { + "\x1b[2m· ready\x1b[0m" + } else { + "· ready" + } + ), + format!(" Workspace {workspace_summary}"), + format!(" Directory {cwd_display}"), + format!(" Model {}", self.model), + format!(" Permissions {}", self.permission_mode.as_str()), + format!(" Session {}", self.session.id), + format!( + " Quick start {}", + if has_claw_md { + "/help · /status · ask for a task" + } else { + "/init · /help · /status" + } + ), + " Editor Tab completes slash commands · /vim toggles modal editing" + .to_string(), + " Multiline Shift+Enter or Ctrl+J inserts a newline".to_string(), + ]; + if !has_claw_md { + lines.push( + " First run /init scaffolds CLAW.md, .claw.json, and local session files" + .to_string(), + ); + } + lines.join("\n") + } + + fn run_turn(&mut self, input: &str) -> Result<(), Box> { + let mut spinner = Spinner::new(); + let mut stdout = io::stdout(); + spinner.tick( + "🦀 Thinking...", + TerminalRenderer::new().color_theme(), + &mut stdout, + )?; + let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); + let result = self.runtime.run_turn(input, Some(&mut permission_prompter)); + match result { + Ok(_) => { + spinner.finish( + "✨ Done", + TerminalRenderer::new().color_theme(), + &mut stdout, + )?; + println!(); + self.persist_session()?; + Ok(()) + } + Err(error) => { + spinner.fail( + "❌ Request failed", + TerminalRenderer::new().color_theme(), + &mut stdout, + )?; + Err(Box::new(error)) + } + } + } + + fn run_turn_with_output( + &mut self, + input: &str, + output_format: CliOutputFormat, + ) -> Result<(), Box> { + match output_format { + CliOutputFormat::Text => self.run_turn(input), + CliOutputFormat::Json => self.run_prompt_json(input), + } + } + + fn run_prompt_json(&mut self, input: &str) -> Result<(), Box> { + let session = self.runtime.session().clone(); + let mut runtime = build_runtime( + session, + self.model.clone(), + self.system_prompt.clone(), + true, + false, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); + let summary = runtime.run_turn(input, Some(&mut permission_prompter))?; + self.runtime = runtime; + self.persist_session()?; + println!( + "{}", + json!({ + "message": final_assistant_text(&summary), + "model": self.model, + "iterations": summary.iterations, + "tool_uses": collect_tool_uses(&summary), + "tool_results": collect_tool_results(&summary), + "usage": { + "input_tokens": summary.usage.input_tokens, + "output_tokens": summary.usage.output_tokens, + "cache_creation_input_tokens": summary.usage.cache_creation_input_tokens, + "cache_read_input_tokens": summary.usage.cache_read_input_tokens, + } + }) + ); + Ok(()) + } + + fn handle_repl_command( + &mut self, + command: SlashCommand, + ) -> Result> { + Ok(match command { + SlashCommand::Help => { + println!("{}", render_repl_help()); + false + } + SlashCommand::Status => { + self.print_status(); + false + } + SlashCommand::Bughunter { scope } => { + self.run_bughunter(scope.as_deref())?; + false + } + SlashCommand::Commit => { + self.run_commit()?; + true + } + SlashCommand::Pr { context } => { + self.run_pr(context.as_deref())?; + false + } + SlashCommand::Issue { context } => { + self.run_issue(context.as_deref())?; + false + } + SlashCommand::Ultraplan { task } => { + self.run_ultraplan(task.as_deref())?; + false + } + SlashCommand::Teleport { target } => { + self.run_teleport(target.as_deref())?; + false + } + SlashCommand::DebugToolCall => { + self.run_debug_tool_call()?; + false + } + SlashCommand::Compact => { + self.compact()?; + false + } + SlashCommand::Model { model } => self.set_model(model)?, + SlashCommand::Permissions { mode } => self.set_permissions(mode)?, + SlashCommand::Clear { confirm } => self.clear_session(confirm)?, + SlashCommand::Cost => { + self.print_cost(); + false + } + SlashCommand::Resume { session_path } => self.resume_session(session_path)?, + SlashCommand::Config { section } => { + Self::print_config(section.as_deref())?; + false + } + SlashCommand::Memory => { + Self::print_memory()?; + false + } + SlashCommand::Init => { + run_init()?; + false + } + SlashCommand::Diff => { + Self::print_diff()?; + false + } + SlashCommand::Version => { + Self::print_version(); + false + } + SlashCommand::Export { path } => { + self.export_session(path.as_deref())?; + false + } + SlashCommand::Session { action, target } => { + self.handle_session_command(action.as_deref(), target.as_deref())? + } + SlashCommand::Plugins { action, target } => { + self.handle_plugins_command(action.as_deref(), target.as_deref())? + } + SlashCommand::Agents { args } => { + Self::print_agents(args.as_deref())?; + false + } + SlashCommand::Skills { args } => { + Self::print_skills(args.as_deref())?; + false + } + SlashCommand::Branch { .. } => { + eprintln!( + "{}", + render_mode_unavailable("branch", "git branch commands") + ); + false + } + SlashCommand::Worktree { .. } => { + eprintln!( + "{}", + render_mode_unavailable("worktree", "git worktree commands") + ); + false + } + SlashCommand::CommitPushPr { .. } => { + eprintln!( + "{}", + render_mode_unavailable("commit-push-pr", "commit + push + PR automation") + ); + false + } + SlashCommand::Unknown(name) => { + eprintln!("{}", render_unknown_repl_command(&name)); + false + } + }) + } + + fn persist_session(&self) -> Result<(), Box> { + self.runtime.session().save_to_path(&self.session.path)?; + Ok(()) + } + + fn print_status(&self) { + let cumulative = self.runtime.usage().cumulative_usage(); + let latest = self.runtime.usage().current_turn_usage(); + println!( + "{}", + format_status_report( + &self.model, + StatusUsage { + message_count: self.runtime.session().messages.len(), + turns: self.runtime.usage().turns(), + latest, + cumulative, + estimated_tokens: self.runtime.estimated_tokens(), + }, + self.permission_mode.as_str(), + &status_context(Some(&self.session.path)).expect("status context should load"), + ) + ); + } + + fn set_model(&mut self, model: Option) -> Result> { + let Some(model) = model else { + println!( + "{}", + format_model_report( + &self.model, + self.runtime.session().messages.len(), + self.runtime.usage().turns(), + ) + ); + return Ok(false); + }; + + let model = resolve_model_alias(&model).to_string(); + + if model == self.model { + println!( + "{}", + format_model_report( + &self.model, + self.runtime.session().messages.len(), + self.runtime.usage().turns(), + ) + ); + return Ok(false); + } + + let previous = self.model.clone(); + let session = self.runtime.session().clone(); + let message_count = session.messages.len(); + self.runtime = build_runtime( + session, + model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + self.model.clone_from(&model); + println!( + "{}", + format_model_switch_report(&previous, &model, message_count) + ); + Ok(true) + } + + fn set_permissions( + &mut self, + mode: Option, + ) -> Result> { + let Some(mode) = mode else { + println!( + "{}", + format_permissions_report(self.permission_mode.as_str()) + ); + return Ok(false); + }; + + let normalized = normalize_permission_mode(&mode).ok_or_else(|| { + format!( + "unsupported permission mode '{mode}'. Use read-only, workspace-write, or danger-full-access." + ) + })?; + + if normalized == self.permission_mode.as_str() { + println!("{}", format_permissions_report(normalized)); + return Ok(false); + } + + let previous = self.permission_mode.as_str().to_string(); + let session = self.runtime.session().clone(); + self.permission_mode = permission_mode_from_label(normalized); + self.runtime = build_runtime( + session, + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + println!( + "{}", + format_permissions_switch_report(&previous, normalized) + ); + Ok(true) + } + + fn clear_session(&mut self, confirm: bool) -> Result> { + if !confirm { + println!( + "clear: confirmation required; run /clear --confirm to start a fresh session." + ); + return Ok(false); + } + + self.session = create_managed_session_handle()?; + self.runtime = build_runtime( + Session::new(), + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + println!( + "Session cleared\n Mode fresh session\n Preserved model {}\n Permission mode {}\n Session {}", + self.model, + self.permission_mode.as_str(), + self.session.id, + ); + Ok(true) + } + + fn print_cost(&self) { + let cumulative = self.runtime.usage().cumulative_usage(); + println!("{}", format_cost_report(cumulative)); + } + + fn resume_session( + &mut self, + session_path: Option, + ) -> Result> { + let Some(session_ref) = session_path else { + println!("Usage: /resume "); + return Ok(false); + }; + + let handle = resolve_session_reference(&session_ref)?; + let session = Session::load_from_path(&handle.path)?; + let message_count = session.messages.len(); + self.runtime = build_runtime( + session, + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + self.session = handle; + println!( + "{}", + format_resume_report( + &self.session.path.display().to_string(), + message_count, + self.runtime.usage().turns(), + ) + ); + Ok(true) + } + + fn print_config(section: Option<&str>) -> Result<(), Box> { + println!("{}", render_config_report(section)?); + Ok(()) + } + + fn print_memory() -> Result<(), Box> { + println!("{}", render_memory_report()?); + Ok(()) + } + + fn print_agents(args: Option<&str>) -> Result<(), Box> { + let cwd = env::current_dir()?; + println!("{}", handle_agents_slash_command(args, &cwd)?); + Ok(()) + } + + fn print_skills(args: Option<&str>) -> Result<(), Box> { + let cwd = env::current_dir()?; + println!("{}", handle_skills_slash_command(args, &cwd)?); + Ok(()) + } + + fn print_diff() -> Result<(), Box> { + println!("{}", render_diff_report()?); + Ok(()) + } + + fn print_version() { + println!("{}", render_version_report()); + } + + fn export_session( + &self, + requested_path: Option<&str>, + ) -> Result<(), Box> { + let export_path = resolve_export_path(requested_path, self.runtime.session())?; + fs::write(&export_path, render_export_text(self.runtime.session()))?; + println!( + "Export\n Result wrote transcript\n File {}\n Messages {}", + export_path.display(), + self.runtime.session().messages.len(), + ); + Ok(()) + } + + fn handle_session_command( + &mut self, + action: Option<&str>, + target: Option<&str>, + ) -> Result> { + match action { + None | Some("list") => { + println!("{}", render_session_list(&self.session.id)?); + Ok(false) + } + Some("switch") => { + let Some(target) = target else { + println!("Usage: /session switch "); + return Ok(false); + }; + let handle = resolve_session_reference(target)?; + let session = Session::load_from_path(&handle.path)?; + let message_count = session.messages.len(); + self.runtime = build_runtime( + session, + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + self.session = handle; + println!( + "Session switched\n Active session {}\n File {}\n Messages {}", + self.session.id, + self.session.path.display(), + message_count, + ); + Ok(true) + } + Some(other) => { + println!("Unknown /session action '{other}'. Use /session list or /session switch ."); + Ok(false) + } + } + } + + fn handle_plugins_command( + &mut self, + action: Option<&str>, + target: Option<&str>, + ) -> Result> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader.load()?; + let mut manager = build_plugin_manager(&cwd, &loader, &runtime_config); + let result = handle_plugins_slash_command(action, target, &mut manager)?; + println!("{}", result.message); + if result.reload_runtime { + self.reload_runtime_features()?; + } + Ok(false) + } + + fn reload_runtime_features(&mut self) -> Result<(), Box> { + self.runtime = build_runtime( + self.runtime.session().clone(), + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + self.persist_session() + } + + fn compact(&mut self) -> Result<(), Box> { + let result = self.runtime.compact(CompactionConfig::default()); + let removed = result.removed_message_count; + let kept = result.compacted_session.messages.len(); + let skipped = removed == 0; + self.runtime = build_runtime( + result.compacted_session, + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + self.persist_session()?; + println!("{}", format_compact_report(removed, kept, skipped)); + Ok(()) + } + + fn run_internal_prompt_text_with_progress( + &self, + prompt: &str, + enable_tools: bool, + progress: Option, + ) -> Result> { + let session = self.runtime.session().clone(); + let mut runtime = build_runtime( + session, + self.model.clone(), + self.system_prompt.clone(), + enable_tools, + false, + self.allowed_tools.clone(), + self.permission_mode, + progress, + )?; + let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); + let summary = runtime.run_turn(prompt, Some(&mut permission_prompter))?; + Ok(final_assistant_text(&summary).trim().to_string()) + } + + fn run_internal_prompt_text( + &self, + prompt: &str, + enable_tools: bool, + ) -> Result> { + self.run_internal_prompt_text_with_progress(prompt, enable_tools, None) + } + + fn run_bughunter(&self, scope: Option<&str>) -> Result<(), Box> { + let scope = scope.unwrap_or("the current repository"); + let prompt = format!( + "You are /bughunter. Inspect {scope} and identify the most likely bugs or correctness issues. Prioritize concrete findings with file paths, severity, and suggested fixes. Use tools if needed." + ); + println!("{}", self.run_internal_prompt_text(&prompt, true)?); + Ok(()) + } + + fn run_ultraplan(&self, task: Option<&str>) -> Result<(), Box> { + let task = task.unwrap_or("the current repo work"); + let prompt = format!( + "You are /ultraplan. Produce a deep multi-step execution plan for {task}. Include goals, risks, implementation sequence, verification steps, and rollback considerations. Use tools if needed." + ); + let mut progress = InternalPromptProgressRun::start_ultraplan(task); + match self.run_internal_prompt_text_with_progress(&prompt, true, Some(progress.reporter())) + { + Ok(plan) => { + progress.finish_success(); + println!("{plan}"); + Ok(()) + } + Err(error) => { + progress.finish_failure(&error.to_string()); + Err(error) + } + } + } + + #[allow(clippy::unused_self)] + fn run_teleport(&self, target: Option<&str>) -> Result<(), Box> { + let Some(target) = target.map(str::trim).filter(|value| !value.is_empty()) else { + println!("Usage: /teleport "); + return Ok(()); + }; + + println!("{}", render_teleport_report(target)?); + Ok(()) + } + + fn run_debug_tool_call(&self) -> Result<(), Box> { + println!("{}", render_last_tool_debug_report(self.runtime.session())?); + Ok(()) + } + + fn run_commit(&mut self) -> Result<(), Box> { + let status = git_output(&["status", "--short"])?; + if status.trim().is_empty() { + println!("Commit\n Result skipped\n Reason no workspace changes"); + return Ok(()); + } + + git_status_ok(&["add", "-A"])?; + let staged_stat = git_output(&["diff", "--cached", "--stat"])?; + let prompt = format!( + "Generate a git commit message in plain text Lore format only. Base it on this staged diff summary:\n\n{}\n\nRecent conversation context:\n{}", + truncate_for_prompt(&staged_stat, 8_000), + recent_user_context(self.runtime.session(), 6) + ); + let message = sanitize_generated_message(&self.run_internal_prompt_text(&prompt, false)?); + if message.trim().is_empty() { + return Err("generated commit message was empty".into()); + } + + let path = write_temp_text_file("claw-commit-message.txt", &message)?; + let output = Command::new("git") + .args(["commit", "--file"]) + .arg(&path) + .current_dir(env::current_dir()?) + .output()?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(format!("git commit failed: {stderr}").into()); + } + + println!( + "Commit\n Result created\n Message file {}\n\n{}", + path.display(), + message.trim() + ); + Ok(()) + } + + fn run_pr(&self, context: Option<&str>) -> Result<(), Box> { + let staged = git_output(&["diff", "--stat"])?; + let prompt = format!( + "Generate a pull request title and body from this conversation and diff summary. Output plain text in this format exactly:\nTITLE: \nBODY:\n<body markdown>\n\nContext hint: {}\n\nDiff summary:\n{}", + context.unwrap_or("none"), + truncate_for_prompt(&staged, 10_000) + ); + let draft = sanitize_generated_message(&self.run_internal_prompt_text(&prompt, false)?); + let (title, body) = parse_titled_body(&draft) + .ok_or_else(|| "failed to parse generated PR title/body".to_string())?; + + if command_exists("gh") { + let body_path = write_temp_text_file("claw-pr-body.md", &body)?; + let output = Command::new("gh") + .args(["pr", "create", "--title", &title, "--body-file"]) + .arg(&body_path) + .current_dir(env::current_dir()?) + .output()?; + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + println!( + "PR\n Result created\n Title {title}\n URL {}", + if stdout.is_empty() { "<unknown>" } else { &stdout } + ); + return Ok(()); + } + } + + println!("PR draft\n Title {title}\n\n{body}"); + Ok(()) + } + + fn run_issue(&self, context: Option<&str>) -> Result<(), Box<dyn std::error::Error>> { + let prompt = format!( + "Generate a GitHub issue title and body from this conversation. Output plain text in this format exactly:\nTITLE: <title>\nBODY:\n<body markdown>\n\nContext hint: {}\n\nConversation context:\n{}", + context.unwrap_or("none"), + truncate_for_prompt(&recent_user_context(self.runtime.session(), 10), 10_000) + ); + let draft = sanitize_generated_message(&self.run_internal_prompt_text(&prompt, false)?); + let (title, body) = parse_titled_body(&draft) + .ok_or_else(|| "failed to parse generated issue title/body".to_string())?; + + if command_exists("gh") { + let body_path = write_temp_text_file("claw-issue-body.md", &body)?; + let output = Command::new("gh") + .args(["issue", "create", "--title", &title, "--body-file"]) + .arg(&body_path) + .current_dir(env::current_dir()?) + .output()?; + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + println!( + "Issue\n Result created\n Title {title}\n URL {}", + if stdout.is_empty() { "<unknown>" } else { &stdout } + ); + return Ok(()); + } + } + + println!("Issue draft\n Title {title}\n\n{body}"); + Ok(()) + } +} + +fn sessions_dir() -> Result<PathBuf, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + let path = cwd.join(".claw").join("sessions"); + fs::create_dir_all(&path)?; + Ok(path) +} + +fn create_managed_session_handle() -> Result<SessionHandle, Box<dyn std::error::Error>> { + let id = generate_session_id(); + let path = sessions_dir()?.join(format!("{id}.json")); + Ok(SessionHandle { id, path }) +} + +fn generate_session_id() -> String { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_millis()) + .unwrap_or_default(); + format!("session-{millis}") +} + +fn resolve_session_reference(reference: &str) -> Result<SessionHandle, Box<dyn std::error::Error>> { + let direct = PathBuf::from(reference); + let path = if direct.exists() { + direct + } else { + sessions_dir()?.join(format!("{reference}.json")) + }; + if !path.exists() { + return Err(format!("session not found: {reference}").into()); + } + let id = path + .file_stem() + .and_then(|value| value.to_str()) + .unwrap_or(reference) + .to_string(); + Ok(SessionHandle { id, path }) +} + +fn list_managed_sessions() -> Result<Vec<ManagedSessionSummary>, Box<dyn std::error::Error>> { + let mut sessions = Vec::new(); + for entry in fs::read_dir(sessions_dir()?)? { + let entry = entry?; + let path = entry.path(); + if path.extension().and_then(|ext| ext.to_str()) != Some("json") { + continue; + } + let metadata = entry.metadata()?; + let modified_epoch_secs = metadata + .modified() + .ok() + .and_then(|time| time.duration_since(UNIX_EPOCH).ok()) + .map(|duration| duration.as_secs()) + .unwrap_or_default(); + let message_count = Session::load_from_path(&path) + .map(|session| session.messages.len()) + .unwrap_or_default(); + let id = path + .file_stem() + .and_then(|value| value.to_str()) + .unwrap_or("unknown") + .to_string(); + sessions.push(ManagedSessionSummary { + id, + path, + modified_epoch_secs, + message_count, + }); + } + sessions.sort_by(|left, right| right.modified_epoch_secs.cmp(&left.modified_epoch_secs)); + Ok(sessions) +} + +fn format_relative_timestamp(epoch_secs: u64) -> String { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(epoch_secs); + let elapsed = now.saturating_sub(epoch_secs); + match elapsed { + 0..=59 => format!("{elapsed}s ago"), + 60..=3_599 => format!("{}m ago", elapsed / 60), + 3_600..=86_399 => format!("{}h ago", elapsed / 3_600), + _ => format!("{}d ago", elapsed / 86_400), + } +} + +fn render_session_list(active_session_id: &str) -> Result<String, Box<dyn std::error::Error>> { + let sessions = list_managed_sessions()?; + let mut lines = vec![ + "Sessions".to_string(), + format!(" Directory {}", sessions_dir()?.display()), + ]; + if sessions.is_empty() { + lines.push(" No managed sessions saved yet.".to_string()); + return Ok(lines.join("\n")); + } + for session in sessions { + let marker = if session.id == active_session_id { + "● current" + } else { + "○ saved" + }; + lines.push(format!( + " {id:<20} {marker:<10} {msgs:>3} msgs · updated {modified}", + id = session.id, + msgs = session.message_count, + modified = format_relative_timestamp(session.modified_epoch_secs), + )); + lines.push(format!(" {}", session.path.display())); + } + Ok(lines.join("\n")) +} + +fn render_repl_help() -> String { + [ + "Interactive REPL".to_string(), + " Quick start Ask a task in plain English or use one of the core commands below." + .to_string(), + " Core commands /help · /status · /model · /permissions · /compact".to_string(), + " Exit /exit or /quit".to_string(), + " Vim mode /vim toggles modal editing".to_string(), + " History Up/Down recalls previous prompts".to_string(), + " Completion Tab cycles slash command matches".to_string(), + " Cancel Ctrl-C clears input (or exits on an empty prompt)".to_string(), + " Multiline Shift+Enter or Ctrl+J inserts a newline".to_string(), + String::new(), + render_slash_command_help(), + ] + .join( + " +", + ) +} + +fn append_slash_command_suggestions(lines: &mut Vec<String>, name: &str) { + let suggestions = suggest_slash_commands(name, 3); + if suggestions.is_empty() { + lines.push(" Try /help shows the full slash command map".to_string()); + return; + } + + lines.push(" Try /help shows the full slash command map".to_string()); + lines.push("Suggestions".to_string()); + lines.extend( + suggestions + .into_iter() + .map(|suggestion| format!(" {suggestion}")), + ); +} + +fn render_unknown_repl_command(name: &str) -> String { + let mut lines = vec![ + "Unknown slash command".to_string(), + format!(" Command /{name}"), + ]; + append_repl_command_suggestions(&mut lines, name); + lines.join("\n") +} + +fn append_repl_command_suggestions(lines: &mut Vec<String>, name: &str) { + let suggestions = suggest_repl_commands(name); + if suggestions.is_empty() { + lines.push(" Try /help shows the full slash command map".to_string()); + return; + } + + lines.push(" Try /help shows the full slash command map".to_string()); + lines.push("Suggestions".to_string()); + lines.extend( + suggestions + .into_iter() + .map(|suggestion| format!(" {suggestion}")), + ); +} + +fn render_mode_unavailable(command: &str, label: &str) -> String { + [ + "Command unavailable in this REPL mode".to_string(), + format!(" Command /{command}"), + format!(" Feature {label}"), + " Tip Use /help to find currently wired REPL commands".to_string(), + ] + .join("\n") +} + +fn status_context( + session_path: Option<&Path>, +) -> Result<StatusContext, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let discovered_config_files = loader.discover().len(); + let runtime_config = loader.load()?; + let project_context = ProjectContext::discover_with_git(&cwd, DEFAULT_DATE)?; + let (project_root, git_branch) = + parse_git_status_metadata(project_context.git_status.as_deref()); + Ok(StatusContext { + cwd, + session_path: session_path.map(Path::to_path_buf), + loaded_config_files: runtime_config.loaded_entries().len(), + discovered_config_files, + memory_file_count: project_context.instruction_files.len(), + project_root, + git_branch, + }) +} + +fn format_status_report( + model: &str, + usage: StatusUsage, + permission_mode: &str, + context: &StatusContext, +) -> String { + [ + format!( + "Session + Model {model} + Permissions {permission_mode} + Activity {} messages · {} turns + Tokens est {} · latest {} · total {}", + usage.message_count, + usage.turns, + usage.estimated_tokens, + usage.latest.total_tokens(), + usage.cumulative.total_tokens(), + ), + format!( + "Usage + Cumulative input {} + Cumulative output {} + Cache create {} + Cache read {}", + usage.cumulative.input_tokens, + usage.cumulative.output_tokens, + usage.cumulative.cache_creation_input_tokens, + usage.cumulative.cache_read_input_tokens, + ), + format!( + "Workspace + Folder {} + Project root {} + Git branch {} + Session file {} + Config files loaded {}/{} + Memory files {} + +Next + /help Browse commands + /session list Inspect saved sessions + /diff Review current workspace changes", + context.cwd.display(), + context + .project_root + .as_ref() + .map_or_else(|| "unknown".to_string(), |path| path.display().to_string()), + context.git_branch.as_deref().unwrap_or("unknown"), + context.session_path.as_ref().map_or_else( + || "live-repl".to_string(), + |path| path.display().to_string() + ), + context.loaded_config_files, + context.discovered_config_files, + context.memory_file_count, + ), + ] + .join( + " + +", + ) +} + +fn render_config_report(section: Option<&str>) -> Result<String, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let discovered = loader.discover(); + let runtime_config = loader.load()?; + + let mut lines = vec![ + format!( + "Config + Working directory {} + Loaded files {} + Merged keys {}", + cwd.display(), + runtime_config.loaded_entries().len(), + runtime_config.merged().len() + ), + "Discovered files".to_string(), + ]; + for entry in discovered { + let source = match entry.source { + ConfigSource::User => "user", + ConfigSource::Project => "project", + ConfigSource::Local => "local", + }; + let status = if runtime_config + .loaded_entries() + .iter() + .any(|loaded_entry| loaded_entry.path == entry.path) + { + "loaded" + } else { + "missing" + }; + lines.push(format!( + " {source:<7} {status:<7} {}", + entry.path.display() + )); + } + + if let Some(section) = section { + lines.push(format!("Merged section: {section}")); + let value = match section { + "env" => runtime_config.get("env"), + "hooks" => runtime_config.get("hooks"), + "model" => runtime_config.get("model"), + "plugins" => runtime_config + .get("plugins") + .or_else(|| runtime_config.get("enabledPlugins")), + other => { + lines.push(format!( + " Unsupported config section '{other}'. Use env, hooks, model, or plugins." + )); + return Ok(lines.join( + " +", + )); + } + }; + lines.push(format!( + " {}", + match value { + Some(value) => value.render(), + None => "<unset>".to_string(), + } + )); + return Ok(lines.join( + " +", + )); + } + + lines.push("Merged JSON".to_string()); + lines.push(format!(" {}", runtime_config.as_json().render())); + Ok(lines.join( + " +", + )) +} + +fn render_memory_report() -> Result<String, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + let project_context = ProjectContext::discover(&cwd, DEFAULT_DATE)?; + let mut lines = vec![format!( + "Memory + Working directory {} + Instruction files {}", + cwd.display(), + project_context.instruction_files.len() + )]; + if project_context.instruction_files.is_empty() { + lines.push("Discovered files".to_string()); + lines.push( + " No CLAW instruction files discovered in the current directory ancestry.".to_string(), + ); + } else { + lines.push("Discovered files".to_string()); + for (index, file) in project_context.instruction_files.iter().enumerate() { + let preview = file.content.lines().next().unwrap_or("").trim(); + let preview = if preview.is_empty() { + "<empty>" + } else { + preview + }; + lines.push(format!(" {}. {}", index + 1, file.path.display(),)); + lines.push(format!( + " lines={} preview={}", + file.content.lines().count(), + preview + )); + } + } + Ok(lines.join( + " +", + )) +} + +fn init_claw_md() -> Result<String, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + Ok(initialize_repo(&cwd)?.render()) +} + +fn run_init() -> Result<(), Box<dyn std::error::Error>> { + println!("{}", init_claw_md()?); + Ok(()) +} + +fn normalize_permission_mode(mode: &str) -> Option<&'static str> { + match mode.trim() { + "read-only" => Some("read-only"), + "workspace-write" => Some("workspace-write"), + "danger-full-access" => Some("danger-full-access"), + _ => None, + } +} + +fn render_diff_report() -> Result<String, Box<dyn std::error::Error>> { + let output = std::process::Command::new("git") + .args(["diff", "--", ":(exclude).omx"]) + .current_dir(env::current_dir()?) + .output()?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(format!("git diff failed: {stderr}").into()); + } + let diff = String::from_utf8(output.stdout)?; + if diff.trim().is_empty() { + return Ok( + "Diff\n Result clean working tree\n Detail no current changes" + .to_string(), + ); + } + Ok(format!("Diff\n\n{}", diff.trim_end())) +} + +fn render_teleport_report(target: &str) -> Result<String, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + + let file_list = Command::new("rg") + .args(["--files"]) + .current_dir(&cwd) + .output()?; + let file_matches = if file_list.status.success() { + String::from_utf8(file_list.stdout)? + .lines() + .filter(|line| line.contains(target)) + .take(10) + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + } else { + Vec::new() + }; + + let content_output = Command::new("rg") + .args(["-n", "-S", "--color", "never", target, "."]) + .current_dir(&cwd) + .output()?; + + let mut lines = vec![format!("Teleport\n Target {target}")]; + if !file_matches.is_empty() { + lines.push(String::new()); + lines.push("File matches".to_string()); + lines.extend(file_matches.into_iter().map(|path| format!(" {path}"))); + } + + if content_output.status.success() { + let matches = String::from_utf8(content_output.stdout)?; + if !matches.trim().is_empty() { + lines.push(String::new()); + lines.push("Content matches".to_string()); + lines.push(truncate_for_prompt(&matches, 4_000)); + } + } + + if lines.len() == 1 { + lines.push(" Result no matches found".to_string()); + } + + Ok(lines.join("\n")) +} + +fn render_last_tool_debug_report(session: &Session) -> Result<String, Box<dyn std::error::Error>> { + let last_tool_use = session + .messages + .iter() + .rev() + .find_map(|message| { + message.blocks.iter().rev().find_map(|block| match block { + ContentBlock::ToolUse { id, name, input } => { + Some((id.clone(), name.clone(), input.clone())) + } + _ => None, + }) + }) + .ok_or_else(|| "no prior tool call found in session".to_string())?; + + let tool_result = session.messages.iter().rev().find_map(|message| { + message.blocks.iter().rev().find_map(|block| match block { + ContentBlock::ToolResult { + tool_use_id, + tool_name, + output, + is_error, + } if tool_use_id == &last_tool_use.0 => { + Some((tool_name.clone(), output.clone(), *is_error)) + } + _ => None, + }) + }); + + let mut lines = vec![ + "Debug tool call".to_string(), + format!(" Tool id {}", last_tool_use.0), + format!(" Tool name {}", last_tool_use.1), + " Input".to_string(), + indent_block(&last_tool_use.2, 4), + ]; + + match tool_result { + Some((tool_name, output, is_error)) => { + lines.push(" Result".to_string()); + lines.push(format!(" name {tool_name}")); + lines.push(format!( + " status {}", + if is_error { "error" } else { "ok" } + )); + lines.push(indent_block(&output, 4)); + } + None => lines.push(" Result missing tool result".to_string()), + } + + Ok(lines.join("\n")) +} + +fn indent_block(value: &str, spaces: usize) -> String { + let indent = " ".repeat(spaces); + value + .lines() + .map(|line| format!("{indent}{line}")) + .collect::<Vec<_>>() + .join("\n") +} + +fn git_output(args: &[&str]) -> Result<String, Box<dyn std::error::Error>> { + let output = Command::new("git") + .args(args) + .current_dir(env::current_dir()?) + .output()?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(format!("git {} failed: {stderr}", args.join(" ")).into()); + } + Ok(String::from_utf8(output.stdout)?) +} + +fn git_status_ok(args: &[&str]) -> Result<(), Box<dyn std::error::Error>> { + let output = Command::new("git") + .args(args) + .current_dir(env::current_dir()?) + .output()?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(format!("git {} failed: {stderr}", args.join(" ")).into()); + } + Ok(()) +} + +fn command_exists(name: &str) -> bool { + Command::new("which") + .arg(name) + .output() + .map(|output| output.status.success()) + .unwrap_or(false) +} + +fn write_temp_text_file( + filename: &str, + contents: &str, +) -> Result<PathBuf, Box<dyn std::error::Error>> { + let path = env::temp_dir().join(filename); + fs::write(&path, contents)?; + Ok(path) +} + +fn recent_user_context(session: &Session, limit: usize) -> String { + let requests = session + .messages + .iter() + .filter(|message| message.role == MessageRole::User) + .filter_map(|message| { + message.blocks.iter().find_map(|block| match block { + ContentBlock::Text { text } => Some(text.trim().to_string()), + _ => None, + }) + }) + .rev() + .take(limit) + .collect::<Vec<_>>(); + + if requests.is_empty() { + "<no prior user messages>".to_string() + } else { + requests + .into_iter() + .rev() + .enumerate() + .map(|(index, text)| format!("{}. {}", index + 1, text)) + .collect::<Vec<_>>() + .join("\n") + } +} + +fn truncate_for_prompt(value: &str, limit: usize) -> String { + if value.chars().count() <= limit { + value.trim().to_string() + } else { + let truncated = value.chars().take(limit).collect::<String>(); + format!("{}\n…[truncated]", truncated.trim_end()) + } +} + +fn sanitize_generated_message(value: &str) -> String { + value.trim().trim_matches('`').trim().replace("\r\n", "\n") +} + +fn parse_titled_body(value: &str) -> Option<(String, String)> { + let normalized = sanitize_generated_message(value); + let title = normalized + .lines() + .find_map(|line| line.strip_prefix("TITLE:").map(str::trim))?; + let body_start = normalized.find("BODY:")?; + let body = normalized[body_start + "BODY:".len()..].trim(); + Some((title.to_string(), body.to_string())) +} + +fn render_version_report() -> String { + let git_sha = GIT_SHA.unwrap_or("unknown"); + let target = BUILD_TARGET.unwrap_or("unknown"); + format!( + "Claw Code\n Version {VERSION}\n Git SHA {git_sha}\n Target {target}\n Build date {DEFAULT_DATE}\n\nSupport\n Help claw --help\n REPL /help" + ) +} + +fn render_export_text(session: &Session) -> String { + let mut lines = vec!["# Conversation Export".to_string(), String::new()]; + for (index, message) in session.messages.iter().enumerate() { + let role = match message.role { + MessageRole::System => "system", + MessageRole::User => "user", + MessageRole::Assistant => "assistant", + MessageRole::Tool => "tool", + }; + lines.push(format!("## {}. {role}", index + 1)); + for block in &message.blocks { + match block { + ContentBlock::Text { text } => lines.push(text.clone()), + ContentBlock::ToolUse { id, name, input } => { + lines.push(format!("[tool_use id={id} name={name}] {input}")); + } + ContentBlock::ToolResult { + tool_use_id, + tool_name, + output, + is_error, + } => { + lines.push(format!( + "[tool_result id={tool_use_id} name={tool_name} error={is_error}] {output}" + )); + } + } + } + lines.push(String::new()); + } + lines.join("\n") +} + +fn default_export_filename(session: &Session) -> String { + let stem = session + .messages + .iter() + .find_map(|message| match message.role { + MessageRole::User => message.blocks.iter().find_map(|block| match block { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }), + _ => None, + }) + .map_or("conversation", |text| { + text.lines().next().unwrap_or("conversation") + }) + .chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() { + ch.to_ascii_lowercase() + } else { + '-' + } + }) + .collect::<String>() + .split('-') + .filter(|part| !part.is_empty()) + .take(8) + .collect::<Vec<_>>() + .join("-"); + let fallback = if stem.is_empty() { + "conversation" + } else { + &stem + }; + format!("{fallback}.txt") +} + +fn resolve_export_path( + requested_path: Option<&str>, + session: &Session, +) -> Result<PathBuf, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + let file_name = + requested_path.map_or_else(|| default_export_filename(session), ToOwned::to_owned); + let final_name = if Path::new(&file_name) + .extension() + .is_some_and(|ext| ext.eq_ignore_ascii_case("txt")) + { + file_name + } else { + format!("{file_name}.txt") + }; + Ok(cwd.join(final_name)) +} + +fn build_system_prompt() -> Result<Vec<String>, Box<dyn std::error::Error>> { + Ok(load_system_prompt( + env::current_dir()?, + DEFAULT_DATE, + env::consts::OS, + "unknown", + )?) +} + +fn build_runtime_plugin_state( +) -> Result<(runtime::RuntimeFeatureConfig, GlobalToolRegistry), Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader.load()?; + let plugin_manager = build_plugin_manager(&cwd, &loader, &runtime_config); + let tool_registry = GlobalToolRegistry::with_plugin_tools(plugin_manager.aggregated_tools()?)?; + Ok((runtime_config.feature_config().clone(), tool_registry)) +} + +fn build_plugin_manager( + cwd: &Path, + loader: &ConfigLoader, + runtime_config: &runtime::RuntimeConfig, +) -> PluginManager { + let plugin_settings = runtime_config.plugins(); + let mut plugin_config = PluginManagerConfig::new(loader.config_home().to_path_buf()); + plugin_config.enabled_plugins = plugin_settings.enabled_plugins().clone(); + plugin_config.external_dirs = plugin_settings + .external_directories() + .iter() + .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)) + .collect(); + plugin_config.install_root = plugin_settings + .install_root() + .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); + plugin_config.registry_path = plugin_settings + .registry_path() + .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); + plugin_config.bundled_root = plugin_settings + .bundled_root() + .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); + PluginManager::new(plugin_config) +} + +fn resolve_plugin_path(cwd: &Path, config_home: &Path, value: &str) -> PathBuf { + let path = PathBuf::from(value); + if path.is_absolute() { + path + } else if value.starts_with('.') { + cwd.join(path) + } else { + config_home.join(path) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct InternalPromptProgressState { + command_label: &'static str, + task_label: String, + step: usize, + phase: String, + detail: Option<String>, + saw_final_text: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum InternalPromptProgressEvent { + Started, + Update, + Heartbeat, + Complete, + Failed, +} + +#[derive(Debug)] +struct InternalPromptProgressShared { + state: Mutex<InternalPromptProgressState>, + output_lock: Mutex<()>, + started_at: Instant, +} + +#[derive(Debug, Clone)] +struct InternalPromptProgressReporter { + shared: Arc<InternalPromptProgressShared>, +} + +#[derive(Debug)] +struct InternalPromptProgressRun { + reporter: InternalPromptProgressReporter, + heartbeat_stop: Option<mpsc::Sender<()>>, + heartbeat_handle: Option<thread::JoinHandle<()>>, +} + +impl InternalPromptProgressReporter { + fn ultraplan(task: &str) -> Self { + Self { + shared: Arc::new(InternalPromptProgressShared { + state: Mutex::new(InternalPromptProgressState { + command_label: "Ultraplan", + task_label: task.to_string(), + step: 0, + phase: "planning started".to_string(), + detail: Some(format!("task: {task}")), + saw_final_text: false, + }), + output_lock: Mutex::new(()), + started_at: Instant::now(), + }), + } + } + + fn emit(&self, event: InternalPromptProgressEvent, error: Option<&str>) { + let snapshot = self.snapshot(); + let line = format_internal_prompt_progress_line(event, &snapshot, self.elapsed(), error); + self.write_line(&line); + } + + fn mark_model_phase(&self) { + let snapshot = { + let mut state = self + .shared + .state + .lock() + .expect("internal prompt progress state poisoned"); + state.step += 1; + state.phase = if state.step == 1 { + "analyzing request".to_string() + } else { + "reviewing findings".to_string() + }; + state.detail = Some(format!("task: {}", state.task_label)); + state.clone() + }; + self.write_line(&format_internal_prompt_progress_line( + InternalPromptProgressEvent::Update, + &snapshot, + self.elapsed(), + None, + )); + } + + fn mark_tool_phase(&self, name: &str, input: &str) { + let detail = describe_tool_progress(name, input); + let snapshot = { + let mut state = self + .shared + .state + .lock() + .expect("internal prompt progress state poisoned"); + state.step += 1; + state.phase = format!("running {name}"); + state.detail = Some(detail); + state.clone() + }; + self.write_line(&format_internal_prompt_progress_line( + InternalPromptProgressEvent::Update, + &snapshot, + self.elapsed(), + None, + )); + } + + fn mark_text_phase(&self, text: &str) { + let trimmed = text.trim(); + if trimmed.is_empty() { + return; + } + let detail = truncate_for_summary(first_visible_line(trimmed), 120); + let snapshot = { + let mut state = self + .shared + .state + .lock() + .expect("internal prompt progress state poisoned"); + if state.saw_final_text { + return; + } + state.saw_final_text = true; + state.step += 1; + state.phase = "drafting final plan".to_string(); + state.detail = (!detail.is_empty()).then_some(detail); + state.clone() + }; + self.write_line(&format_internal_prompt_progress_line( + InternalPromptProgressEvent::Update, + &snapshot, + self.elapsed(), + None, + )); + } + + fn emit_heartbeat(&self) { + let snapshot = self.snapshot(); + self.write_line(&format_internal_prompt_progress_line( + InternalPromptProgressEvent::Heartbeat, + &snapshot, + self.elapsed(), + None, + )); + } + + fn snapshot(&self) -> InternalPromptProgressState { + self.shared + .state + .lock() + .expect("internal prompt progress state poisoned") + .clone() + } + + fn elapsed(&self) -> Duration { + self.shared.started_at.elapsed() + } + + fn write_line(&self, line: &str) { + let _guard = self + .shared + .output_lock + .lock() + .expect("internal prompt progress output lock poisoned"); + let mut stdout = io::stdout(); + let _ = writeln!(stdout, "{line}"); + let _ = stdout.flush(); + } +} + +impl InternalPromptProgressRun { + fn start_ultraplan(task: &str) -> Self { + let reporter = InternalPromptProgressReporter::ultraplan(task); + reporter.emit(InternalPromptProgressEvent::Started, None); + + let (heartbeat_stop, heartbeat_rx) = mpsc::channel(); + let heartbeat_reporter = reporter.clone(); + let heartbeat_handle = thread::spawn(move || loop { + match heartbeat_rx.recv_timeout(INTERNAL_PROGRESS_HEARTBEAT_INTERVAL) { + Ok(()) | Err(RecvTimeoutError::Disconnected) => break, + Err(RecvTimeoutError::Timeout) => heartbeat_reporter.emit_heartbeat(), + } + }); + + Self { + reporter, + heartbeat_stop: Some(heartbeat_stop), + heartbeat_handle: Some(heartbeat_handle), + } + } + + fn reporter(&self) -> InternalPromptProgressReporter { + self.reporter.clone() + } + + fn finish_success(&mut self) { + self.stop_heartbeat(); + self.reporter + .emit(InternalPromptProgressEvent::Complete, None); + } + + fn finish_failure(&mut self, error: &str) { + self.stop_heartbeat(); + self.reporter + .emit(InternalPromptProgressEvent::Failed, Some(error)); + } + + fn stop_heartbeat(&mut self) { + if let Some(sender) = self.heartbeat_stop.take() { + let _ = sender.send(()); + } + if let Some(handle) = self.heartbeat_handle.take() { + let _ = handle.join(); + } + } +} + +impl Drop for InternalPromptProgressRun { + fn drop(&mut self) { + self.stop_heartbeat(); + } +} + +fn format_internal_prompt_progress_line( + event: InternalPromptProgressEvent, + snapshot: &InternalPromptProgressState, + elapsed: Duration, + error: Option<&str>, +) -> String { + let elapsed_seconds = elapsed.as_secs(); + let step_label = if snapshot.step == 0 { + "current step pending".to_string() + } else { + format!("current step {}", snapshot.step) + }; + let mut status_bits = vec![step_label, format!("phase {}", snapshot.phase)]; + if let Some(detail) = snapshot + .detail + .as_deref() + .filter(|detail| !detail.is_empty()) + { + status_bits.push(detail.to_string()); + } + let status = status_bits.join(" · "); + match event { + InternalPromptProgressEvent::Started => { + format!( + "🧭 {} status · planning started · {status}", + snapshot.command_label + ) + } + InternalPromptProgressEvent::Update => { + format!("… {} status · {status}", snapshot.command_label) + } + InternalPromptProgressEvent::Heartbeat => format!( + "… {} heartbeat · {elapsed_seconds}s elapsed · {status}", + snapshot.command_label + ), + InternalPromptProgressEvent::Complete => format!( + "✔ {} status · completed · {elapsed_seconds}s elapsed · {} steps total", + snapshot.command_label, snapshot.step + ), + InternalPromptProgressEvent::Failed => format!( + "✘ {} status · failed · {elapsed_seconds}s elapsed · {}", + snapshot.command_label, + error.unwrap_or("unknown error") + ), + } +} + +fn describe_tool_progress(name: &str, input: &str) -> String { + let parsed: serde_json::Value = + serde_json::from_str(input).unwrap_or(serde_json::Value::String(input.to_string())); + match name { + "bash" | "Bash" => { + let command = parsed + .get("command") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + if command.is_empty() { + "running shell command".to_string() + } else { + format!("command {}", truncate_for_summary(command.trim(), 100)) + } + } + "read_file" | "Read" => format!("reading {}", extract_tool_path(&parsed)), + "write_file" | "Write" => format!("writing {}", extract_tool_path(&parsed)), + "edit_file" | "Edit" => format!("editing {}", extract_tool_path(&parsed)), + "glob_search" | "Glob" => { + let pattern = parsed + .get("pattern") + .and_then(|value| value.as_str()) + .unwrap_or("?"); + let scope = parsed + .get("path") + .and_then(|value| value.as_str()) + .unwrap_or("."); + format!("glob `{pattern}` in {scope}") + } + "grep_search" | "Grep" => { + let pattern = parsed + .get("pattern") + .and_then(|value| value.as_str()) + .unwrap_or("?"); + let scope = parsed + .get("path") + .and_then(|value| value.as_str()) + .unwrap_or("."); + format!("grep `{pattern}` in {scope}") + } + "web_search" | "WebSearch" => parsed + .get("query") + .and_then(|value| value.as_str()) + .map_or_else( + || "running web search".to_string(), + |query| format!("query {}", truncate_for_summary(query, 100)), + ), + _ => { + let summary = summarize_tool_payload(input); + if summary.is_empty() { + format!("running {name}") + } else { + format!("{name}: {summary}") + } + } + } +} + +#[allow(clippy::needless_pass_by_value)] +#[allow(clippy::too_many_arguments)] +fn build_runtime( + session: Session, + model: String, + system_prompt: Vec<String>, + enable_tools: bool, + emit_output: bool, + allowed_tools: Option<AllowedToolSet>, + permission_mode: PermissionMode, + progress_reporter: Option<InternalPromptProgressReporter>, +) -> Result<ConversationRuntime<DefaultRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>> +{ + let (feature_config, tool_registry) = build_runtime_plugin_state()?; + Ok(ConversationRuntime::new_with_features( + session, + DefaultRuntimeClient::new( + model, + enable_tools, + emit_output, + allowed_tools.clone(), + tool_registry.clone(), + progress_reporter, + )?, + CliToolExecutor::new(allowed_tools.clone(), emit_output, tool_registry.clone()), + permission_policy(permission_mode, &tool_registry), + system_prompt, + feature_config, + )) +} + +struct CliPermissionPrompter { + current_mode: PermissionMode, +} + +impl CliPermissionPrompter { + fn new(current_mode: PermissionMode) -> Self { + Self { current_mode } + } +} + +impl runtime::PermissionPrompter for CliPermissionPrompter { + fn decide( + &mut self, + request: &runtime::PermissionRequest, + ) -> runtime::PermissionPromptDecision { + println!(); + println!("Permission approval required"); + println!(" Tool {}", request.tool_name); + println!(" Current mode {}", self.current_mode.as_str()); + println!(" Required mode {}", request.required_mode.as_str()); + println!(" Input {}", request.input); + print!("Approve this tool call? [y/N]: "); + let _ = io::stdout().flush(); + + let mut response = String::new(); + match io::stdin().read_line(&mut response) { + Ok(_) => { + let normalized = response.trim().to_ascii_lowercase(); + if matches!(normalized.as_str(), "y" | "yes") { + runtime::PermissionPromptDecision::Allow + } else { + runtime::PermissionPromptDecision::Deny { + reason: format!( + "tool '{}' denied by user approval prompt", + request.tool_name + ), + } + } + } + Err(error) => runtime::PermissionPromptDecision::Deny { + reason: format!("permission approval failed: {error}"), + }, + } + } +} + +struct DefaultRuntimeClient { + runtime: tokio::runtime::Runtime, + client: ClawApiClient, + model: String, + enable_tools: bool, + emit_output: bool, + allowed_tools: Option<AllowedToolSet>, + tool_registry: GlobalToolRegistry, + progress_reporter: Option<InternalPromptProgressReporter>, +} + +impl DefaultRuntimeClient { + fn new( + model: String, + enable_tools: bool, + emit_output: bool, + allowed_tools: Option<AllowedToolSet>, + tool_registry: GlobalToolRegistry, + progress_reporter: Option<InternalPromptProgressReporter>, + ) -> Result<Self, Box<dyn std::error::Error>> { + Ok(Self { + runtime: tokio::runtime::Runtime::new()?, + client: ClawApiClient::from_auth(resolve_cli_auth_source()?) + .with_base_url(api::read_base_url()), + model, + enable_tools, + emit_output, + allowed_tools, + tool_registry, + progress_reporter, + }) + } +} + +fn resolve_cli_auth_source() -> Result<AuthSource, Box<dyn std::error::Error>> { + Ok(resolve_startup_auth_source(|| { + let cwd = env::current_dir().map_err(api::ApiError::from)?; + let config = ConfigLoader::default_for(&cwd).load().map_err(|error| { + api::ApiError::Auth(format!("failed to load runtime OAuth config: {error}")) + })?; + Ok(config.oauth().cloned()) + })?) +} + +impl ApiClient for DefaultRuntimeClient { + #[allow(clippy::too_many_lines)] + fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { + if let Some(progress_reporter) = &self.progress_reporter { + progress_reporter.mark_model_phase(); + } + let message_request = MessageRequest { + model: self.model.clone(), + max_tokens: max_tokens_for_model(&self.model), + messages: convert_messages(&request.messages), + system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), + tools: self + .enable_tools + .then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())), + tool_choice: self.enable_tools.then_some(ToolChoice::Auto), + stream: true, + }; + + self.runtime.block_on(async { + let mut stream = self + .client + .stream_message(&message_request) + .await + .map_err(|error| RuntimeError::new(error.to_string()))?; + let mut stdout = io::stdout(); + let mut sink = io::sink(); + let out: &mut dyn Write = if self.emit_output { + &mut stdout + } else { + &mut sink + }; + let renderer = TerminalRenderer::new(); + let mut markdown_stream = MarkdownStreamState::default(); + let mut events = Vec::new(); + let mut pending_tool: Option<(String, String, String)> = None; + let mut saw_stop = false; + + while let Some(event) = stream + .next_event() + .await + .map_err(|error| RuntimeError::new(error.to_string()))? + { + match event { + ApiStreamEvent::MessageStart(start) => { + for block in start.message.content { + push_output_block(block, out, &mut events, &mut pending_tool, true)?; + } + } + ApiStreamEvent::ContentBlockStart(start) => { + push_output_block( + start.content_block, + out, + &mut events, + &mut pending_tool, + true, + )?; + } + ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta { + ContentBlockDelta::TextDelta { text } => { + if !text.is_empty() { + if let Some(progress_reporter) = &self.progress_reporter { + progress_reporter.mark_text_phase(&text); + } + if let Some(rendered) = markdown_stream.push(&renderer, &text) { + write!(out, "{rendered}") + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + } + events.push(AssistantEvent::TextDelta(text)); + } + } + ContentBlockDelta::InputJsonDelta { partial_json } => { + if let Some((_, _, input)) = &mut pending_tool { + input.push_str(&partial_json); + } + } + ContentBlockDelta::ThinkingDelta { .. } + | ContentBlockDelta::SignatureDelta { .. } => {} + }, + ApiStreamEvent::ContentBlockStop(_) => { + if let Some(rendered) = markdown_stream.flush(&renderer) { + write!(out, "{rendered}") + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + } + if let Some((id, name, input)) = pending_tool.take() { + if let Some(progress_reporter) = &self.progress_reporter { + progress_reporter.mark_tool_phase(&name, &input); + } + // Display tool call now that input is fully accumulated + writeln!(out, "\n{}", format_tool_call_start(&name, &input)) + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + events.push(AssistantEvent::ToolUse { id, name, input }); + } + } + ApiStreamEvent::MessageDelta(delta) => { + events.push(AssistantEvent::Usage(TokenUsage { + input_tokens: delta.usage.input_tokens, + output_tokens: delta.usage.output_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + })); + } + ApiStreamEvent::MessageStop(_) => { + saw_stop = true; + if let Some(rendered) = markdown_stream.flush(&renderer) { + write!(out, "{rendered}") + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + } + events.push(AssistantEvent::MessageStop); + } + } + } + + if !saw_stop + && events.iter().any(|event| { + matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) + || matches!(event, AssistantEvent::ToolUse { .. }) + }) + { + events.push(AssistantEvent::MessageStop); + } + + if events + .iter() + .any(|event| matches!(event, AssistantEvent::MessageStop)) + { + return Ok(events); + } + + let response = self + .client + .send_message(&MessageRequest { + stream: false, + ..message_request.clone() + }) + .await + .map_err(|error| RuntimeError::new(error.to_string()))?; + response_to_events(response, out) + }) + } +} + +fn final_assistant_text(summary: &runtime::TurnSummary) -> String { + summary + .assistant_messages + .last() + .map(|message| { + message + .blocks + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::<Vec<_>>() + .join("") + }) + .unwrap_or_default() +} + +fn collect_tool_uses(summary: &runtime::TurnSummary) -> Vec<serde_json::Value> { + summary + .assistant_messages + .iter() + .flat_map(|message| message.blocks.iter()) + .filter_map(|block| match block { + ContentBlock::ToolUse { id, name, input } => Some(json!({ + "id": id, + "name": name, + "input": input, + })), + _ => None, + }) + .collect() +} + +fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec<serde_json::Value> { + summary + .tool_results + .iter() + .flat_map(|message| message.blocks.iter()) + .filter_map(|block| match block { + ContentBlock::ToolResult { + tool_use_id, + tool_name, + output, + is_error, + } => Some(json!({ + "tool_use_id": tool_use_id, + "tool_name": tool_name, + "output": output, + "is_error": is_error, + })), + _ => None, + }) + .collect() +} + +fn slash_command_completion_candidates() -> Vec<String> { + let mut candidates = slash_command_specs() + .iter() + .flat_map(|spec| { + std::iter::once(spec.name) + .chain(spec.aliases.iter().copied()) + .map(|name| format!("/{name}")) + .collect::<Vec<_>>() + }) + .collect::<Vec<_>>(); + candidates.extend([ + String::from("/vim"), + String::from("/exit"), + String::from("/quit"), + ]); + candidates.sort(); + candidates.dedup(); + candidates +} + +fn suggest_repl_commands(name: &str) -> Vec<String> { + let normalized = name.trim().trim_start_matches('/').to_ascii_lowercase(); + if normalized.is_empty() { + return Vec::new(); + } + + let mut ranked = slash_command_completion_candidates() + .into_iter() + .filter_map(|candidate| { + let raw = candidate.trim_start_matches('/').to_ascii_lowercase(); + let distance = edit_distance(&normalized, &raw); + let prefix_match = raw.starts_with(&normalized) || normalized.starts_with(&raw); + let near_match = distance <= 2; + (prefix_match || near_match).then_some((distance, candidate)) + }) + .collect::<Vec<_>>(); + ranked.sort(); + ranked.dedup_by(|left, right| left.1 == right.1); + ranked + .into_iter() + .map(|(_, candidate)| candidate) + .take(3) + .collect() +} + +fn edit_distance(left: &str, right: &str) -> usize { + if left == right { + return 0; + } + if left.is_empty() { + return right.chars().count(); + } + if right.is_empty() { + return left.chars().count(); + } + + let right_chars = right.chars().collect::<Vec<_>>(); + let mut previous = (0..=right_chars.len()).collect::<Vec<_>>(); + let mut current = vec![0; right_chars.len() + 1]; + + for (left_index, left_char) in left.chars().enumerate() { + current[0] = left_index + 1; + for (right_index, right_char) in right_chars.iter().enumerate() { + let substitution_cost = usize::from(left_char != *right_char); + current[right_index + 1] = (previous[right_index + 1] + 1) + .min(current[right_index] + 1) + .min(previous[right_index] + substitution_cost); + } + std::mem::swap(&mut previous, &mut current); + } + + previous[right_chars.len()] +} + +fn format_tool_call_start(name: &str, input: &str) -> String { + let parsed: serde_json::Value = + serde_json::from_str(input).unwrap_or(serde_json::Value::String(input.to_string())); + + let detail = match name { + "bash" | "Bash" => format_bash_call(&parsed), + "read_file" | "Read" => { + let path = extract_tool_path(&parsed); + format!("\x1b[2m📄 Reading {path}…\x1b[0m") + } + "write_file" | "Write" => { + let path = extract_tool_path(&parsed); + let lines = parsed + .get("content") + .and_then(|value| value.as_str()) + .map_or(0, |content| content.lines().count()); + format!("\x1b[1;32m✏️ Writing {path}\x1b[0m \x1b[2m({lines} lines)\x1b[0m") + } + "edit_file" | "Edit" => { + let path = extract_tool_path(&parsed); + let old_value = parsed + .get("old_string") + .or_else(|| parsed.get("oldString")) + .and_then(|value| value.as_str()) + .unwrap_or_default(); + let new_value = parsed + .get("new_string") + .or_else(|| parsed.get("newString")) + .and_then(|value| value.as_str()) + .unwrap_or_default(); + format!( + "\x1b[1;33m📝 Editing {path}\x1b[0m{}", + format_patch_preview(old_value, new_value) + .map(|preview| format!("\n{preview}")) + .unwrap_or_default() + ) + } + "glob_search" | "Glob" => format_search_start("🔎 Glob", &parsed), + "grep_search" | "Grep" => format_search_start("🔎 Grep", &parsed), + "web_search" | "WebSearch" => parsed + .get("query") + .and_then(|value| value.as_str()) + .unwrap_or("?") + .to_string(), + _ => summarize_tool_payload(input), + }; + + let border = "─".repeat(name.len() + 8); + format!( + "\x1b[38;5;245m╭─ \x1b[1;36m{name}\x1b[0;38;5;245m ─╮\x1b[0m\n\x1b[38;5;245m│\x1b[0m {detail}\n\x1b[38;5;245m╰{border}╯\x1b[0m" + ) +} + +fn format_tool_result(name: &str, output: &str, is_error: bool) -> String { + let icon = if is_error { + "\x1b[1;31m✗\x1b[0m" + } else { + "\x1b[1;32m✓\x1b[0m" + }; + if is_error { + let summary = truncate_for_summary(output.trim(), 160); + return if summary.is_empty() { + format!("{icon} \x1b[38;5;245m{name}\x1b[0m") + } else { + format!("{icon} \x1b[38;5;245m{name}\x1b[0m\n\x1b[38;5;203m{summary}\x1b[0m") + }; + } + + let parsed: serde_json::Value = + serde_json::from_str(output).unwrap_or(serde_json::Value::String(output.to_string())); + match name { + "bash" | "Bash" => format_bash_result(icon, &parsed), + "read_file" | "Read" => format_read_result(icon, &parsed), + "write_file" | "Write" => format_write_result(icon, &parsed), + "edit_file" | "Edit" => format_edit_result(icon, &parsed), + "glob_search" | "Glob" => format_glob_result(icon, &parsed), + "grep_search" | "Grep" => format_grep_result(icon, &parsed), + _ => format_generic_tool_result(icon, name, &parsed), + } +} + +const DISPLAY_TRUNCATION_NOTICE: &str = + "\x1b[2m… output truncated for display; full result preserved in session.\x1b[0m"; +const READ_DISPLAY_MAX_LINES: usize = 80; +const READ_DISPLAY_MAX_CHARS: usize = 6_000; +const TOOL_OUTPUT_DISPLAY_MAX_LINES: usize = 60; +const TOOL_OUTPUT_DISPLAY_MAX_CHARS: usize = 4_000; + +fn extract_tool_path(parsed: &serde_json::Value) -> String { + parsed + .get("file_path") + .or_else(|| parsed.get("filePath")) + .or_else(|| parsed.get("path")) + .and_then(|value| value.as_str()) + .unwrap_or("?") + .to_string() +} + +fn format_search_start(label: &str, parsed: &serde_json::Value) -> String { + let pattern = parsed + .get("pattern") + .and_then(|value| value.as_str()) + .unwrap_or("?"); + let scope = parsed + .get("path") + .and_then(|value| value.as_str()) + .unwrap_or("."); + format!("{label} {pattern}\n\x1b[2min {scope}\x1b[0m") +} + +fn format_patch_preview(old_value: &str, new_value: &str) -> Option<String> { + if old_value.is_empty() && new_value.is_empty() { + return None; + } + Some(format!( + "\x1b[38;5;203m- {}\x1b[0m\n\x1b[38;5;70m+ {}\x1b[0m", + truncate_for_summary(first_visible_line(old_value), 72), + truncate_for_summary(first_visible_line(new_value), 72) + )) +} + +fn format_bash_call(parsed: &serde_json::Value) -> String { + let command = parsed + .get("command") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + if command.is_empty() { + String::new() + } else { + format!( + "\x1b[48;5;236;38;5;255m $ {} \x1b[0m", + truncate_for_summary(command, 160) + ) + } +} + +fn first_visible_line(text: &str) -> &str { + text.lines() + .find(|line| !line.trim().is_empty()) + .unwrap_or(text) +} + +fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String { + let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")]; + if let Some(task_id) = parsed + .get("backgroundTaskId") + .and_then(|value| value.as_str()) + { + write!(&mut lines[0], " backgrounded ({task_id})").expect("write to string"); + } else if let Some(status) = parsed + .get("returnCodeInterpretation") + .and_then(|value| value.as_str()) + .filter(|status| !status.is_empty()) + { + write!(&mut lines[0], " {status}").expect("write to string"); + } + + if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) { + if !stdout.trim().is_empty() { + lines.push(truncate_output_for_display( + stdout, + TOOL_OUTPUT_DISPLAY_MAX_LINES, + TOOL_OUTPUT_DISPLAY_MAX_CHARS, + )); + } + } + if let Some(stderr) = parsed.get("stderr").and_then(|value| value.as_str()) { + if !stderr.trim().is_empty() { + lines.push(format!( + "\x1b[38;5;203m{}\x1b[0m", + truncate_output_for_display( + stderr, + TOOL_OUTPUT_DISPLAY_MAX_LINES, + TOOL_OUTPUT_DISPLAY_MAX_CHARS, + ) + )); + } + } + + lines.join("\n\n") +} + +fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String { + let file = parsed.get("file").unwrap_or(parsed); + let path = extract_tool_path(file); + let start_line = file + .get("startLine") + .and_then(serde_json::Value::as_u64) + .unwrap_or(1); + let num_lines = file + .get("numLines") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let total_lines = file + .get("totalLines") + .and_then(serde_json::Value::as_u64) + .unwrap_or(num_lines); + let content = file + .get("content") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + let end_line = start_line.saturating_add(num_lines.saturating_sub(1)); + + format!( + "{icon} \x1b[2m📄 Read {path} (lines {}-{} of {})\x1b[0m\n{}", + start_line, + end_line.max(start_line), + total_lines, + truncate_output_for_display(content, READ_DISPLAY_MAX_LINES, READ_DISPLAY_MAX_CHARS) + ) +} + +fn format_write_result(icon: &str, parsed: &serde_json::Value) -> String { + let path = extract_tool_path(parsed); + let kind = parsed + .get("type") + .and_then(|value| value.as_str()) + .unwrap_or("write"); + let line_count = parsed + .get("content") + .and_then(|value| value.as_str()) + .map_or(0, |content| content.lines().count()); + format!( + "{icon} \x1b[1;32m✏️ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m", + if kind == "create" { "Wrote" } else { "Updated" }, + ) +} + +fn format_structured_patch_preview(parsed: &serde_json::Value) -> Option<String> { + let hunks = parsed.get("structuredPatch")?.as_array()?; + let mut preview = Vec::new(); + for hunk in hunks.iter().take(2) { + let lines = hunk.get("lines")?.as_array()?; + for line in lines.iter().filter_map(|value| value.as_str()).take(6) { + match line.chars().next() { + Some('+') => preview.push(format!("\x1b[38;5;70m{line}\x1b[0m")), + Some('-') => preview.push(format!("\x1b[38;5;203m{line}\x1b[0m")), + _ => preview.push(line.to_string()), + } + } + } + if preview.is_empty() { + None + } else { + Some(preview.join("\n")) + } +} + +fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String { + let path = extract_tool_path(parsed); + let suffix = if parsed + .get("replaceAll") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + { + " (replace all)" + } else { + "" + }; + let preview = format_structured_patch_preview(parsed).or_else(|| { + let old_value = parsed + .get("oldString") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + let new_value = parsed + .get("newString") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + format_patch_preview(old_value, new_value) + }); + + match preview { + Some(preview) => format!("{icon} \x1b[1;33m📝 Edited {path}{suffix}\x1b[0m\n{preview}"), + None => format!("{icon} \x1b[1;33m📝 Edited {path}{suffix}\x1b[0m"), + } +} + +fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { + let num_files = parsed + .get("numFiles") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let filenames = parsed + .get("filenames") + .and_then(|value| value.as_array()) + .map(|files| { + files + .iter() + .filter_map(|value| value.as_str()) + .take(8) + .collect::<Vec<_>>() + .join("\n") + }) + .unwrap_or_default(); + if filenames.is_empty() { + format!("{icon} \x1b[38;5;245mglob_search\x1b[0m matched {num_files} files") + } else { + format!("{icon} \x1b[38;5;245mglob_search\x1b[0m matched {num_files} files\n{filenames}") + } +} + +fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String { + let num_matches = parsed + .get("numMatches") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let num_files = parsed + .get("numFiles") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let content = parsed + .get("content") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + let filenames = parsed + .get("filenames") + .and_then(|value| value.as_array()) + .map(|files| { + files + .iter() + .filter_map(|value| value.as_str()) + .take(8) + .collect::<Vec<_>>() + .join("\n") + }) + .unwrap_or_default(); + let summary = format!( + "{icon} \x1b[38;5;245mgrep_search\x1b[0m {num_matches} matches across {num_files} files" + ); + if !content.trim().is_empty() { + format!( + "{summary}\n{}", + truncate_output_for_display( + content, + TOOL_OUTPUT_DISPLAY_MAX_LINES, + TOOL_OUTPUT_DISPLAY_MAX_CHARS, + ) + ) + } else if !filenames.is_empty() { + format!("{summary}\n{filenames}") + } else { + summary + } +} + +fn format_generic_tool_result(icon: &str, name: &str, parsed: &serde_json::Value) -> String { + let rendered_output = match parsed { + serde_json::Value::String(text) => text.clone(), + serde_json::Value::Null => String::new(), + serde_json::Value::Object(_) | serde_json::Value::Array(_) => { + serde_json::to_string_pretty(parsed).unwrap_or_else(|_| parsed.to_string()) + } + _ => parsed.to_string(), + }; + let preview = truncate_output_for_display( + &rendered_output, + TOOL_OUTPUT_DISPLAY_MAX_LINES, + TOOL_OUTPUT_DISPLAY_MAX_CHARS, + ); + + if preview.is_empty() { + format!("{icon} \x1b[38;5;245m{name}\x1b[0m") + } else if preview.contains('\n') { + format!("{icon} \x1b[38;5;245m{name}\x1b[0m\n{preview}") + } else { + format!("{icon} \x1b[38;5;245m{name}:\x1b[0m {preview}") + } +} + +fn summarize_tool_payload(payload: &str) -> String { + let compact = match serde_json::from_str::<serde_json::Value>(payload) { + Ok(value) => value.to_string(), + Err(_) => payload.trim().to_string(), + }; + truncate_for_summary(&compact, 96) +} + +fn truncate_for_summary(value: &str, limit: usize) -> String { + let mut chars = value.chars(); + let truncated = chars.by_ref().take(limit).collect::<String>(); + if chars.next().is_some() { + format!("{truncated}…") + } else { + truncated + } +} + +fn truncate_output_for_display(content: &str, max_lines: usize, max_chars: usize) -> String { + let original = content.trim_end_matches('\n'); + if original.is_empty() { + return String::new(); + } + + let mut preview_lines = Vec::new(); + let mut used_chars = 0usize; + let mut truncated = false; + + for (index, line) in original.lines().enumerate() { + if index >= max_lines { + truncated = true; + break; + } + + let newline_cost = usize::from(!preview_lines.is_empty()); + let available = max_chars.saturating_sub(used_chars + newline_cost); + if available == 0 { + truncated = true; + break; + } + + let line_chars = line.chars().count(); + if line_chars > available { + preview_lines.push(line.chars().take(available).collect::<String>()); + truncated = true; + break; + } + + preview_lines.push(line.to_string()); + used_chars += newline_cost + line_chars; + } + + let mut preview = preview_lines.join("\n"); + if truncated { + if !preview.is_empty() { + preview.push('\n'); + } + preview.push_str(DISPLAY_TRUNCATION_NOTICE); + } + preview +} + +fn push_output_block( + block: OutputContentBlock, + out: &mut (impl Write + ?Sized), + events: &mut Vec<AssistantEvent>, + pending_tool: &mut Option<(String, String, String)>, + streaming_tool_input: bool, +) -> Result<(), RuntimeError> { + match block { + OutputContentBlock::Text { text } => { + if !text.is_empty() { + let rendered = TerminalRenderer::new().markdown_to_ansi(&text); + write!(out, "{rendered}") + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + events.push(AssistantEvent::TextDelta(text)); + } + } + OutputContentBlock::ToolUse { id, name, input } => { + // During streaming, the initial content_block_start has an empty input ({}). + // The real input arrives via input_json_delta events. In + // non-streaming responses, preserve a legitimate empty object. + let initial_input = if streaming_tool_input + && input.is_object() + && input.as_object().is_some_and(serde_json::Map::is_empty) + { + String::new() + } else { + input.to_string() + }; + *pending_tool = Some((id, name, initial_input)); + } + OutputContentBlock::Thinking { .. } | OutputContentBlock::RedactedThinking { .. } => {} + } + Ok(()) +} + +fn response_to_events( + response: MessageResponse, + out: &mut (impl Write + ?Sized), +) -> Result<Vec<AssistantEvent>, RuntimeError> { + let mut events = Vec::new(); + let mut pending_tool = None; + + for block in response.content { + push_output_block(block, out, &mut events, &mut pending_tool, false)?; + if let Some((id, name, input)) = pending_tool.take() { + events.push(AssistantEvent::ToolUse { id, name, input }); + } + } + + events.push(AssistantEvent::Usage(TokenUsage { + input_tokens: response.usage.input_tokens, + output_tokens: response.usage.output_tokens, + cache_creation_input_tokens: response.usage.cache_creation_input_tokens, + cache_read_input_tokens: response.usage.cache_read_input_tokens, + })); + events.push(AssistantEvent::MessageStop); + Ok(events) +} + +struct CliToolExecutor { + renderer: TerminalRenderer, + emit_output: bool, + allowed_tools: Option<AllowedToolSet>, + tool_registry: GlobalToolRegistry, +} + +impl CliToolExecutor { + fn new( + allowed_tools: Option<AllowedToolSet>, + emit_output: bool, + tool_registry: GlobalToolRegistry, + ) -> Self { + Self { + renderer: TerminalRenderer::new(), + emit_output, + allowed_tools, + tool_registry, + } + } +} + +impl ToolExecutor for CliToolExecutor { + fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> { + if self + .allowed_tools + .as_ref() + .is_some_and(|allowed| !allowed.contains(tool_name)) + { + return Err(ToolError::new(format!( + "tool `{tool_name}` is not enabled by the current --allowedTools setting" + ))); + } + let value = serde_json::from_str(input) + .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + match self.tool_registry.execute(tool_name, &value) { + Ok(output) => { + if self.emit_output { + let markdown = format_tool_result(tool_name, &output, false); + self.renderer + .stream_markdown(&markdown, &mut io::stdout()) + .map_err(|error| ToolError::new(error.to_string()))?; + } + Ok(output) + } + Err(error) => { + if self.emit_output { + let markdown = format_tool_result(tool_name, &error, true); + self.renderer + .stream_markdown(&markdown, &mut io::stdout()) + .map_err(|stream_error| ToolError::new(stream_error.to_string()))?; + } + Err(ToolError::new(error)) + } + } + } +} + +fn permission_policy(mode: PermissionMode, tool_registry: &GlobalToolRegistry) -> PermissionPolicy { + tool_registry.permission_specs(None).into_iter().fold( + PermissionPolicy::new(mode), + |policy, (name, required_permission)| { + policy.with_tool_requirement(name, required_permission) + }, + ) +} + +fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> { + messages + .iter() + .filter_map(|message| { + let role = match message.role { + MessageRole::System | MessageRole::User | MessageRole::Tool => "user", + MessageRole::Assistant => "assistant", + }; + let content = message + .blocks + .iter() + .map(|block| match block { + ContentBlock::Text { text } => InputContentBlock::Text { text: text.clone() }, + ContentBlock::ToolUse { id, name, input } => InputContentBlock::ToolUse { + id: id.clone(), + name: name.clone(), + input: serde_json::from_str(input) + .unwrap_or_else(|_| serde_json::json!({ "raw": input })), + }, + ContentBlock::ToolResult { + tool_use_id, + output, + is_error, + .. + } => InputContentBlock::ToolResult { + tool_use_id: tool_use_id.clone(), + content: vec![ToolResultContentBlock::Text { + text: output.clone(), + }], + is_error: *is_error, + }, + }) + .collect::<Vec<_>>(); + (!content.is_empty()).then(|| InputMessage { + role: role.to_string(), + content, + }) + }) + .collect() +} + +fn print_help_to(out: &mut impl Write) -> io::Result<()> { + writeln!(out, "Claw Code CLI v{VERSION}")?; + writeln!( + out, + " Interactive coding assistant for the current workspace." + )?; + writeln!(out)?; + writeln!(out, "Quick start")?; + writeln!( + out, + " claw Start the interactive REPL" + )?; + writeln!( + out, + " claw \"summarize this repo\" Run one prompt and exit" + )?; + writeln!( + out, + " claw prompt \"explain src/main.rs\" Explicit one-shot prompt" + )?; + writeln!( + out, + " claw --resume SESSION.json /status Inspect a saved session" + )?; + writeln!(out)?; + writeln!(out, "Interactive essentials")?; + writeln!( + out, + " /help Browse the full slash command map" + )?; + writeln!( + out, + " /status Inspect session + workspace state" + )?; + writeln!( + out, + " /model <name> Switch models mid-session" + )?; + writeln!( + out, + " /permissions <mode> Adjust tool access" + )?; + writeln!( + out, + " Tab Complete slash commands" + )?; + writeln!( + out, + " /vim Toggle modal editing" + )?; + writeln!( + out, + " Shift+Enter / Ctrl+J Insert a newline" + )?; + writeln!(out)?; + writeln!(out, "Commands")?; + writeln!( + out, + " claw dump-manifests Read upstream TS sources and print extracted counts" + )?; + writeln!( + out, + " claw bootstrap-plan Print the bootstrap phase skeleton" + )?; + writeln!( + out, + " claw agents List configured agents" + )?; + writeln!( + out, + " claw skills List installed skills" + )?; + writeln!(out, " claw system-prompt [--cwd PATH] [--date YYYY-MM-DD]")?; + writeln!( + out, + " claw login Start the OAuth login flow" + )?; + writeln!( + out, + " claw logout Clear saved OAuth credentials" + )?; + writeln!( + out, + " claw init Scaffold CLAW.md + local files" + )?; + writeln!(out)?; + writeln!(out, "Flags")?; + writeln!( + out, + " --model MODEL Override the active model" + )?; + writeln!( + out, + " --output-format FORMAT Non-interactive output: text or json" + )?; + writeln!( + out, + " --permission-mode MODE Set read-only, workspace-write, or danger-full-access" + )?; + writeln!( + out, + " --dangerously-skip-permissions Skip all permission checks" + )?; + writeln!( + out, + " --allowedTools TOOLS Restrict enabled tools (repeatable; comma-separated aliases supported)" + )?; + writeln!( + out, + " --version, -V Print version and build information" + )?; + writeln!(out)?; + writeln!(out, "Slash command reference")?; + writeln!(out, "{}", render_slash_command_help())?; + writeln!(out)?; + let resume_commands = resume_supported_slash_commands() + .into_iter() + .map(|spec| match spec.argument_hint { + Some(argument_hint) => format!("/{} {}", spec.name, argument_hint), + None => format!("/{}", spec.name), + }) + .collect::<Vec<_>>() + .join(", "); + writeln!(out, "Resume-safe commands: {resume_commands}")?; + writeln!(out, "Examples")?; + writeln!(out, " claw --model opus \"summarize this repo\"")?; + writeln!( + out, + " claw --output-format json prompt \"explain src/main.rs\"" + )?; + writeln!( + out, + " claw --allowedTools read,glob \"summarize Cargo.toml\"" + )?; + writeln!( + out, + " claw --resume session.json /status /diff /export notes.txt" + )?; + writeln!(out, " claw agents")?; + writeln!(out, " claw /skills")?; + writeln!(out, " claw login")?; + writeln!(out, " claw init")?; + Ok(()) +} + +fn print_help() { + let _ = print_help_to(&mut io::stdout()); +} + +#[cfg(test)] +mod tests { + use super::{ + describe_tool_progress, filter_tool_specs, format_compact_report, format_cost_report, + format_internal_prompt_progress_line, format_model_report, format_model_switch_report, + format_permissions_report, format_permissions_switch_report, format_resume_report, + format_status_report, format_tool_call_start, format_tool_result, + normalize_permission_mode, parse_args, parse_git_status_metadata, permission_policy, + print_help_to, push_output_block, render_config_report, render_memory_report, + render_repl_help, render_unknown_repl_command, resolve_model_alias, response_to_events, + resume_supported_slash_commands, slash_command_completion_candidates, status_context, + CliAction, CliOutputFormat, InternalPromptProgressEvent, InternalPromptProgressState, + SlashCommand, StatusUsage, DEFAULT_MODEL, + }; + use api::{MessageResponse, OutputContentBlock, Usage}; + use plugins::{PluginTool, PluginToolDefinition, PluginToolPermission}; + use runtime::{AssistantEvent, ContentBlock, ConversationMessage, MessageRole, PermissionMode}; + use serde_json::json; + use std::path::PathBuf; + use std::time::Duration; + use tools::GlobalToolRegistry; + + fn registry_with_plugin_tool() -> GlobalToolRegistry { + GlobalToolRegistry::with_plugin_tools(vec![PluginTool::new( + "plugin-demo@external", + "plugin-demo", + PluginToolDefinition { + name: "plugin_echo".to_string(), + description: Some("Echo plugin payload".to_string()), + input_schema: json!({ + "type": "object", + "properties": { + "message": { "type": "string" } + }, + "required": ["message"], + "additionalProperties": false + }), + }, + "echo".to_string(), + Vec::new(), + PluginToolPermission::WorkspaceWrite, + None, + )]) + .expect("plugin tool registry should build") + } + + #[test] + fn defaults_to_repl_when_no_args() { + assert_eq!( + parse_args(&[]).expect("args should parse"), + CliAction::Repl { + model: DEFAULT_MODEL.to_string(), + allowed_tools: None, + permission_mode: PermissionMode::DangerFullAccess, + } + ); + } + + #[test] + fn parses_prompt_subcommand() { + let args = vec![ + "prompt".to_string(), + "hello".to_string(), + "world".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::Prompt { + prompt: "hello world".to_string(), + model: DEFAULT_MODEL.to_string(), + output_format: CliOutputFormat::Text, + allowed_tools: None, + permission_mode: PermissionMode::DangerFullAccess, + } + ); + } + + #[test] + fn parses_bare_prompt_and_json_output_flag() { + let args = vec![ + "--output-format=json".to_string(), + "--model".to_string(), + "custom-opus".to_string(), + "explain".to_string(), + "this".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::Prompt { + prompt: "explain this".to_string(), + model: "custom-opus".to_string(), + output_format: CliOutputFormat::Json, + allowed_tools: None, + permission_mode: PermissionMode::DangerFullAccess, + } + ); + } + + #[test] + fn resolves_model_aliases_in_args() { + let args = vec![ + "--model".to_string(), + "opus".to_string(), + "explain".to_string(), + "this".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::Prompt { + prompt: "explain this".to_string(), + model: "claude-opus-4-6".to_string(), + output_format: CliOutputFormat::Text, + allowed_tools: None, + permission_mode: PermissionMode::DangerFullAccess, + } + ); + } + + #[test] + fn resolves_known_model_aliases() { + assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6"); + assert_eq!(resolve_model_alias("sonnet"), "claude-sonnet-4-6"); + assert_eq!(resolve_model_alias("haiku"), "claude-haiku-4-5-20251213"); + assert_eq!(resolve_model_alias("custom-opus"), "custom-opus"); + } + + #[test] + fn parses_version_flags_without_initializing_prompt_mode() { + assert_eq!( + parse_args(&["--version".to_string()]).expect("args should parse"), + CliAction::Version + ); + assert_eq!( + parse_args(&["-V".to_string()]).expect("args should parse"), + CliAction::Version + ); + } + + #[test] + fn parses_permission_mode_flag() { + let args = vec!["--permission-mode=read-only".to_string()]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::Repl { + model: DEFAULT_MODEL.to_string(), + allowed_tools: None, + permission_mode: PermissionMode::ReadOnly, + } + ); + } + + #[test] + fn parses_allowed_tools_flags_with_aliases_and_lists() { + let args = vec![ + "--allowedTools".to_string(), + "read,glob".to_string(), + "--allowed-tools=write_file".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::Repl { + model: DEFAULT_MODEL.to_string(), + allowed_tools: Some( + ["glob_search", "read_file", "write_file"] + .into_iter() + .map(str::to_string) + .collect() + ), + permission_mode: PermissionMode::DangerFullAccess, + } + ); + } + + #[test] + fn rejects_unknown_allowed_tools() { + let error = parse_args(&["--allowedTools".to_string(), "teleport".to_string()]) + .expect_err("tool should be rejected"); + assert!(error.contains("unsupported tool in --allowedTools: teleport")); + } + + #[test] + fn parses_system_prompt_options() { + let args = vec![ + "system-prompt".to_string(), + "--cwd".to_string(), + "/tmp/project".to_string(), + "--date".to_string(), + "2026-04-01".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::PrintSystemPrompt { + cwd: PathBuf::from("/tmp/project"), + date: "2026-04-01".to_string(), + } + ); + } + + #[test] + fn parses_login_and_logout_subcommands() { + assert_eq!( + parse_args(&["login".to_string()]).expect("login should parse"), + CliAction::Login + ); + assert_eq!( + parse_args(&["logout".to_string()]).expect("logout should parse"), + CliAction::Logout + ); + assert_eq!( + parse_args(&["init".to_string()]).expect("init should parse"), + CliAction::Init + ); + assert_eq!( + parse_args(&["agents".to_string()]).expect("agents should parse"), + CliAction::Agents { args: None } + ); + assert_eq!( + parse_args(&["skills".to_string()]).expect("skills should parse"), + CliAction::Skills { args: None } + ); + assert_eq!( + parse_args(&["agents".to_string(), "--help".to_string()]) + .expect("agents help should parse"), + CliAction::Agents { + args: Some("--help".to_string()) + } + ); + } + + #[test] + fn parses_direct_agents_and_skills_slash_commands() { + assert_eq!( + parse_args(&["/agents".to_string()]).expect("/agents should parse"), + CliAction::Agents { args: None } + ); + assert_eq!( + parse_args(&["/skills".to_string()]).expect("/skills should parse"), + CliAction::Skills { args: None } + ); + assert_eq!( + parse_args(&["/skills".to_string(), "help".to_string()]) + .expect("/skills help should parse"), + CliAction::Skills { + args: Some("help".to_string()) + } + ); + let error = parse_args(&["/status".to_string()]) + .expect_err("/status should remain REPL-only when invoked directly"); + assert!(error.contains("Direct slash command unavailable")); + assert!(error.contains("/status")); + } + + #[test] + fn parses_resume_flag_with_slash_command() { + let args = vec![ + "--resume".to_string(), + "session.json".to_string(), + "/compact".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::ResumeSession { + session_path: PathBuf::from("session.json"), + commands: vec!["/compact".to_string()], + } + ); + } + + #[test] + fn parses_resume_flag_with_multiple_slash_commands() { + let args = vec![ + "--resume".to_string(), + "session.json".to_string(), + "/status".to_string(), + "/compact".to_string(), + "/cost".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::ResumeSession { + session_path: PathBuf::from("session.json"), + commands: vec![ + "/status".to_string(), + "/compact".to_string(), + "/cost".to_string(), + ], + } + ); + } + + #[test] + fn filtered_tool_specs_respect_allowlist() { + let allowed = ["read_file", "grep_search"] + .into_iter() + .map(str::to_string) + .collect(); + let filtered = filter_tool_specs(&GlobalToolRegistry::builtin(), Some(&allowed)); + let names = filtered + .into_iter() + .map(|spec| spec.name) + .collect::<Vec<_>>(); + assert_eq!(names, vec!["read_file", "grep_search"]); + } + + #[test] + fn filtered_tool_specs_include_plugin_tools() { + let filtered = filter_tool_specs(®istry_with_plugin_tool(), None); + let names = filtered + .into_iter() + .map(|definition| definition.name) + .collect::<Vec<_>>(); + assert!(names.contains(&"bash".to_string())); + assert!(names.contains(&"plugin_echo".to_string())); + } + + #[test] + fn permission_policy_uses_plugin_tool_permissions() { + let policy = permission_policy(PermissionMode::ReadOnly, ®istry_with_plugin_tool()); + let required = policy.required_mode_for("plugin_echo"); + assert_eq!(required, PermissionMode::WorkspaceWrite); + } + + #[test] + fn shared_help_uses_resume_annotation_copy() { + let help = commands::render_slash_command_help(); + assert!(help.contains("Slash commands")); + assert!(help.contains("Tab completes commands inside the REPL.")); + assert!(help.contains("available via claw --resume SESSION.json")); + } + + #[test] + fn repl_help_includes_shared_commands_and_exit() { + let help = render_repl_help(); + assert!(help.contains("Interactive REPL")); + assert!(help.contains("/help")); + assert!(help.contains("/status")); + assert!(help.contains("/model [model]")); + assert!(help.contains("/permissions [read-only|workspace-write|danger-full-access]")); + assert!(help.contains("/clear [--confirm]")); + assert!(help.contains("/cost")); + assert!(help.contains("/resume <session-path>")); + assert!(help.contains("/config [env|hooks|model|plugins]")); + assert!(help.contains("/memory")); + assert!(help.contains("/init")); + assert!(help.contains("/diff")); + assert!(help.contains("/version")); + assert!(help.contains("/export [file]")); + assert!(help.contains("/session [list|switch <session-id>]")); + assert!(help.contains( + "/plugin [list|install <path>|enable <name>|disable <name>|uninstall <id>|update <id>]" + )); + assert!(help.contains("aliases: /plugins, /marketplace")); + assert!(help.contains("/agents")); + assert!(help.contains("/skills")); + assert!(help.contains("/exit")); + assert!(help.contains("Tab cycles slash command matches")); + } + + #[test] + fn completion_candidates_include_repl_only_exit_commands() { + let candidates = slash_command_completion_candidates(); + assert!(candidates.contains(&"/help".to_string())); + assert!(candidates.contains(&"/vim".to_string())); + assert!(candidates.contains(&"/exit".to_string())); + assert!(candidates.contains(&"/quit".to_string())); + } + + #[test] + fn unknown_repl_command_suggestions_include_repl_shortcuts() { + let rendered = render_unknown_repl_command("exi"); + assert!(rendered.contains("Unknown slash command")); + assert!(rendered.contains("/exit")); + assert!(rendered.contains("/help")); + } + + #[test] + fn resume_supported_command_list_matches_expected_surface() { + let names = resume_supported_slash_commands() + .into_iter() + .map(|spec| spec.name) + .collect::<Vec<_>>(); + assert_eq!( + names, + vec![ + "help", "status", "compact", "clear", "cost", "config", "memory", "init", "diff", + "version", "export", "agents", "skills", + ] + ); + } + + #[test] + fn resume_report_uses_sectioned_layout() { + let report = format_resume_report("session.json", 14, 6); + assert!(report.contains("Session resumed")); + assert!(report.contains("Session file session.json")); + assert!(report.contains("History 14 messages · 6 turns")); + assert!(report.contains("/status · /diff · /export")); + } + + #[test] + fn compact_report_uses_structured_output() { + let compacted = format_compact_report(8, 5, false); + assert!(compacted.contains("Compact")); + assert!(compacted.contains("Result compacted")); + assert!(compacted.contains("Messages removed 8")); + assert!(compacted.contains("Use /status")); + let skipped = format_compact_report(0, 3, true); + assert!(skipped.contains("Result skipped")); + } + + #[test] + fn cost_report_uses_sectioned_layout() { + let report = format_cost_report(runtime::TokenUsage { + input_tokens: 20, + output_tokens: 8, + cache_creation_input_tokens: 3, + cache_read_input_tokens: 1, + }); + assert!(report.contains("Cost")); + assert!(report.contains("Input tokens 20")); + assert!(report.contains("Output tokens 8")); + assert!(report.contains("Cache create 3")); + assert!(report.contains("Cache read 1")); + assert!(report.contains("Total tokens 32")); + assert!(report.contains("/compact")); + } + + #[test] + fn permissions_report_uses_sectioned_layout() { + let report = format_permissions_report("workspace-write"); + assert!(report.contains("Permissions")); + assert!(report.contains("Active mode workspace-write")); + assert!(report.contains("Effect Editing tools can modify files in the workspace")); + assert!(report.contains("Modes")); + assert!(report.contains("read-only ○ available Read/search tools only")); + assert!(report.contains("workspace-write ● current Edit files inside the workspace")); + assert!(report.contains("danger-full-access ○ available Unrestricted tool access")); + } + + #[test] + fn permissions_switch_report_is_structured() { + let report = format_permissions_switch_report("read-only", "workspace-write"); + assert!(report.contains("Permissions updated")); + assert!(report.contains("Previous mode read-only")); + assert!(report.contains("Active mode workspace-write")); + assert!(report.contains("Applies to Subsequent tool calls in this REPL")); + } + + #[test] + fn init_help_mentions_direct_subcommand() { + let mut help = Vec::new(); + print_help_to(&mut help).expect("help should render"); + let help = String::from_utf8(help).expect("help should be utf8"); + assert!(help.contains("claw init")); + assert!(help.contains("claw agents")); + assert!(help.contains("claw skills")); + assert!(help.contains("claw /skills")); + } + + #[test] + fn model_report_uses_sectioned_layout() { + let report = format_model_report("sonnet", 12, 4); + assert!(report.contains("Model")); + assert!(report.contains("Current sonnet")); + assert!(report.contains("Session 12 messages · 4 turns")); + assert!(report.contains("Aliases")); + assert!(report.contains("/model <name> Switch models for this REPL session")); + } + + #[test] + fn model_switch_report_preserves_context_summary() { + let report = format_model_switch_report("sonnet", "opus", 9); + assert!(report.contains("Model updated")); + assert!(report.contains("Previous sonnet")); + assert!(report.contains("Current opus")); + assert!(report.contains("Preserved 9 messages")); + } + + #[test] + fn status_line_reports_model_and_token_totals() { + let status = format_status_report( + "sonnet", + StatusUsage { + message_count: 7, + turns: 3, + latest: runtime::TokenUsage { + input_tokens: 5, + output_tokens: 4, + cache_creation_input_tokens: 1, + cache_read_input_tokens: 0, + }, + cumulative: runtime::TokenUsage { + input_tokens: 20, + output_tokens: 8, + cache_creation_input_tokens: 2, + cache_read_input_tokens: 1, + }, + estimated_tokens: 128, + }, + "workspace-write", + &super::StatusContext { + cwd: PathBuf::from("/tmp/project"), + session_path: Some(PathBuf::from("session.json")), + loaded_config_files: 2, + discovered_config_files: 3, + memory_file_count: 4, + project_root: Some(PathBuf::from("/tmp")), + git_branch: Some("main".to_string()), + }, + ); + assert!(status.contains("Session")); + assert!(status.contains("Model sonnet")); + assert!(status.contains("Permissions workspace-write")); + assert!(status.contains("Activity 7 messages · 3 turns")); + assert!(status.contains("Tokens est 128 · latest 10 · total 31")); + assert!(status.contains("Folder /tmp/project")); + assert!(status.contains("Project root /tmp")); + assert!(status.contains("Git branch main")); + assert!(status.contains("Session file session.json")); + assert!(status.contains("Config files loaded 2/3")); + assert!(status.contains("Memory files 4")); + assert!(status.contains("/session list")); + } + + #[test] + fn config_report_supports_section_views() { + let report = render_config_report(Some("env")).expect("config report should render"); + assert!(report.contains("Merged section: env")); + let plugins_report = + render_config_report(Some("plugins")).expect("plugins config report should render"); + assert!(plugins_report.contains("Merged section: plugins")); + } + + #[test] + fn memory_report_uses_sectioned_layout() { + let report = render_memory_report().expect("memory report should render"); + assert!(report.contains("Memory")); + assert!(report.contains("Working directory")); + assert!(report.contains("Instruction files")); + assert!(report.contains("Discovered files")); + } + + #[test] + fn config_report_uses_sectioned_layout() { + let report = render_config_report(None).expect("config report should render"); + assert!(report.contains("Config")); + assert!(report.contains("Discovered files")); + assert!(report.contains("Merged JSON")); + } + + #[test] + fn parses_git_status_metadata() { + let (root, branch) = parse_git_status_metadata(Some( + "## rcc/cli...origin/rcc/cli + M src/main.rs", + )); + assert_eq!(branch.as_deref(), Some("rcc/cli")); + let _ = root; + } + + #[test] + fn status_context_reads_real_workspace_metadata() { + let context = status_context(None).expect("status context should load"); + assert!(context.cwd.is_absolute()); + assert_eq!(context.discovered_config_files, 5); + assert!(context.loaded_config_files <= context.discovered_config_files); + } + + #[test] + fn normalizes_supported_permission_modes() { + assert_eq!(normalize_permission_mode("read-only"), Some("read-only")); + assert_eq!( + normalize_permission_mode("workspace-write"), + Some("workspace-write") + ); + assert_eq!( + normalize_permission_mode("danger-full-access"), + Some("danger-full-access") + ); + assert_eq!(normalize_permission_mode("unknown"), None); + } + + #[test] + fn clear_command_requires_explicit_confirmation_flag() { + assert_eq!( + SlashCommand::parse("/clear"), + Some(SlashCommand::Clear { confirm: false }) + ); + assert_eq!( + SlashCommand::parse("/clear --confirm"), + Some(SlashCommand::Clear { confirm: true }) + ); + } + + #[test] + fn parses_resume_and_config_slash_commands() { + assert_eq!( + SlashCommand::parse("/resume saved-session.json"), + Some(SlashCommand::Resume { + session_path: Some("saved-session.json".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/clear --confirm"), + Some(SlashCommand::Clear { confirm: true }) + ); + assert_eq!( + SlashCommand::parse("/config"), + Some(SlashCommand::Config { section: None }) + ); + assert_eq!( + SlashCommand::parse("/config env"), + Some(SlashCommand::Config { + section: Some("env".to_string()) + }) + ); + assert_eq!(SlashCommand::parse("/memory"), Some(SlashCommand::Memory)); + assert_eq!(SlashCommand::parse("/init"), Some(SlashCommand::Init)); + } + + #[test] + fn init_template_mentions_detected_rust_workspace() { + let rendered = crate::init::render_init_claw_md(std::path::Path::new(".")); + assert!(rendered.contains("# CLAW.md")); + assert!(rendered.contains("cargo clippy --workspace --all-targets -- -D warnings")); + } + + #[test] + fn converts_tool_roundtrip_messages() { + let messages = vec![ + ConversationMessage::user_text("hello"), + ConversationMessage::assistant(vec![ContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "bash".to_string(), + input: "{\"command\":\"pwd\"}".to_string(), + }]), + ConversationMessage { + role: MessageRole::Tool, + blocks: vec![ContentBlock::ToolResult { + tool_use_id: "tool-1".to_string(), + tool_name: "bash".to_string(), + output: "ok".to_string(), + is_error: false, + }], + usage: None, + }, + ]; + + let converted = super::convert_messages(&messages); + assert_eq!(converted.len(), 3); + assert_eq!(converted[1].role, "assistant"); + assert_eq!(converted[2].role, "user"); + } + #[test] + fn repl_help_mentions_history_completion_and_multiline() { + let help = render_repl_help(); + assert!(help.contains("Up/Down")); + assert!(help.contains("Tab cycles")); + assert!(help.contains("Shift+Enter or Ctrl+J")); + } + + #[test] + fn tool_rendering_helpers_compact_output() { + let start = format_tool_call_start("read_file", r#"{"path":"src/main.rs"}"#); + assert!(start.contains("read_file")); + assert!(start.contains("src/main.rs")); + + let done = format_tool_result( + "read_file", + r#"{"file":{"filePath":"src/main.rs","content":"hello","numLines":1,"startLine":1,"totalLines":1}}"#, + false, + ); + assert!(done.contains("📄 Read src/main.rs")); + assert!(done.contains("hello")); + } + + #[test] + fn tool_rendering_truncates_large_read_output_for_display_only() { + let content = (0..200) + .map(|index| format!("line {index:03}")) + .collect::<Vec<_>>() + .join("\n"); + let output = json!({ + "file": { + "filePath": "src/main.rs", + "content": content, + "numLines": 200, + "startLine": 1, + "totalLines": 200 + } + }) + .to_string(); + + let rendered = format_tool_result("read_file", &output, false); + + assert!(rendered.contains("line 000")); + assert!(rendered.contains("line 079")); + assert!(!rendered.contains("line 199")); + assert!(rendered.contains("full result preserved in session")); + assert!(output.contains("line 199")); + } + + #[test] + fn tool_rendering_truncates_large_bash_output_for_display_only() { + let stdout = (0..120) + .map(|index| format!("stdout {index:03}")) + .collect::<Vec<_>>() + .join("\n"); + let output = json!({ + "stdout": stdout, + "stderr": "", + "returnCodeInterpretation": "completed successfully" + }) + .to_string(); + + let rendered = format_tool_result("bash", &output, false); + + assert!(rendered.contains("stdout 000")); + assert!(rendered.contains("stdout 059")); + assert!(!rendered.contains("stdout 119")); + assert!(rendered.contains("full result preserved in session")); + assert!(output.contains("stdout 119")); + } + + #[test] + fn tool_rendering_truncates_generic_long_output_for_display_only() { + let items = (0..120) + .map(|index| format!("payload {index:03}")) + .collect::<Vec<_>>(); + let output = json!({ + "summary": "plugin payload", + "items": items, + }) + .to_string(); + + let rendered = format_tool_result("plugin_echo", &output, false); + + assert!(rendered.contains("plugin_echo")); + assert!(rendered.contains("payload 000")); + assert!(rendered.contains("payload 040")); + assert!(!rendered.contains("payload 080")); + assert!(!rendered.contains("payload 119")); + assert!(rendered.contains("full result preserved in session")); + assert!(output.contains("payload 119")); + } + + #[test] + fn tool_rendering_truncates_raw_generic_output_for_display_only() { + let output = (0..120) + .map(|index| format!("raw {index:03}")) + .collect::<Vec<_>>() + .join("\n"); + + let rendered = format_tool_result("plugin_echo", &output, false); + + assert!(rendered.contains("plugin_echo")); + assert!(rendered.contains("raw 000")); + assert!(rendered.contains("raw 059")); + assert!(!rendered.contains("raw 119")); + assert!(rendered.contains("full result preserved in session")); + assert!(output.contains("raw 119")); + } + + #[test] + fn ultraplan_progress_lines_include_phase_step_and_elapsed_status() { + let snapshot = InternalPromptProgressState { + command_label: "Ultraplan", + task_label: "ship plugin progress".to_string(), + step: 3, + phase: "running read_file".to_string(), + detail: Some("reading rust/crates/claw-cli/src/main.rs".to_string()), + saw_final_text: false, + }; + + let started = format_internal_prompt_progress_line( + InternalPromptProgressEvent::Started, + &snapshot, + Duration::from_secs(0), + None, + ); + let heartbeat = format_internal_prompt_progress_line( + InternalPromptProgressEvent::Heartbeat, + &snapshot, + Duration::from_secs(9), + None, + ); + let completed = format_internal_prompt_progress_line( + InternalPromptProgressEvent::Complete, + &snapshot, + Duration::from_secs(12), + None, + ); + let failed = format_internal_prompt_progress_line( + InternalPromptProgressEvent::Failed, + &snapshot, + Duration::from_secs(12), + Some("network timeout"), + ); + + assert!(started.contains("planning started")); + assert!(started.contains("current step 3")); + assert!(heartbeat.contains("heartbeat")); + assert!(heartbeat.contains("9s elapsed")); + assert!(heartbeat.contains("phase running read_file")); + assert!(completed.contains("completed")); + assert!(completed.contains("3 steps total")); + assert!(failed.contains("failed")); + assert!(failed.contains("network timeout")); + } + + #[test] + fn describe_tool_progress_summarizes_known_tools() { + assert_eq!( + describe_tool_progress("read_file", r#"{"path":"src/main.rs"}"#), + "reading src/main.rs" + ); + assert!( + describe_tool_progress("bash", r#"{"command":"cargo test -p claw-cli"}"#) + .contains("cargo test -p claw-cli") + ); + assert_eq!( + describe_tool_progress("grep_search", r#"{"pattern":"ultraplan","path":"rust"}"#), + "grep `ultraplan` in rust" + ); + } + + #[test] + fn push_output_block_renders_markdown_text() { + let mut out = Vec::new(); + let mut events = Vec::new(); + let mut pending_tool = None; + + push_output_block( + OutputContentBlock::Text { + text: "# Heading".to_string(), + }, + &mut out, + &mut events, + &mut pending_tool, + false, + ) + .expect("text block should render"); + + let rendered = String::from_utf8(out).expect("utf8"); + assert!(rendered.contains("Heading")); + assert!(rendered.contains('\u{1b}')); + } + + #[test] + fn push_output_block_skips_empty_object_prefix_for_tool_streams() { + let mut out = Vec::new(); + let mut events = Vec::new(); + let mut pending_tool = None; + + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }, + &mut out, + &mut events, + &mut pending_tool, + true, + ) + .expect("tool block should accumulate"); + + assert!(events.is_empty()); + assert_eq!( + pending_tool, + Some(("tool-1".to_string(), "read_file".to_string(), String::new(),)) + ); + } + + #[test] + fn response_to_events_preserves_empty_object_json_input_outside_streaming() { + let mut out = Vec::new(); + let events = response_to_events( + MessageResponse { + id: "msg-1".to_string(), + kind: "message".to_string(), + model: "claude-opus-4-6".to_string(), + role: "assistant".to_string(), + content: vec![OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }], + stop_reason: Some("tool_use".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 1, + output_tokens: 1, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + request_id: None, + }, + &mut out, + ) + .expect("response conversion should succeed"); + + assert!(matches!( + &events[0], + AssistantEvent::ToolUse { name, input, .. } + if name == "read_file" && input == "{}" + )); + } + + #[test] + fn response_to_events_preserves_non_empty_json_input_outside_streaming() { + let mut out = Vec::new(); + let events = response_to_events( + MessageResponse { + id: "msg-2".to_string(), + kind: "message".to_string(), + model: "claude-opus-4-6".to_string(), + role: "assistant".to_string(), + content: vec![OutputContentBlock::ToolUse { + id: "tool-2".to_string(), + name: "read_file".to_string(), + input: json!({ "path": "rust/Cargo.toml" }), + }], + stop_reason: Some("tool_use".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 1, + output_tokens: 1, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + request_id: None, + }, + &mut out, + ) + .expect("response conversion should succeed"); + + assert!(matches!( + &events[0], + AssistantEvent::ToolUse { name, input, .. } + if name == "read_file" && input == "{\"path\":\"rust/Cargo.toml\"}" + )); + } + + #[test] + fn response_to_events_ignores_thinking_blocks() { + let mut out = Vec::new(); + let events = response_to_events( + MessageResponse { + id: "msg-3".to_string(), + kind: "message".to_string(), + model: "claude-opus-4-6".to_string(), + role: "assistant".to_string(), + content: vec![ + OutputContentBlock::Thinking { + thinking: "step 1".to_string(), + signature: Some("sig_123".to_string()), + }, + OutputContentBlock::Text { + text: "Final answer".to_string(), + }, + ], + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 1, + output_tokens: 1, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + request_id: None, + }, + &mut out, + ) + .expect("response conversion should succeed"); + + assert!(matches!( + &events[0], + AssistantEvent::TextDelta(text) if text == "Final answer" + )); + assert!(!String::from_utf8(out).expect("utf8").contains("step 1")); + } +} diff --git a/crates/claw-cli/src/render.rs b/crates/claw-cli/src/render.rs new file mode 100644 index 0000000..01751fd --- /dev/null +++ b/crates/claw-cli/src/render.rs @@ -0,0 +1,797 @@ +use std::fmt::Write as FmtWrite; +use std::io::{self, Write}; + +use crossterm::cursor::{MoveToColumn, RestorePosition, SavePosition}; +use crossterm::style::{Color, Print, ResetColor, SetForegroundColor, Stylize}; +use crossterm::terminal::{Clear, ClearType}; +use crossterm::{execute, queue}; +use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag, TagEnd}; +use syntect::easy::HighlightLines; +use syntect::highlighting::{Theme, ThemeSet}; +use syntect::parsing::SyntaxSet; +use syntect::util::{as_24_bit_terminal_escaped, LinesWithEndings}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ColorTheme { + heading: Color, + emphasis: Color, + strong: Color, + inline_code: Color, + link: Color, + quote: Color, + table_border: Color, + code_block_border: Color, + spinner_active: Color, + spinner_done: Color, + spinner_failed: Color, +} + +impl Default for ColorTheme { + fn default() -> Self { + Self { + heading: Color::Cyan, + emphasis: Color::Magenta, + strong: Color::Yellow, + inline_code: Color::Green, + link: Color::Blue, + quote: Color::DarkGrey, + table_border: Color::DarkCyan, + code_block_border: Color::DarkGrey, + spinner_active: Color::Blue, + spinner_done: Color::Green, + spinner_failed: Color::Red, + } + } +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct Spinner { + frame_index: usize, +} + +impl Spinner { + const FRAMES: [&str; 10] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]; + + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn tick( + &mut self, + label: &str, + theme: &ColorTheme, + out: &mut impl Write, + ) -> io::Result<()> { + let frame = Self::FRAMES[self.frame_index % Self::FRAMES.len()]; + self.frame_index += 1; + queue!( + out, + SavePosition, + MoveToColumn(0), + Clear(ClearType::CurrentLine), + SetForegroundColor(theme.spinner_active), + Print(format!("{frame} {label}")), + ResetColor, + RestorePosition + )?; + out.flush() + } + + pub fn finish( + &mut self, + label: &str, + theme: &ColorTheme, + out: &mut impl Write, + ) -> io::Result<()> { + self.frame_index = 0; + execute!( + out, + MoveToColumn(0), + Clear(ClearType::CurrentLine), + SetForegroundColor(theme.spinner_done), + Print(format!("✔ {label}\n")), + ResetColor + )?; + out.flush() + } + + pub fn fail( + &mut self, + label: &str, + theme: &ColorTheme, + out: &mut impl Write, + ) -> io::Result<()> { + self.frame_index = 0; + execute!( + out, + MoveToColumn(0), + Clear(ClearType::CurrentLine), + SetForegroundColor(theme.spinner_failed), + Print(format!("✘ {label}\n")), + ResetColor + )?; + out.flush() + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum ListKind { + Unordered, + Ordered { next_index: u64 }, +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +struct TableState { + headers: Vec<String>, + rows: Vec<Vec<String>>, + current_row: Vec<String>, + current_cell: String, + in_head: bool, +} + +impl TableState { + fn push_cell(&mut self) { + let cell = self.current_cell.trim().to_string(); + self.current_row.push(cell); + self.current_cell.clear(); + } + + fn finish_row(&mut self) { + if self.current_row.is_empty() { + return; + } + let row = std::mem::take(&mut self.current_row); + if self.in_head { + self.headers = row; + } else { + self.rows.push(row); + } + } +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +struct RenderState { + emphasis: usize, + strong: usize, + heading_level: Option<u8>, + quote: usize, + list_stack: Vec<ListKind>, + link_stack: Vec<LinkState>, + table: Option<TableState>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct LinkState { + destination: String, + text: String, +} + +impl RenderState { + fn style_text(&self, text: &str, theme: &ColorTheme) -> String { + let mut style = text.stylize(); + + if matches!(self.heading_level, Some(1 | 2)) || self.strong > 0 { + style = style.bold(); + } + if self.emphasis > 0 { + style = style.italic(); + } + + if let Some(level) = self.heading_level { + style = match level { + 1 => style.with(theme.heading), + 2 => style.white(), + 3 => style.with(Color::Blue), + _ => style.with(Color::Grey), + }; + } else if self.strong > 0 { + style = style.with(theme.strong); + } else if self.emphasis > 0 { + style = style.with(theme.emphasis); + } + + if self.quote > 0 { + style = style.with(theme.quote); + } + + format!("{style}") + } + + fn append_raw(&mut self, output: &mut String, text: &str) { + if let Some(link) = self.link_stack.last_mut() { + link.text.push_str(text); + } else if let Some(table) = self.table.as_mut() { + table.current_cell.push_str(text); + } else { + output.push_str(text); + } + } + + fn append_styled(&mut self, output: &mut String, text: &str, theme: &ColorTheme) { + let styled = self.style_text(text, theme); + self.append_raw(output, &styled); + } +} + +#[derive(Debug)] +pub struct TerminalRenderer { + syntax_set: SyntaxSet, + syntax_theme: Theme, + color_theme: ColorTheme, +} + +impl Default for TerminalRenderer { + fn default() -> Self { + let syntax_set = SyntaxSet::load_defaults_newlines(); + let syntax_theme = ThemeSet::load_defaults() + .themes + .remove("base16-ocean.dark") + .unwrap_or_default(); + Self { + syntax_set, + syntax_theme, + color_theme: ColorTheme::default(), + } + } +} + +impl TerminalRenderer { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + #[must_use] + pub fn color_theme(&self) -> &ColorTheme { + &self.color_theme + } + + #[must_use] + pub fn render_markdown(&self, markdown: &str) -> String { + let mut output = String::new(); + let mut state = RenderState::default(); + let mut code_language = String::new(); + let mut code_buffer = String::new(); + let mut in_code_block = false; + + for event in Parser::new_ext(markdown, Options::all()) { + self.render_event( + event, + &mut state, + &mut output, + &mut code_buffer, + &mut code_language, + &mut in_code_block, + ); + } + + output.trim_end().to_string() + } + + #[must_use] + pub fn markdown_to_ansi(&self, markdown: &str) -> String { + self.render_markdown(markdown) + } + + #[allow(clippy::too_many_lines)] + fn render_event( + &self, + event: Event<'_>, + state: &mut RenderState, + output: &mut String, + code_buffer: &mut String, + code_language: &mut String, + in_code_block: &mut bool, + ) { + match event { + Event::Start(Tag::Heading { level, .. }) => { + self.start_heading(state, level as u8, output); + } + Event::End(TagEnd::Paragraph) => output.push_str("\n\n"), + Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output), + Event::End(TagEnd::BlockQuote(..)) => { + state.quote = state.quote.saturating_sub(1); + output.push('\n'); + } + Event::End(TagEnd::Heading(..)) => { + state.heading_level = None; + output.push_str("\n\n"); + } + Event::End(TagEnd::Item) | Event::SoftBreak | Event::HardBreak => { + state.append_raw(output, "\n"); + } + Event::Start(Tag::List(first_item)) => { + let kind = match first_item { + Some(index) => ListKind::Ordered { next_index: index }, + None => ListKind::Unordered, + }; + state.list_stack.push(kind); + } + Event::End(TagEnd::List(..)) => { + state.list_stack.pop(); + output.push('\n'); + } + Event::Start(Tag::Item) => Self::start_item(state, output), + Event::Start(Tag::CodeBlock(kind)) => { + *in_code_block = true; + *code_language = match kind { + CodeBlockKind::Indented => String::from("text"), + CodeBlockKind::Fenced(lang) => lang.to_string(), + }; + code_buffer.clear(); + self.start_code_block(code_language, output); + } + Event::End(TagEnd::CodeBlock) => { + self.finish_code_block(code_buffer, code_language, output); + *in_code_block = false; + code_language.clear(); + code_buffer.clear(); + } + Event::Start(Tag::Emphasis) => state.emphasis += 1, + Event::End(TagEnd::Emphasis) => state.emphasis = state.emphasis.saturating_sub(1), + Event::Start(Tag::Strong) => state.strong += 1, + Event::End(TagEnd::Strong) => state.strong = state.strong.saturating_sub(1), + Event::Code(code) => { + let rendered = + format!("{}", format!("`{code}`").with(self.color_theme.inline_code)); + state.append_raw(output, &rendered); + } + Event::Rule => output.push_str("---\n"), + Event::Text(text) => { + self.push_text(text.as_ref(), state, output, code_buffer, *in_code_block); + } + Event::Html(html) | Event::InlineHtml(html) => { + state.append_raw(output, &html); + } + Event::FootnoteReference(reference) => { + state.append_raw(output, &format!("[{reference}]")); + } + Event::TaskListMarker(done) => { + state.append_raw(output, if done { "[x] " } else { "[ ] " }); + } + Event::InlineMath(math) | Event::DisplayMath(math) => { + state.append_raw(output, &math); + } + Event::Start(Tag::Link { dest_url, .. }) => { + state.link_stack.push(LinkState { + destination: dest_url.to_string(), + text: String::new(), + }); + } + Event::End(TagEnd::Link) => { + if let Some(link) = state.link_stack.pop() { + let label = if link.text.is_empty() { + link.destination.clone() + } else { + link.text + }; + let rendered = format!( + "{}", + format!("[{label}]({})", link.destination) + .underlined() + .with(self.color_theme.link) + ); + state.append_raw(output, &rendered); + } + } + Event::Start(Tag::Image { dest_url, .. }) => { + let rendered = format!( + "{}", + format!("[image:{dest_url}]").with(self.color_theme.link) + ); + state.append_raw(output, &rendered); + } + Event::Start(Tag::Table(..)) => state.table = Some(TableState::default()), + Event::End(TagEnd::Table) => { + if let Some(table) = state.table.take() { + output.push_str(&self.render_table(&table)); + output.push_str("\n\n"); + } + } + Event::Start(Tag::TableHead) => { + if let Some(table) = state.table.as_mut() { + table.in_head = true; + } + } + Event::End(TagEnd::TableHead) => { + if let Some(table) = state.table.as_mut() { + table.finish_row(); + table.in_head = false; + } + } + Event::Start(Tag::TableRow) => { + if let Some(table) = state.table.as_mut() { + table.current_row.clear(); + table.current_cell.clear(); + } + } + Event::End(TagEnd::TableRow) => { + if let Some(table) = state.table.as_mut() { + table.finish_row(); + } + } + Event::Start(Tag::TableCell) => { + if let Some(table) = state.table.as_mut() { + table.current_cell.clear(); + } + } + Event::End(TagEnd::TableCell) => { + if let Some(table) = state.table.as_mut() { + table.push_cell(); + } + } + Event::Start(Tag::Paragraph | Tag::MetadataBlock(..) | _) + | Event::End(TagEnd::Image | TagEnd::MetadataBlock(..) | _) => {} + } + } + + #[allow(clippy::unused_self)] + fn start_heading(&self, state: &mut RenderState, level: u8, output: &mut String) { + state.heading_level = Some(level); + if !output.is_empty() { + output.push('\n'); + } + } + + fn start_quote(&self, state: &mut RenderState, output: &mut String) { + state.quote += 1; + let _ = write!(output, "{}", "│ ".with(self.color_theme.quote)); + } + + fn start_item(state: &mut RenderState, output: &mut String) { + let depth = state.list_stack.len().saturating_sub(1); + output.push_str(&" ".repeat(depth)); + + let marker = match state.list_stack.last_mut() { + Some(ListKind::Ordered { next_index }) => { + let value = *next_index; + *next_index += 1; + format!("{value}. ") + } + _ => "• ".to_string(), + }; + output.push_str(&marker); + } + + fn start_code_block(&self, code_language: &str, output: &mut String) { + let label = if code_language.is_empty() { + "code".to_string() + } else { + code_language.to_string() + }; + let _ = writeln!( + output, + "{}", + format!("╭─ {label}") + .bold() + .with(self.color_theme.code_block_border) + ); + } + + fn finish_code_block(&self, code_buffer: &str, code_language: &str, output: &mut String) { + output.push_str(&self.highlight_code(code_buffer, code_language)); + let _ = write!( + output, + "{}", + "╰─".bold().with(self.color_theme.code_block_border) + ); + output.push_str("\n\n"); + } + + fn push_text( + &self, + text: &str, + state: &mut RenderState, + output: &mut String, + code_buffer: &mut String, + in_code_block: bool, + ) { + if in_code_block { + code_buffer.push_str(text); + } else { + state.append_styled(output, text, &self.color_theme); + } + } + + fn render_table(&self, table: &TableState) -> String { + let mut rows = Vec::new(); + if !table.headers.is_empty() { + rows.push(table.headers.clone()); + } + rows.extend(table.rows.iter().cloned()); + + if rows.is_empty() { + return String::new(); + } + + let column_count = rows.iter().map(Vec::len).max().unwrap_or(0); + let widths = (0..column_count) + .map(|column| { + rows.iter() + .filter_map(|row| row.get(column)) + .map(|cell| visible_width(cell)) + .max() + .unwrap_or(0) + }) + .collect::<Vec<_>>(); + + let border = format!("{}", "│".with(self.color_theme.table_border)); + let separator = widths + .iter() + .map(|width| "─".repeat(*width + 2)) + .collect::<Vec<_>>() + .join(&format!("{}", "┼".with(self.color_theme.table_border))); + let separator = format!("{border}{separator}{border}"); + + let mut output = String::new(); + if !table.headers.is_empty() { + output.push_str(&self.render_table_row(&table.headers, &widths, true)); + output.push('\n'); + output.push_str(&separator); + if !table.rows.is_empty() { + output.push('\n'); + } + } + + for (index, row) in table.rows.iter().enumerate() { + output.push_str(&self.render_table_row(row, &widths, false)); + if index + 1 < table.rows.len() { + output.push('\n'); + } + } + + output + } + + fn render_table_row(&self, row: &[String], widths: &[usize], is_header: bool) -> String { + let border = format!("{}", "│".with(self.color_theme.table_border)); + let mut line = String::new(); + line.push_str(&border); + + for (index, width) in widths.iter().enumerate() { + let cell = row.get(index).map_or("", String::as_str); + line.push(' '); + if is_header { + let _ = write!(line, "{}", cell.bold().with(self.color_theme.heading)); + } else { + line.push_str(cell); + } + let padding = width.saturating_sub(visible_width(cell)); + line.push_str(&" ".repeat(padding + 1)); + line.push_str(&border); + } + + line + } + + #[must_use] + pub fn highlight_code(&self, code: &str, language: &str) -> String { + let syntax = self + .syntax_set + .find_syntax_by_token(language) + .unwrap_or_else(|| self.syntax_set.find_syntax_plain_text()); + let mut syntax_highlighter = HighlightLines::new(syntax, &self.syntax_theme); + let mut colored_output = String::new(); + + for line in LinesWithEndings::from(code) { + match syntax_highlighter.highlight_line(line, &self.syntax_set) { + Ok(ranges) => { + let escaped = as_24_bit_terminal_escaped(&ranges[..], false); + colored_output.push_str(&apply_code_block_background(&escaped)); + } + Err(_) => colored_output.push_str(&apply_code_block_background(line)), + } + } + + colored_output + } + + pub fn stream_markdown(&self, markdown: &str, out: &mut impl Write) -> io::Result<()> { + let rendered_markdown = self.markdown_to_ansi(markdown); + write!(out, "{rendered_markdown}")?; + if !rendered_markdown.ends_with('\n') { + writeln!(out)?; + } + out.flush() + } +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct MarkdownStreamState { + pending: String, +} + +impl MarkdownStreamState { + #[must_use] + pub fn push(&mut self, renderer: &TerminalRenderer, delta: &str) -> Option<String> { + self.pending.push_str(delta); + let split = find_stream_safe_boundary(&self.pending)?; + let ready = self.pending[..split].to_string(); + self.pending.drain(..split); + Some(renderer.markdown_to_ansi(&ready)) + } + + #[must_use] + pub fn flush(&mut self, renderer: &TerminalRenderer) -> Option<String> { + if self.pending.trim().is_empty() { + self.pending.clear(); + None + } else { + let pending = std::mem::take(&mut self.pending); + Some(renderer.markdown_to_ansi(&pending)) + } + } +} + +fn apply_code_block_background(line: &str) -> String { + let trimmed = line.trim_end_matches('\n'); + let trailing_newline = if trimmed.len() == line.len() { + "" + } else { + "\n" + }; + let with_background = trimmed.replace("\u{1b}[0m", "\u{1b}[0;48;5;236m"); + format!("\u{1b}[48;5;236m{with_background}\u{1b}[0m{trailing_newline}") +} + +fn find_stream_safe_boundary(markdown: &str) -> Option<usize> { + let mut in_fence = false; + let mut last_boundary = None; + + for (offset, line) in markdown.split_inclusive('\n').scan(0usize, |cursor, line| { + let start = *cursor; + *cursor += line.len(); + Some((start, line)) + }) { + let trimmed = line.trim_start(); + if trimmed.starts_with("```") || trimmed.starts_with("~~~") { + in_fence = !in_fence; + if !in_fence { + last_boundary = Some(offset + line.len()); + } + continue; + } + + if in_fence { + continue; + } + + if trimmed.is_empty() { + last_boundary = Some(offset + line.len()); + } + } + + last_boundary +} + +fn visible_width(input: &str) -> usize { + strip_ansi(input).chars().count() +} + +fn strip_ansi(input: &str) -> String { + let mut output = String::new(); + let mut chars = input.chars().peekable(); + + while let Some(ch) = chars.next() { + if ch == '\u{1b}' { + if chars.peek() == Some(&'[') { + chars.next(); + for next in chars.by_ref() { + if next.is_ascii_alphabetic() { + break; + } + } + } + } else { + output.push(ch); + } + } + + output +} + +#[cfg(test)] +mod tests { + use super::{strip_ansi, MarkdownStreamState, Spinner, TerminalRenderer}; + + #[test] + fn renders_markdown_with_styling_and_lists() { + let terminal_renderer = TerminalRenderer::new(); + let markdown_output = terminal_renderer + .render_markdown("# Heading\n\nThis is **bold** and *italic*.\n\n- item\n\n`code`"); + + assert!(markdown_output.contains("Heading")); + assert!(markdown_output.contains("• item")); + assert!(markdown_output.contains("code")); + assert!(markdown_output.contains('\u{1b}')); + } + + #[test] + fn renders_links_as_colored_markdown_labels() { + let terminal_renderer = TerminalRenderer::new(); + let markdown_output = + terminal_renderer.render_markdown("See [Claw](https://example.com/docs) now."); + let plain_text = strip_ansi(&markdown_output); + + assert!(plain_text.contains("[Claw](https://example.com/docs)")); + assert!(markdown_output.contains('\u{1b}')); + } + + #[test] + fn highlights_fenced_code_blocks() { + let terminal_renderer = TerminalRenderer::new(); + let markdown_output = + terminal_renderer.markdown_to_ansi("```rust\nfn hi() { println!(\"hi\"); }\n```"); + let plain_text = strip_ansi(&markdown_output); + + assert!(plain_text.contains("╭─ rust")); + assert!(plain_text.contains("fn hi")); + assert!(markdown_output.contains('\u{1b}')); + assert!(markdown_output.contains("[48;5;236m")); + } + + #[test] + fn renders_ordered_and_nested_lists() { + let terminal_renderer = TerminalRenderer::new(); + let markdown_output = + terminal_renderer.render_markdown("1. first\n2. second\n - nested\n - child"); + let plain_text = strip_ansi(&markdown_output); + + assert!(plain_text.contains("1. first")); + assert!(plain_text.contains("2. second")); + assert!(plain_text.contains(" • nested")); + assert!(plain_text.contains(" • child")); + } + + #[test] + fn renders_tables_with_alignment() { + let terminal_renderer = TerminalRenderer::new(); + let markdown_output = terminal_renderer + .render_markdown("| Name | Value |\n| ---- | ----- |\n| alpha | 1 |\n| beta | 22 |"); + let plain_text = strip_ansi(&markdown_output); + let lines = plain_text.lines().collect::<Vec<_>>(); + + assert_eq!(lines[0], "│ Name │ Value │"); + assert_eq!(lines[1], "│───────┼───────│"); + assert_eq!(lines[2], "│ alpha │ 1 │"); + assert_eq!(lines[3], "│ beta │ 22 │"); + assert!(markdown_output.contains('\u{1b}')); + } + + #[test] + fn streaming_state_waits_for_complete_blocks() { + let renderer = TerminalRenderer::new(); + let mut state = MarkdownStreamState::default(); + + assert_eq!(state.push(&renderer, "# Heading"), None); + let flushed = state + .push(&renderer, "\n\nParagraph\n\n") + .expect("completed block"); + let plain_text = strip_ansi(&flushed); + assert!(plain_text.contains("Heading")); + assert!(plain_text.contains("Paragraph")); + + assert_eq!(state.push(&renderer, "```rust\nfn main() {}\n"), None); + let code = state + .push(&renderer, "```\n") + .expect("closed code fence flushes"); + assert!(strip_ansi(&code).contains("fn main()")); + } + + #[test] + fn spinner_advances_frames() { + let terminal_renderer = TerminalRenderer::new(); + let mut spinner = Spinner::new(); + let mut out = Vec::new(); + spinner + .tick("Working", terminal_renderer.color_theme(), &mut out) + .expect("tick succeeds"); + spinner + .tick("Working", terminal_renderer.color_theme(), &mut out) + .expect("tick succeeds"); + + let output = String::from_utf8_lossy(&out); + assert!(output.contains("Working")); + } +} diff --git a/crates/commands/Cargo.toml b/crates/commands/Cargo.toml index d465bff..2263f7a 100644 --- a/crates/commands/Cargo.toml +++ b/crates/commands/Cargo.toml @@ -9,4 +9,6 @@ publish.workspace = true workspace = true [dependencies] +plugins = { path = "../plugins" } runtime = { path = "../runtime" } +serde_json.workspace = true diff --git a/crates/commands/src/lib.rs b/crates/commands/src/lib.rs index b396bb0..da7f1a4 100644 --- a/crates/commands/src/lib.rs +++ b/crates/commands/src/lib.rs @@ -1,3 +1,12 @@ +use std::collections::BTreeMap; +use std::env; +use std::fs; +use std::io; +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::time::{SystemTime, UNIX_EPOCH}; + +use plugins::{PluginError, PluginManager, PluginSummary}; use runtime::{compact_session, CompactionConfig, Session}; #[derive(Debug, Clone, PartialEq, Eq)] @@ -30,104 +39,263 @@ impl CommandRegistry { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SlashCommandCategory { + Core, + Workspace, + Session, + Git, + Automation, +} + +impl SlashCommandCategory { + const fn title(self) -> &'static str { + match self { + Self::Core => "Core flow", + Self::Workspace => "Workspace & memory", + Self::Session => "Sessions & output", + Self::Git => "Git & GitHub", + Self::Automation => "Automation & discovery", + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SlashCommandSpec { pub name: &'static str, + pub aliases: &'static [&'static str], pub summary: &'static str, pub argument_hint: Option<&'static str>, pub resume_supported: bool, + pub category: SlashCommandCategory, } const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ SlashCommandSpec { name: "help", + aliases: &[], summary: "Show available slash commands", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "status", + aliases: &[], summary: "Show current session status", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "compact", + aliases: &[], summary: "Compact local session history", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "model", + aliases: &[], summary: "Show or switch the active model", argument_hint: Some("[model]"), resume_supported: false, + category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "permissions", + aliases: &[], summary: "Show or switch the active permission mode", argument_hint: Some("[read-only|workspace-write|danger-full-access]"), resume_supported: false, + category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "clear", + aliases: &[], summary: "Start a fresh local session", argument_hint: Some("[--confirm]"), resume_supported: true, + category: SlashCommandCategory::Session, }, SlashCommandSpec { name: "cost", + aliases: &[], summary: "Show cumulative token usage for this session", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "resume", + aliases: &[], summary: "Load a saved session into the REPL", argument_hint: Some("<session-path>"), resume_supported: false, + category: SlashCommandCategory::Session, }, SlashCommandSpec { name: "config", - summary: "Inspect Claude config files or merged sections", - argument_hint: Some("[env|hooks|model]"), + aliases: &[], + summary: "Inspect Claw config files or merged sections", + argument_hint: Some("[env|hooks|model|plugins]"), resume_supported: true, + category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "memory", - summary: "Inspect loaded Claude instruction memory files", + aliases: &[], + summary: "Inspect loaded Claw instruction memory files", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "init", - summary: "Create a starter CLAUDE.md for this repo", + aliases: &[], + summary: "Create a starter CLAW.md for this repo", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "diff", + aliases: &[], summary: "Show git diff for current workspace changes", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "version", + aliases: &[], summary: "Show CLI version and build information", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Workspace, + }, + SlashCommandSpec { + name: "bughunter", + aliases: &[], + summary: "Inspect the codebase for likely bugs", + argument_hint: Some("[scope]"), + resume_supported: false, + category: SlashCommandCategory::Automation, + }, + SlashCommandSpec { + name: "branch", + aliases: &[], + summary: "List, create, or switch git branches", + argument_hint: Some("[list|create <name>|switch <name>]"), + resume_supported: false, + category: SlashCommandCategory::Git, + }, + SlashCommandSpec { + name: "worktree", + aliases: &[], + summary: "List, add, remove, or prune git worktrees", + argument_hint: Some("[list|add <path> [branch]|remove <path>|prune]"), + resume_supported: false, + category: SlashCommandCategory::Git, + }, + SlashCommandSpec { + name: "commit", + aliases: &[], + summary: "Generate a commit message and create a git commit", + argument_hint: None, + resume_supported: false, + category: SlashCommandCategory::Git, + }, + SlashCommandSpec { + name: "commit-push-pr", + aliases: &[], + summary: "Commit workspace changes, push the branch, and open a PR", + argument_hint: Some("[context]"), + resume_supported: false, + category: SlashCommandCategory::Git, + }, + SlashCommandSpec { + name: "pr", + aliases: &[], + summary: "Draft or create a pull request from the conversation", + argument_hint: Some("[context]"), + resume_supported: false, + category: SlashCommandCategory::Git, + }, + SlashCommandSpec { + name: "issue", + aliases: &[], + summary: "Draft or create a GitHub issue from the conversation", + argument_hint: Some("[context]"), + resume_supported: false, + category: SlashCommandCategory::Git, + }, + SlashCommandSpec { + name: "ultraplan", + aliases: &[], + summary: "Run a deep planning prompt with multi-step reasoning", + argument_hint: Some("[task]"), + resume_supported: false, + category: SlashCommandCategory::Automation, + }, + SlashCommandSpec { + name: "teleport", + aliases: &[], + summary: "Jump to a file or symbol by searching the workspace", + argument_hint: Some("<symbol-or-path>"), + resume_supported: false, + category: SlashCommandCategory::Workspace, + }, + SlashCommandSpec { + name: "debug-tool-call", + aliases: &[], + summary: "Replay the last tool call with debug details", + argument_hint: None, + resume_supported: false, + category: SlashCommandCategory::Automation, }, SlashCommandSpec { name: "export", + aliases: &[], summary: "Export the current conversation to a file", argument_hint: Some("[file]"), resume_supported: true, + category: SlashCommandCategory::Session, }, SlashCommandSpec { name: "session", + aliases: &[], summary: "List or switch managed local sessions", argument_hint: Some("[list|switch <session-id>]"), resume_supported: false, + category: SlashCommandCategory::Session, + }, + SlashCommandSpec { + name: "plugin", + aliases: &["plugins", "marketplace"], + summary: "Manage Claw Code plugins", + argument_hint: Some( + "[list|install <path>|enable <name>|disable <name>|uninstall <id>|update <id>]", + ), + resume_supported: false, + category: SlashCommandCategory::Automation, + }, + SlashCommandSpec { + name: "agents", + aliases: &[], + summary: "List configured agents", + argument_hint: None, + resume_supported: true, + category: SlashCommandCategory::Automation, + }, + SlashCommandSpec { + name: "skills", + aliases: &[], + summary: "List available skills", + argument_hint: None, + resume_supported: true, + category: SlashCommandCategory::Automation, }, ]; @@ -136,6 +304,35 @@ pub enum SlashCommand { Help, Status, Compact, + Branch { + action: Option<String>, + target: Option<String>, + }, + Bughunter { + scope: Option<String>, + }, + Worktree { + action: Option<String>, + path: Option<String>, + branch: Option<String>, + }, + Commit, + CommitPushPr { + context: Option<String>, + }, + Pr { + context: Option<String>, + }, + Issue { + context: Option<String>, + }, + Ultraplan { + task: Option<String>, + }, + Teleport { + target: Option<String>, + }, + DebugToolCall, Model { model: Option<String>, }, @@ -163,6 +360,16 @@ pub enum SlashCommand { action: Option<String>, target: Option<String>, }, + Plugins { + action: Option<String>, + target: Option<String>, + }, + Agents { + args: Option<String>, + }, + Skills { + args: Option<String>, + }, Unknown(String), } @@ -180,6 +387,35 @@ impl SlashCommand { "help" => Self::Help, "status" => Self::Status, "compact" => Self::Compact, + "branch" => Self::Branch { + action: parts.next().map(ToOwned::to_owned), + target: parts.next().map(ToOwned::to_owned), + }, + "bughunter" => Self::Bughunter { + scope: remainder_after_command(trimmed, command), + }, + "worktree" => Self::Worktree { + action: parts.next().map(ToOwned::to_owned), + path: parts.next().map(ToOwned::to_owned), + branch: parts.next().map(ToOwned::to_owned), + }, + "commit" => Self::Commit, + "commit-push-pr" => Self::CommitPushPr { + context: remainder_after_command(trimmed, command), + }, + "pr" => Self::Pr { + context: remainder_after_command(trimmed, command), + }, + "issue" => Self::Issue { + context: remainder_after_command(trimmed, command), + }, + "ultraplan" => Self::Ultraplan { + task: remainder_after_command(trimmed, command), + }, + "teleport" => Self::Teleport { + target: remainder_after_command(trimmed, command), + }, + "debug-tool-call" => Self::DebugToolCall, "model" => Self::Model { model: parts.next().map(ToOwned::to_owned), }, @@ -207,11 +443,33 @@ impl SlashCommand { action: parts.next().map(ToOwned::to_owned), target: parts.next().map(ToOwned::to_owned), }, + "plugin" | "plugins" | "marketplace" => Self::Plugins { + action: parts.next().map(ToOwned::to_owned), + target: { + let remainder = parts.collect::<Vec<_>>().join(" "); + (!remainder.is_empty()).then_some(remainder) + }, + }, + "agents" => Self::Agents { + args: remainder_after_command(trimmed, command), + }, + "skills" => Self::Skills { + args: remainder_after_command(trimmed, command), + }, other => Self::Unknown(other.to_string()), }) } } +fn remainder_after_command(input: &str, command: &str) -> Option<String> { + input + .trim() + .strip_prefix(&format!("/{command}")) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned) +} + #[must_use] pub fn slash_command_specs() -> &'static [SlashCommandSpec] { SLASH_COMMAND_SPECS @@ -229,29 +487,1250 @@ pub fn resume_supported_slash_commands() -> Vec<&'static SlashCommandSpec> { pub fn render_slash_command_help() -> String { let mut lines = vec![ "Slash commands".to_string(), - " [resume] means the command also works with --resume SESSION.json".to_string(), + " Tab completes commands inside the REPL.".to_string(), + " [resume] = also available via claw --resume SESSION.json".to_string(), ]; - for spec in slash_command_specs() { - let name = match spec.argument_hint { - Some(argument_hint) => format!("/{} {}", spec.name, argument_hint), - None => format!("/{}", spec.name), - }; - let resume = if spec.resume_supported { - " [resume]" - } else { - "" - }; - lines.push(format!(" {name:<20} {}{}", spec.summary, resume)); + + for category in [ + SlashCommandCategory::Core, + SlashCommandCategory::Workspace, + SlashCommandCategory::Session, + SlashCommandCategory::Git, + SlashCommandCategory::Automation, + ] { + lines.push(String::new()); + lines.push(category.title().to_string()); + lines.extend( + slash_command_specs() + .iter() + .filter(|spec| spec.category == category) + .map(render_slash_command_entry), + ); } + lines.join("\n") } +fn render_slash_command_entry(spec: &SlashCommandSpec) -> String { + let alias_suffix = if spec.aliases.is_empty() { + String::new() + } else { + format!( + " (aliases: {})", + spec.aliases + .iter() + .map(|alias| format!("/{alias}")) + .collect::<Vec<_>>() + .join(", ") + ) + }; + let resume = if spec.resume_supported { + " [resume]" + } else { + "" + }; + format!( + " {name:<46} {}{alias_suffix}{resume}", + spec.summary, + name = render_slash_command_name(spec), + ) +} + +fn render_slash_command_name(spec: &SlashCommandSpec) -> String { + match spec.argument_hint { + Some(argument_hint) => format!("/{} {}", spec.name, argument_hint), + None => format!("/{}", spec.name), + } +} + +fn levenshtein_distance(left: &str, right: &str) -> usize { + if left == right { + return 0; + } + if left.is_empty() { + return right.chars().count(); + } + if right.is_empty() { + return left.chars().count(); + } + + let right_chars = right.chars().collect::<Vec<_>>(); + let mut previous = (0..=right_chars.len()).collect::<Vec<_>>(); + let mut current = vec![0; right_chars.len() + 1]; + + for (left_index, left_char) in left.chars().enumerate() { + current[0] = left_index + 1; + for (right_index, right_char) in right_chars.iter().enumerate() { + let cost = usize::from(left_char != *right_char); + current[right_index + 1] = (previous[right_index + 1] + 1) + .min(current[right_index] + 1) + .min(previous[right_index] + cost); + } + std::mem::swap(&mut previous, &mut current); + } + + previous[right_chars.len()] +} + +#[must_use] +pub fn suggest_slash_commands(input: &str, limit: usize) -> Vec<String> { + let normalized = input.trim().trim_start_matches('/').to_ascii_lowercase(); + if normalized.is_empty() || limit == 0 { + return Vec::new(); + } + + let mut ranked = slash_command_specs() + .iter() + .filter_map(|spec| { + let score = std::iter::once(spec.name) + .chain(spec.aliases.iter().copied()) + .map(str::to_ascii_lowercase) + .filter_map(|alias| { + if alias == normalized { + Some((0_usize, alias.len())) + } else if alias.starts_with(&normalized) { + Some((1, alias.len())) + } else if alias.contains(&normalized) { + Some((2, alias.len())) + } else { + let distance = levenshtein_distance(&alias, &normalized); + (distance <= 2).then_some((3 + distance, alias.len())) + } + }) + .min(); + + score.map(|(bucket, len)| (bucket, len, render_slash_command_name(spec))) + }) + .collect::<Vec<_>>(); + + ranked.sort_by(|left, right| left.cmp(right)); + ranked.dedup_by(|left, right| left.2 == right.2); + ranked + .into_iter() + .take(limit) + .map(|(_, _, display)| display) + .collect() +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct SlashCommandResult { pub message: String, pub session: Session, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PluginsCommandResult { + pub message: String, + pub reload_runtime: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +enum DefinitionSource { + ProjectCodex, + ProjectClaw, + UserCodexHome, + UserCodex, + UserClaw, +} + +impl DefinitionSource { + fn label(self) -> &'static str { + match self { + Self::ProjectCodex => "Project (.codex)", + Self::ProjectClaw => "Project (.claw)", + Self::UserCodexHome => "User ($CODEX_HOME)", + Self::UserCodex => "User (~/.codex)", + Self::UserClaw => "User (~/.claw)", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct AgentSummary { + name: String, + description: Option<String>, + model: Option<String>, + reasoning_effort: Option<String>, + source: DefinitionSource, + shadowed_by: Option<DefinitionSource>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SkillSummary { + name: String, + description: Option<String>, + source: DefinitionSource, + shadowed_by: Option<DefinitionSource>, + origin: SkillOrigin, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SkillOrigin { + SkillsDir, + LegacyCommandsDir, +} + +impl SkillOrigin { + fn detail_label(self) -> Option<&'static str> { + match self { + Self::SkillsDir => None, + Self::LegacyCommandsDir => Some("legacy /commands"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SkillRoot { + source: DefinitionSource, + path: PathBuf, + origin: SkillOrigin, +} + +#[allow(clippy::too_many_lines)] +pub fn handle_plugins_slash_command( + action: Option<&str>, + target: Option<&str>, + manager: &mut PluginManager, +) -> Result<PluginsCommandResult, PluginError> { + match action { + None | Some("list") => Ok(PluginsCommandResult { + message: render_plugins_report(&manager.list_installed_plugins()?), + reload_runtime: false, + }), + Some("install") => { + let Some(target) = target else { + return Ok(PluginsCommandResult { + message: "Usage: /plugins install <path>".to_string(), + reload_runtime: false, + }); + }; + let install = manager.install(target)?; + let plugin = manager + .list_installed_plugins()? + .into_iter() + .find(|plugin| plugin.metadata.id == install.plugin_id); + Ok(PluginsCommandResult { + message: render_plugin_install_report(&install.plugin_id, plugin.as_ref()), + reload_runtime: true, + }) + } + Some("enable") => { + let Some(target) = target else { + return Ok(PluginsCommandResult { + message: "Usage: /plugins enable <name>".to_string(), + reload_runtime: false, + }); + }; + let plugin = resolve_plugin_target(manager, target)?; + manager.enable(&plugin.metadata.id)?; + Ok(PluginsCommandResult { + message: format!( + "Plugins\n Result enabled {}\n Name {}\n Version {}\n Status enabled", + plugin.metadata.id, plugin.metadata.name, plugin.metadata.version + ), + reload_runtime: true, + }) + } + Some("disable") => { + let Some(target) = target else { + return Ok(PluginsCommandResult { + message: "Usage: /plugins disable <name>".to_string(), + reload_runtime: false, + }); + }; + let plugin = resolve_plugin_target(manager, target)?; + manager.disable(&plugin.metadata.id)?; + Ok(PluginsCommandResult { + message: format!( + "Plugins\n Result disabled {}\n Name {}\n Version {}\n Status disabled", + plugin.metadata.id, plugin.metadata.name, plugin.metadata.version + ), + reload_runtime: true, + }) + } + Some("uninstall") => { + let Some(target) = target else { + return Ok(PluginsCommandResult { + message: "Usage: /plugins uninstall <plugin-id>".to_string(), + reload_runtime: false, + }); + }; + manager.uninstall(target)?; + Ok(PluginsCommandResult { + message: format!("Plugins\n Result uninstalled {target}"), + reload_runtime: true, + }) + } + Some("update") => { + let Some(target) = target else { + return Ok(PluginsCommandResult { + message: "Usage: /plugins update <plugin-id>".to_string(), + reload_runtime: false, + }); + }; + let update = manager.update(target)?; + let plugin = manager + .list_installed_plugins()? + .into_iter() + .find(|plugin| plugin.metadata.id == update.plugin_id); + Ok(PluginsCommandResult { + message: format!( + "Plugins\n Result updated {}\n Name {}\n Old version {}\n New version {}\n Status {}", + update.plugin_id, + plugin + .as_ref() + .map_or_else(|| update.plugin_id.clone(), |plugin| plugin.metadata.name.clone()), + update.old_version, + update.new_version, + plugin + .as_ref() + .map_or("unknown", |plugin| if plugin.enabled { "enabled" } else { "disabled" }), + ), + reload_runtime: true, + }) + } + Some(other) => Ok(PluginsCommandResult { + message: format!( + "Unknown /plugins action '{other}'. Use list, install, enable, disable, uninstall, or update." + ), + reload_runtime: false, + }), + } +} + +pub fn handle_agents_slash_command(args: Option<&str>, cwd: &Path) -> std::io::Result<String> { + match normalize_optional_args(args) { + None | Some("list") => { + let roots = discover_definition_roots(cwd, "agents"); + let agents = load_agents_from_roots(&roots)?; + Ok(render_agents_report(&agents)) + } + Some("-h" | "--help" | "help") => Ok(render_agents_usage(None)), + Some(args) => Ok(render_agents_usage(Some(args))), + } +} + +pub fn handle_skills_slash_command(args: Option<&str>, cwd: &Path) -> std::io::Result<String> { + match normalize_optional_args(args) { + None | Some("list") => { + let roots = discover_skill_roots(cwd); + let skills = load_skills_from_roots(&roots)?; + Ok(render_skills_report(&skills)) + } + Some("-h" | "--help" | "help") => Ok(render_skills_usage(None)), + Some(args) => Ok(render_skills_usage(Some(args))), + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CommitPushPrRequest { + pub commit_message: Option<String>, + pub pr_title: String, + pub pr_body: String, + pub branch_name_hint: String, +} + +pub fn handle_branch_slash_command( + action: Option<&str>, + target: Option<&str>, + cwd: &Path, +) -> io::Result<String> { + match normalize_optional_args(action) { + None | Some("list") => { + let branches = git_stdout(cwd, &["branch", "--list", "--verbose"])?; + let trimmed = branches.trim(); + Ok(if trimmed.is_empty() { + "Branch\n Result no branches found".to_string() + } else { + format!("Branch\n Result listed\n\n{}", trimmed) + }) + } + Some("create") => { + let Some(target) = target.filter(|value| !value.trim().is_empty()) else { + return Ok("Usage: /branch create <name>".to_string()); + }; + git_status_ok(cwd, &["switch", "-c", target])?; + Ok(format!( + "Branch\n Result created and switched\n Branch {target}" + )) + } + Some("switch") => { + let Some(target) = target.filter(|value| !value.trim().is_empty()) else { + return Ok("Usage: /branch switch <name>".to_string()); + }; + git_status_ok(cwd, &["switch", target])?; + Ok(format!( + "Branch\n Result switched\n Branch {target}" + )) + } + Some(other) => Ok(format!( + "Unknown /branch action '{other}'. Use /branch list, /branch create <name>, or /branch switch <name>." + )), + } +} + +pub fn handle_worktree_slash_command( + action: Option<&str>, + path: Option<&str>, + branch: Option<&str>, + cwd: &Path, +) -> io::Result<String> { + match normalize_optional_args(action) { + None | Some("list") => { + let worktrees = git_stdout(cwd, &["worktree", "list"])?; + let trimmed = worktrees.trim(); + Ok(if trimmed.is_empty() { + "Worktree\n Result no worktrees found".to_string() + } else { + format!("Worktree\n Result listed\n\n{}", trimmed) + }) + } + Some("add") => { + let Some(path) = path.filter(|value| !value.trim().is_empty()) else { + return Ok("Usage: /worktree add <path> [branch]".to_string()); + }; + if let Some(branch) = branch.filter(|value| !value.trim().is_empty()) { + if branch_exists(cwd, branch) { + git_status_ok(cwd, &["worktree", "add", path, branch])?; + } else { + git_status_ok(cwd, &["worktree", "add", path, "-b", branch])?; + } + Ok(format!( + "Worktree\n Result added\n Path {path}\n Branch {branch}" + )) + } else { + git_status_ok(cwd, &["worktree", "add", path])?; + Ok(format!( + "Worktree\n Result added\n Path {path}" + )) + } + } + Some("remove") => { + let Some(path) = path.filter(|value| !value.trim().is_empty()) else { + return Ok("Usage: /worktree remove <path>".to_string()); + }; + git_status_ok(cwd, &["worktree", "remove", path])?; + Ok(format!( + "Worktree\n Result removed\n Path {path}" + )) + } + Some("prune") => { + git_status_ok(cwd, &["worktree", "prune"])?; + Ok("Worktree\n Result pruned".to_string()) + } + Some(other) => Ok(format!( + "Unknown /worktree action '{other}'. Use /worktree list, /worktree add <path> [branch], /worktree remove <path>, or /worktree prune." + )), + } +} + +pub fn handle_commit_slash_command(message: &str, cwd: &Path) -> io::Result<String> { + let status = git_stdout(cwd, &["status", "--short"])?; + if status.trim().is_empty() { + return Ok( + "Commit\n Result skipped\n Reason no workspace changes" + .to_string(), + ); + } + + let message = message.trim(); + if message.is_empty() { + return Err(io::Error::other("generated commit message was empty")); + } + + git_status_ok(cwd, &["add", "-A"])?; + let path = write_temp_text_file("claw-commit-message", "txt", message)?; + let path_string = path.to_string_lossy().into_owned(); + git_status_ok(cwd, &["commit", "--file", path_string.as_str()])?; + + Ok(format!( + "Commit\n Result created\n Message file {}\n\n{}", + path.display(), + message + )) +} + +pub fn handle_commit_push_pr_slash_command( + request: &CommitPushPrRequest, + cwd: &Path, +) -> io::Result<String> { + if !command_exists("gh") { + return Err(io::Error::other("gh CLI is required for /commit-push-pr")); + } + + let default_branch = detect_default_branch(cwd)?; + let mut branch = current_branch(cwd)?; + let mut created_branch = false; + if branch == default_branch { + let hint = if request.branch_name_hint.trim().is_empty() { + request.pr_title.as_str() + } else { + request.branch_name_hint.as_str() + }; + let next_branch = build_branch_name(hint); + git_status_ok(cwd, &["switch", "-c", next_branch.as_str()])?; + branch = next_branch; + created_branch = true; + } + + let workspace_has_changes = !git_stdout(cwd, &["status", "--short"])?.trim().is_empty(); + let commit_report = if workspace_has_changes { + let Some(message) = request.commit_message.as_deref() else { + return Err(io::Error::other( + "commit message is required when workspace changes are present", + )); + }; + Some(handle_commit_slash_command(message, cwd)?) + } else { + None + }; + + let branch_diff = git_stdout( + cwd, + &["diff", "--stat", &format!("{default_branch}...HEAD")], + )?; + if branch_diff.trim().is_empty() { + return Ok( + "Commit/Push/PR\n Result skipped\n Reason no branch changes to push or open as a pull request" + .to_string(), + ); + } + + git_status_ok(cwd, &["push", "--set-upstream", "origin", branch.as_str()])?; + + let body_path = write_temp_text_file("claw-pr-body", "md", request.pr_body.trim())?; + let body_path_string = body_path.to_string_lossy().into_owned(); + let create = Command::new("gh") + .args([ + "pr", + "create", + "--title", + request.pr_title.as_str(), + "--body-file", + body_path_string.as_str(), + "--base", + default_branch.as_str(), + ]) + .current_dir(cwd) + .output()?; + + let (result, url) = if create.status.success() { + ( + "created", + parse_pr_url(&String::from_utf8_lossy(&create.stdout)) + .unwrap_or_else(|| "<unknown>".to_string()), + ) + } else { + let view = Command::new("gh") + .args(["pr", "view", "--json", "url"]) + .current_dir(cwd) + .output()?; + if !view.status.success() { + return Err(io::Error::other(command_failure( + "gh", + &["pr", "create"], + &create, + ))); + } + ( + "existing", + parse_pr_json_url(&String::from_utf8_lossy(&view.stdout)) + .unwrap_or_else(|| "<unknown>".to_string()), + ) + }; + + let mut lines = vec![ + "Commit/Push/PR".to_string(), + format!(" Result {result}"), + format!(" Branch {branch}"), + format!(" Base {default_branch}"), + format!(" Body file {}", body_path.display()), + format!(" URL {url}"), + ]; + if created_branch { + lines.insert(2, " Branch action created and switched".to_string()); + } + if let Some(report) = commit_report { + lines.push(String::new()); + lines.push(report); + } + Ok(lines.join("\n")) +} + +pub fn detect_default_branch(cwd: &Path) -> io::Result<String> { + if let Ok(reference) = git_stdout(cwd, &["symbolic-ref", "refs/remotes/origin/HEAD"]) { + if let Some(branch) = reference + .trim() + .rsplit('/') + .next() + .filter(|value| !value.is_empty()) + { + return Ok(branch.to_string()); + } + } + + for branch in ["main", "master"] { + if branch_exists(cwd, branch) { + return Ok(branch.to_string()); + } + } + + current_branch(cwd) +} + +fn git_stdout(cwd: &Path, args: &[&str]) -> io::Result<String> { + run_command_stdout("git", args, cwd) +} + +fn git_status_ok(cwd: &Path, args: &[&str]) -> io::Result<()> { + run_command_success("git", args, cwd) +} + +fn run_command_stdout(program: &str, args: &[&str], cwd: &Path) -> io::Result<String> { + let output = Command::new(program).args(args).current_dir(cwd).output()?; + if !output.status.success() { + return Err(io::Error::other(command_failure(program, args, &output))); + } + String::from_utf8(output.stdout) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)) +} + +fn run_command_success(program: &str, args: &[&str], cwd: &Path) -> io::Result<()> { + let output = Command::new(program).args(args).current_dir(cwd).output()?; + if !output.status.success() { + return Err(io::Error::other(command_failure(program, args, &output))); + } + Ok(()) +} + +fn command_failure(program: &str, args: &[&str], output: &std::process::Output) -> String { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + let detail = if stderr.is_empty() { stdout } else { stderr }; + if detail.is_empty() { + format!("{program} {} failed", args.join(" ")) + } else { + format!("{program} {} failed: {detail}", args.join(" ")) + } +} + +fn branch_exists(cwd: &Path, branch: &str) -> bool { + Command::new("git") + .args([ + "show-ref", + "--verify", + "--quiet", + &format!("refs/heads/{branch}"), + ]) + .current_dir(cwd) + .output() + .map(|output| output.status.success()) + .unwrap_or(false) +} + +fn current_branch(cwd: &Path) -> io::Result<String> { + let branch = git_stdout(cwd, &["branch", "--show-current"])?; + let branch = branch.trim(); + if branch.is_empty() { + Err(io::Error::other("unable to determine current git branch")) + } else { + Ok(branch.to_string()) + } +} + +fn command_exists(name: &str) -> bool { + Command::new(name) + .arg("--version") + .output() + .map(|output| output.status.success()) + .unwrap_or(false) +} + +fn write_temp_text_file(prefix: &str, extension: &str, contents: &str) -> io::Result<PathBuf> { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_nanos()) + .unwrap_or_default(); + let path = env::temp_dir().join(format!("{prefix}-{nanos}.{extension}")); + fs::write(&path, contents)?; + Ok(path) +} + +fn build_branch_name(hint: &str) -> String { + let slug = slugify(hint); + let owner = env::var("SAFEUSER") + .ok() + .filter(|value| !value.trim().is_empty()) + .or_else(|| { + env::var("USER") + .ok() + .filter(|value| !value.trim().is_empty()) + }); + match owner { + Some(owner) => format!("{owner}/{slug}"), + None => slug, + } +} + +fn slugify(value: &str) -> String { + let mut slug = String::new(); + let mut last_was_dash = false; + for ch in value.chars() { + if ch.is_ascii_alphanumeric() { + slug.push(ch.to_ascii_lowercase()); + last_was_dash = false; + } else if !last_was_dash { + slug.push('-'); + last_was_dash = true; + } + } + let slug = slug.trim_matches('-').to_string(); + if slug.is_empty() { + "change".to_string() + } else { + slug + } +} + +fn parse_pr_url(stdout: &str) -> Option<String> { + stdout + .lines() + .map(str::trim) + .find(|line| line.starts_with("http://") || line.starts_with("https://")) + .map(ToOwned::to_owned) +} + +fn parse_pr_json_url(stdout: &str) -> Option<String> { + serde_json::from_str::<serde_json::Value>(stdout) + .ok()? + .get("url")? + .as_str() + .map(ToOwned::to_owned) +} + +#[must_use] +pub fn render_plugins_report(plugins: &[PluginSummary]) -> String { + let mut lines = vec!["Plugins".to_string()]; + if plugins.is_empty() { + lines.push(" No plugins installed.".to_string()); + return lines.join("\n"); + } + for plugin in plugins { + let enabled = if plugin.enabled { + "enabled" + } else { + "disabled" + }; + lines.push(format!( + " {name:<20} v{version:<10} {enabled}", + name = plugin.metadata.name, + version = plugin.metadata.version, + )); + } + lines.join("\n") +} + +fn render_plugin_install_report(plugin_id: &str, plugin: Option<&PluginSummary>) -> String { + let name = plugin.map_or(plugin_id, |plugin| plugin.metadata.name.as_str()); + let version = plugin.map_or("unknown", |plugin| plugin.metadata.version.as_str()); + let enabled = plugin.is_some_and(|plugin| plugin.enabled); + format!( + "Plugins\n Result installed {plugin_id}\n Name {name}\n Version {version}\n Status {}", + if enabled { "enabled" } else { "disabled" } + ) +} + +fn resolve_plugin_target( + manager: &PluginManager, + target: &str, +) -> Result<PluginSummary, PluginError> { + let mut matches = manager + .list_installed_plugins()? + .into_iter() + .filter(|plugin| plugin.metadata.id == target || plugin.metadata.name == target) + .collect::<Vec<_>>(); + match matches.len() { + 1 => Ok(matches.remove(0)), + 0 => Err(PluginError::NotFound(format!( + "plugin `{target}` is not installed or discoverable" + ))), + _ => Err(PluginError::InvalidManifest(format!( + "plugin name `{target}` is ambiguous; use the full plugin id" + ))), + } +} + +fn discover_definition_roots(cwd: &Path, leaf: &str) -> Vec<(DefinitionSource, PathBuf)> { + let mut roots = Vec::new(); + + for ancestor in cwd.ancestors() { + push_unique_root( + &mut roots, + DefinitionSource::ProjectCodex, + ancestor.join(".codex").join(leaf), + ); + push_unique_root( + &mut roots, + DefinitionSource::ProjectClaw, + ancestor.join(".claw").join(leaf), + ); + } + + if let Ok(codex_home) = env::var("CODEX_HOME") { + push_unique_root( + &mut roots, + DefinitionSource::UserCodexHome, + PathBuf::from(codex_home).join(leaf), + ); + } + + if let Some(home) = env::var_os("HOME") { + let home = PathBuf::from(home); + push_unique_root( + &mut roots, + DefinitionSource::UserCodex, + home.join(".codex").join(leaf), + ); + push_unique_root( + &mut roots, + DefinitionSource::UserClaw, + home.join(".claw").join(leaf), + ); + } + + roots +} + +fn discover_skill_roots(cwd: &Path) -> Vec<SkillRoot> { + let mut roots = Vec::new(); + + for ancestor in cwd.ancestors() { + push_unique_skill_root( + &mut roots, + DefinitionSource::ProjectCodex, + ancestor.join(".codex").join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::ProjectClaw, + ancestor.join(".claw").join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::ProjectCodex, + ancestor.join(".codex").join("commands"), + SkillOrigin::LegacyCommandsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::ProjectClaw, + ancestor.join(".claw").join("commands"), + SkillOrigin::LegacyCommandsDir, + ); + } + + if let Ok(codex_home) = env::var("CODEX_HOME") { + let codex_home = PathBuf::from(codex_home); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserCodexHome, + codex_home.join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserCodexHome, + codex_home.join("commands"), + SkillOrigin::LegacyCommandsDir, + ); + } + + if let Some(home) = env::var_os("HOME") { + let home = PathBuf::from(home); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserCodex, + home.join(".codex").join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserCodex, + home.join(".codex").join("commands"), + SkillOrigin::LegacyCommandsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClaw, + home.join(".claw").join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClaw, + home.join(".claw").join("commands"), + SkillOrigin::LegacyCommandsDir, + ); + } + + roots +} + +fn push_unique_root( + roots: &mut Vec<(DefinitionSource, PathBuf)>, + source: DefinitionSource, + path: PathBuf, +) { + if path.is_dir() && !roots.iter().any(|(_, existing)| existing == &path) { + roots.push((source, path)); + } +} + +fn push_unique_skill_root( + roots: &mut Vec<SkillRoot>, + source: DefinitionSource, + path: PathBuf, + origin: SkillOrigin, +) { + if path.is_dir() && !roots.iter().any(|existing| existing.path == path) { + roots.push(SkillRoot { + source, + path, + origin, + }); + } +} + +fn load_agents_from_roots( + roots: &[(DefinitionSource, PathBuf)], +) -> std::io::Result<Vec<AgentSummary>> { + let mut agents = Vec::new(); + let mut active_sources = BTreeMap::<String, DefinitionSource>::new(); + + for (source, root) in roots { + let mut root_agents = Vec::new(); + for entry in fs::read_dir(root)? { + let entry = entry?; + if entry.path().extension().is_none_or(|ext| ext != "toml") { + continue; + } + let contents = fs::read_to_string(entry.path())?; + let fallback_name = entry.path().file_stem().map_or_else( + || entry.file_name().to_string_lossy().to_string(), + |stem| stem.to_string_lossy().to_string(), + ); + root_agents.push(AgentSummary { + name: parse_toml_string(&contents, "name").unwrap_or(fallback_name), + description: parse_toml_string(&contents, "description"), + model: parse_toml_string(&contents, "model"), + reasoning_effort: parse_toml_string(&contents, "model_reasoning_effort"), + source: *source, + shadowed_by: None, + }); + } + root_agents.sort_by(|left, right| left.name.cmp(&right.name)); + + for mut agent in root_agents { + let key = agent.name.to_ascii_lowercase(); + if let Some(existing) = active_sources.get(&key) { + agent.shadowed_by = Some(*existing); + } else { + active_sources.insert(key, agent.source); + } + agents.push(agent); + } + } + + Ok(agents) +} + +fn load_skills_from_roots(roots: &[SkillRoot]) -> std::io::Result<Vec<SkillSummary>> { + let mut skills = Vec::new(); + let mut active_sources = BTreeMap::<String, DefinitionSource>::new(); + + for root in roots { + let mut root_skills = Vec::new(); + for entry in fs::read_dir(&root.path)? { + let entry = entry?; + match root.origin { + SkillOrigin::SkillsDir => { + if !entry.path().is_dir() { + continue; + } + let skill_path = entry.path().join("SKILL.md"); + if !skill_path.is_file() { + continue; + } + let contents = fs::read_to_string(skill_path)?; + let (name, description) = parse_skill_frontmatter(&contents); + root_skills.push(SkillSummary { + name: name + .unwrap_or_else(|| entry.file_name().to_string_lossy().to_string()), + description, + source: root.source, + shadowed_by: None, + origin: root.origin, + }); + } + SkillOrigin::LegacyCommandsDir => { + let path = entry.path(); + let markdown_path = if path.is_dir() { + let skill_path = path.join("SKILL.md"); + if !skill_path.is_file() { + continue; + } + skill_path + } else if path + .extension() + .is_some_and(|ext| ext.to_string_lossy().eq_ignore_ascii_case("md")) + { + path + } else { + continue; + }; + + let contents = fs::read_to_string(&markdown_path)?; + let fallback_name = markdown_path.file_stem().map_or_else( + || entry.file_name().to_string_lossy().to_string(), + |stem| stem.to_string_lossy().to_string(), + ); + let (name, description) = parse_skill_frontmatter(&contents); + root_skills.push(SkillSummary { + name: name.unwrap_or(fallback_name), + description, + source: root.source, + shadowed_by: None, + origin: root.origin, + }); + } + } + } + root_skills.sort_by(|left, right| left.name.cmp(&right.name)); + + for mut skill in root_skills { + let key = skill.name.to_ascii_lowercase(); + if let Some(existing) = active_sources.get(&key) { + skill.shadowed_by = Some(*existing); + } else { + active_sources.insert(key, skill.source); + } + skills.push(skill); + } + } + + Ok(skills) +} + +fn parse_toml_string(contents: &str, key: &str) -> Option<String> { + let prefix = format!("{key} ="); + for line in contents.lines() { + let trimmed = line.trim(); + if trimmed.starts_with('#') { + continue; + } + let Some(value) = trimmed.strip_prefix(&prefix) else { + continue; + }; + let value = value.trim(); + let Some(value) = value + .strip_prefix('"') + .and_then(|value| value.strip_suffix('"')) + else { + continue; + }; + if !value.is_empty() { + return Some(value.to_string()); + } + } + None +} + +fn parse_skill_frontmatter(contents: &str) -> (Option<String>, Option<String>) { + let mut lines = contents.lines(); + if lines.next().map(str::trim) != Some("---") { + return (None, None); + } + + let mut name = None; + let mut description = None; + for line in lines { + let trimmed = line.trim(); + if trimmed == "---" { + break; + } + if let Some(value) = trimmed.strip_prefix("name:") { + let value = unquote_frontmatter_value(value.trim()); + if !value.is_empty() { + name = Some(value); + } + continue; + } + if let Some(value) = trimmed.strip_prefix("description:") { + let value = unquote_frontmatter_value(value.trim()); + if !value.is_empty() { + description = Some(value); + } + } + } + + (name, description) +} + +fn unquote_frontmatter_value(value: &str) -> String { + value + .strip_prefix('"') + .and_then(|trimmed| trimmed.strip_suffix('"')) + .or_else(|| { + value + .strip_prefix('\'') + .and_then(|trimmed| trimmed.strip_suffix('\'')) + }) + .unwrap_or(value) + .trim() + .to_string() +} + +fn render_agents_report(agents: &[AgentSummary]) -> String { + if agents.is_empty() { + return "No agents found.".to_string(); + } + + let total_active = agents + .iter() + .filter(|agent| agent.shadowed_by.is_none()) + .count(); + let mut lines = vec![ + "Agents".to_string(), + format!(" {total_active} active agents"), + String::new(), + ]; + + for source in [ + DefinitionSource::ProjectCodex, + DefinitionSource::ProjectClaw, + DefinitionSource::UserCodexHome, + DefinitionSource::UserCodex, + DefinitionSource::UserClaw, + ] { + let group = agents + .iter() + .filter(|agent| agent.source == source) + .collect::<Vec<_>>(); + if group.is_empty() { + continue; + } + + lines.push(format!("{}:", source.label())); + for agent in group { + let detail = agent_detail(agent); + match agent.shadowed_by { + Some(winner) => lines.push(format!(" (shadowed by {}) {detail}", winner.label())), + None => lines.push(format!(" {detail}")), + } + } + lines.push(String::new()); + } + + lines.join("\n").trim_end().to_string() +} + +fn agent_detail(agent: &AgentSummary) -> String { + let mut parts = vec![agent.name.clone()]; + if let Some(description) = &agent.description { + parts.push(description.clone()); + } + if let Some(model) = &agent.model { + parts.push(model.clone()); + } + if let Some(reasoning) = &agent.reasoning_effort { + parts.push(reasoning.clone()); + } + parts.join(" · ") +} + +fn render_skills_report(skills: &[SkillSummary]) -> String { + if skills.is_empty() { + return "No skills found.".to_string(); + } + + let total_active = skills + .iter() + .filter(|skill| skill.shadowed_by.is_none()) + .count(); + let mut lines = vec![ + "Skills".to_string(), + format!(" {total_active} available skills"), + String::new(), + ]; + + for source in [ + DefinitionSource::ProjectCodex, + DefinitionSource::ProjectClaw, + DefinitionSource::UserCodexHome, + DefinitionSource::UserCodex, + DefinitionSource::UserClaw, + ] { + let group = skills + .iter() + .filter(|skill| skill.source == source) + .collect::<Vec<_>>(); + if group.is_empty() { + continue; + } + + lines.push(format!("{}:", source.label())); + for skill in group { + let mut parts = vec![skill.name.clone()]; + if let Some(description) = &skill.description { + parts.push(description.clone()); + } + if let Some(detail) = skill.origin.detail_label() { + parts.push(detail.to_string()); + } + let detail = parts.join(" · "); + match skill.shadowed_by { + Some(winner) => lines.push(format!(" (shadowed by {}) {detail}", winner.label())), + None => lines.push(format!(" {detail}")), + } + } + lines.push(String::new()); + } + + lines.join("\n").trim_end().to_string() +} + +fn normalize_optional_args(args: Option<&str>) -> Option<&str> { + args.map(str::trim).filter(|value| !value.is_empty()) +} + +fn render_agents_usage(unexpected: Option<&str>) -> String { + let mut lines = vec![ + "Agents".to_string(), + " Usage /agents".to_string(), + " Direct CLI claw agents".to_string(), + " Sources .codex/agents, .claw/agents, $CODEX_HOME/agents".to_string(), + ]; + if let Some(args) = unexpected { + lines.push(format!(" Unexpected {args}")); + } + lines.join("\n") +} + +fn render_skills_usage(unexpected: Option<&str>) -> String { + let mut lines = vec![ + "Skills".to_string(), + " Usage /skills".to_string(), + " Direct CLI claw skills".to_string(), + " Sources .codex/skills, .claw/skills, legacy /commands".to_string(), + ]; + if let Some(args) = unexpected { + lines.push(format!(" Unexpected {args}")); + } + lines.join("\n") +} + #[must_use] pub fn handle_slash_command( input: &str, @@ -279,6 +1758,16 @@ pub fn handle_slash_command( session: session.clone(), }), SlashCommand::Status + | SlashCommand::Branch { .. } + | SlashCommand::Bughunter { .. } + | SlashCommand::Worktree { .. } + | SlashCommand::Commit + | SlashCommand::CommitPushPr { .. } + | SlashCommand::Pr { .. } + | SlashCommand::Issue { .. } + | SlashCommand::Ultraplan { .. } + | SlashCommand::Teleport { .. } + | SlashCommand::DebugToolCall | SlashCommand::Model { .. } | SlashCommand::Permissions { .. } | SlashCommand::Clear { .. } @@ -291,6 +1780,9 @@ pub fn handle_slash_command( | SlashCommand::Version | SlashCommand::Export { .. } | SlashCommand::Session { .. } + | SlashCommand::Plugins { .. } + | SlashCommand::Agents { .. } + | SlashCommand::Skills { .. } | SlashCommand::Unknown(_) => None, } } @@ -298,19 +1790,237 @@ pub fn handle_slash_command( #[cfg(test)] mod tests { use super::{ - handle_slash_command, render_slash_command_help, resume_supported_slash_commands, - slash_command_specs, SlashCommand, + handle_branch_slash_command, handle_commit_push_pr_slash_command, + handle_commit_slash_command, handle_plugins_slash_command, handle_slash_command, + handle_worktree_slash_command, load_agents_from_roots, load_skills_from_roots, + render_agents_report, render_plugins_report, render_skills_report, + render_slash_command_help, resume_supported_slash_commands, slash_command_specs, + suggest_slash_commands, CommitPushPrRequest, DefinitionSource, SkillOrigin, SkillRoot, + SlashCommand, }; + use plugins::{PluginKind, PluginManager, PluginManagerConfig, PluginMetadata, PluginSummary}; use runtime::{CompactionConfig, ContentBlock, ConversationMessage, MessageRole, Session}; + use std::env; + use std::fs; + use std::path::{Path, PathBuf}; + use std::process::Command; + use std::sync::{Mutex, OnceLock}; + use std::time::{SystemTime, UNIX_EPOCH}; + #[cfg(unix)] + use std::os::unix::fs::PermissionsExt; + + fn temp_dir(label: &str) -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("commands-plugin-{label}-{nanos}")) + } + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock<Mutex<()>> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock") + } + + fn run_command(cwd: &Path, program: &str, args: &[&str]) -> String { + let output = Command::new(program) + .args(args) + .current_dir(cwd) + .output() + .expect("command should run"); + assert!( + output.status.success(), + "{} {} failed: {}", + program, + args.join(" "), + String::from_utf8_lossy(&output.stderr) + ); + String::from_utf8(output.stdout).expect("stdout should be utf8") + } + + fn init_git_repo(label: &str) -> PathBuf { + let root = temp_dir(label); + fs::create_dir_all(&root).expect("repo root"); + + let init = Command::new("git") + .args(["init", "-b", "main"]) + .current_dir(&root) + .output() + .expect("git init should run"); + if !init.status.success() { + let fallback = Command::new("git") + .arg("init") + .current_dir(&root) + .output() + .expect("fallback git init should run"); + assert!( + fallback.status.success(), + "fallback git init should succeed" + ); + let rename = Command::new("git") + .args(["branch", "-m", "main"]) + .current_dir(&root) + .output() + .expect("git branch -m should run"); + assert!(rename.status.success(), "git branch -m main should succeed"); + } + + run_command(&root, "git", &["config", "user.name", "Claw Tests"]); + run_command(&root, "git", &["config", "user.email", "claw@example.com"]); + fs::write(root.join("README.md"), "seed\n").expect("seed file"); + run_command(&root, "git", &["add", "README.md"]); + run_command(&root, "git", &["commit", "-m", "chore: seed repo"]); + root + } + + fn init_bare_repo(label: &str) -> PathBuf { + let root = temp_dir(label); + let output = Command::new("git") + .args(["init", "--bare"]) + .arg(&root) + .output() + .expect("bare repo should initialize"); + assert!(output.status.success(), "git init --bare should succeed"); + root + } + + #[cfg(unix)] + fn write_fake_gh(bin_dir: &Path, log_path: &Path, url: &str) { + fs::create_dir_all(bin_dir).expect("bin dir"); + let script = format!( + "#!/bin/sh\nif [ \"$1\" = \"--version\" ]; then\n echo 'gh 1.0.0'\n exit 0\nfi\nprintf '%s\\n' \"$*\" >> \"{}\"\nif [ \"$1\" = \"pr\" ] && [ \"$2\" = \"create\" ]; then\n echo '{}'\n exit 0\nfi\nif [ \"$1\" = \"pr\" ] && [ \"$2\" = \"view\" ]; then\n echo '{{\"url\":\"{}\"}}'\n exit 0\nfi\nexit 0\n", + log_path.display(), + url, + url, + ); + let path = bin_dir.join("gh"); + fs::write(&path, script).expect("gh stub"); + let mut permissions = fs::metadata(&path).expect("metadata").permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&path, permissions).expect("chmod"); + } + + fn write_external_plugin(root: &Path, name: &str, version: &str) { + fs::create_dir_all(root.join(".claw-plugin")).expect("manifest dir"); + fs::write( + root.join(".claw-plugin").join("plugin.json"), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"commands plugin\"\n}}" + ), + ) + .expect("write manifest"); + } + + fn write_bundled_plugin(root: &Path, name: &str, version: &str, default_enabled: bool) { + fs::create_dir_all(root.join(".claw-plugin")).expect("manifest dir"); + fs::write( + root.join(".claw-plugin").join("plugin.json"), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"bundled commands plugin\",\n \"defaultEnabled\": {}\n}}", + if default_enabled { "true" } else { "false" } + ), + ) + .expect("write bundled manifest"); + } + + fn write_agent(root: &Path, name: &str, description: &str, model: &str, reasoning: &str) { + fs::create_dir_all(root).expect("agent root"); + fs::write( + root.join(format!("{name}.toml")), + format!( + "name = \"{name}\"\ndescription = \"{description}\"\nmodel = \"{model}\"\nmodel_reasoning_effort = \"{reasoning}\"\n" + ), + ) + .expect("write agent"); + } + + fn write_skill(root: &Path, name: &str, description: &str) { + let skill_root = root.join(name); + fs::create_dir_all(&skill_root).expect("skill root"); + fs::write( + skill_root.join("SKILL.md"), + format!("---\nname: {name}\ndescription: {description}\n---\n\n# {name}\n"), + ) + .expect("write skill"); + } + + fn write_legacy_command(root: &Path, name: &str, description: &str) { + fs::create_dir_all(root).expect("commands root"); + fs::write( + root.join(format!("{name}.md")), + format!("---\nname: {name}\ndescription: {description}\n---\n\n# {name}\n"), + ) + .expect("write command"); + } + + #[allow(clippy::too_many_lines)] #[test] fn parses_supported_slash_commands() { assert_eq!(SlashCommand::parse("/help"), Some(SlashCommand::Help)); assert_eq!(SlashCommand::parse(" /status "), Some(SlashCommand::Status)); assert_eq!( - SlashCommand::parse("/model claude-opus"), + SlashCommand::parse("/bughunter runtime"), + Some(SlashCommand::Bughunter { + scope: Some("runtime".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/branch create feature/demo"), + Some(SlashCommand::Branch { + action: Some("create".to_string()), + target: Some("feature/demo".to_string()), + }) + ); + assert_eq!( + SlashCommand::parse("/worktree add ../demo wt-demo"), + Some(SlashCommand::Worktree { + action: Some("add".to_string()), + path: Some("../demo".to_string()), + branch: Some("wt-demo".to_string()), + }) + ); + assert_eq!(SlashCommand::parse("/commit"), Some(SlashCommand::Commit)); + assert_eq!( + SlashCommand::parse("/commit-push-pr ready for review"), + Some(SlashCommand::CommitPushPr { + context: Some("ready for review".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/pr ready for review"), + Some(SlashCommand::Pr { + context: Some("ready for review".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/issue flaky test"), + Some(SlashCommand::Issue { + context: Some("flaky test".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/ultraplan ship both features"), + Some(SlashCommand::Ultraplan { + task: Some("ship both features".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/teleport conversation.rs"), + Some(SlashCommand::Teleport { + target: Some("conversation.rs".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/debug-tool-call"), + Some(SlashCommand::DebugToolCall) + ); + assert_eq!( + SlashCommand::parse("/model opus"), Some(SlashCommand::Model { - model: Some("claude-opus".to_string()), + model: Some("opus".to_string()), }) ); assert_eq!( @@ -365,29 +2075,85 @@ mod tests { target: Some("abc123".to_string()) }) ); + assert_eq!( + SlashCommand::parse("/plugins install demo"), + Some(SlashCommand::Plugins { + action: Some("install".to_string()), + target: Some("demo".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/plugins list"), + Some(SlashCommand::Plugins { + action: Some("list".to_string()), + target: None + }) + ); + assert_eq!( + SlashCommand::parse("/plugins enable demo"), + Some(SlashCommand::Plugins { + action: Some("enable".to_string()), + target: Some("demo".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/plugins disable demo"), + Some(SlashCommand::Plugins { + action: Some("disable".to_string()), + target: Some("demo".to_string()) + }) + ); } #[test] fn renders_help_from_shared_specs() { let help = render_slash_command_help(); - assert!(help.contains("works with --resume SESSION.json")); + assert!(help.contains("available via claw --resume SESSION.json")); + assert!(help.contains("Core flow")); + assert!(help.contains("Workspace & memory")); + assert!(help.contains("Sessions & output")); + assert!(help.contains("Git & GitHub")); + assert!(help.contains("Automation & discovery")); assert!(help.contains("/help")); assert!(help.contains("/status")); assert!(help.contains("/compact")); + assert!(help.contains("/bughunter [scope]")); + assert!(help.contains("/branch [list|create <name>|switch <name>]")); + assert!(help.contains("/worktree [list|add <path> [branch]|remove <path>|prune]")); + assert!(help.contains("/commit")); + assert!(help.contains("/commit-push-pr [context]")); + assert!(help.contains("/pr [context]")); + assert!(help.contains("/issue [context]")); + assert!(help.contains("/ultraplan [task]")); + assert!(help.contains("/teleport <symbol-or-path>")); + assert!(help.contains("/debug-tool-call")); assert!(help.contains("/model [model]")); assert!(help.contains("/permissions [read-only|workspace-write|danger-full-access]")); assert!(help.contains("/clear [--confirm]")); assert!(help.contains("/cost")); assert!(help.contains("/resume <session-path>")); - assert!(help.contains("/config [env|hooks|model]")); + assert!(help.contains("/config [env|hooks|model|plugins]")); assert!(help.contains("/memory")); assert!(help.contains("/init")); assert!(help.contains("/diff")); assert!(help.contains("/version")); assert!(help.contains("/export [file]")); assert!(help.contains("/session [list|switch <session-id>]")); - assert_eq!(slash_command_specs().len(), 15); - assert_eq!(resume_supported_slash_commands().len(), 11); + assert!(help.contains( + "/plugin [list|install <path>|enable <name>|disable <name>|uninstall <id>|update <id>]" + )); + assert!(help.contains("aliases: /plugins, /marketplace")); + assert!(help.contains("/agents")); + assert!(help.contains("/skills")); + assert_eq!(slash_command_specs().len(), 28); + assert_eq!(resume_supported_slash_commands().len(), 13); + } + + #[test] + fn suggests_close_slash_commands() { + let suggestions = suggest_slash_commands("stats", 3); + assert!(!suggestions.is_empty()); + assert_eq!(suggestions[0], "/status"); } #[test] @@ -435,7 +2201,35 @@ mod tests { assert!(handle_slash_command("/unknown", &session, CompactionConfig::default()).is_none()); assert!(handle_slash_command("/status", &session, CompactionConfig::default()).is_none()); assert!( - handle_slash_command("/model claude", &session, CompactionConfig::default()).is_none() + handle_slash_command("/branch list", &session, CompactionConfig::default()).is_none() + ); + assert!( + handle_slash_command("/bughunter", &session, CompactionConfig::default()).is_none() + ); + assert!( + handle_slash_command("/worktree list", &session, CompactionConfig::default()).is_none() + ); + assert!(handle_slash_command("/commit", &session, CompactionConfig::default()).is_none()); + assert!(handle_slash_command( + "/commit-push-pr review notes", + &session, + CompactionConfig::default() + ) + .is_none()); + assert!(handle_slash_command("/pr", &session, CompactionConfig::default()).is_none()); + assert!(handle_slash_command("/issue", &session, CompactionConfig::default()).is_none()); + assert!( + handle_slash_command("/ultraplan", &session, CompactionConfig::default()).is_none() + ); + assert!( + handle_slash_command("/teleport foo", &session, CompactionConfig::default()).is_none() + ); + assert!( + handle_slash_command("/debug-tool-call", &session, CompactionConfig::default()) + .is_none() + ); + assert!( + handle_slash_command("/model sonnet", &session, CompactionConfig::default()).is_none() ); assert!(handle_slash_command( "/permissions read-only", @@ -468,5 +2262,406 @@ mod tests { assert!( handle_slash_command("/session list", &session, CompactionConfig::default()).is_none() ); + assert!( + handle_slash_command("/plugins list", &session, CompactionConfig::default()).is_none() + ); + } + + #[test] + fn renders_plugins_report_with_name_version_and_status() { + let rendered = render_plugins_report(&[ + PluginSummary { + metadata: PluginMetadata { + id: "demo@external".to_string(), + name: "demo".to_string(), + version: "1.2.3".to_string(), + description: "demo plugin".to_string(), + kind: PluginKind::External, + source: "demo".to_string(), + default_enabled: false, + root: None, + }, + enabled: true, + }, + PluginSummary { + metadata: PluginMetadata { + id: "sample@external".to_string(), + name: "sample".to_string(), + version: "0.9.0".to_string(), + description: "sample plugin".to_string(), + kind: PluginKind::External, + source: "sample".to_string(), + default_enabled: false, + root: None, + }, + enabled: false, + }, + ]); + + assert!(rendered.contains("demo")); + assert!(rendered.contains("v1.2.3")); + assert!(rendered.contains("enabled")); + assert!(rendered.contains("sample")); + assert!(rendered.contains("v0.9.0")); + assert!(rendered.contains("disabled")); + } + + #[test] + fn lists_agents_from_project_and_user_roots() { + let workspace = temp_dir("agents-workspace"); + let project_agents = workspace.join(".codex").join("agents"); + let user_home = temp_dir("agents-home"); + let user_agents = user_home.join(".codex").join("agents"); + + write_agent( + &project_agents, + "planner", + "Project planner", + "gpt-5.4", + "medium", + ); + write_agent( + &user_agents, + "planner", + "User planner", + "gpt-5.4-mini", + "high", + ); + write_agent( + &user_agents, + "verifier", + "Verification agent", + "gpt-5.4-mini", + "high", + ); + + let roots = vec![ + (DefinitionSource::ProjectCodex, project_agents), + (DefinitionSource::UserCodex, user_agents), + ]; + let report = + render_agents_report(&load_agents_from_roots(&roots).expect("agent roots should load")); + + assert!(report.contains("Agents")); + assert!(report.contains("2 active agents")); + assert!(report.contains("Project (.codex):")); + assert!(report.contains("planner · Project planner · gpt-5.4 · medium")); + assert!(report.contains("User (~/.codex):")); + assert!(report.contains("(shadowed by Project (.codex)) planner · User planner")); + assert!(report.contains("verifier · Verification agent · gpt-5.4-mini · high")); + + let _ = fs::remove_dir_all(workspace); + let _ = fs::remove_dir_all(user_home); + } + + #[test] + fn lists_skills_from_project_and_user_roots() { + let workspace = temp_dir("skills-workspace"); + let project_skills = workspace.join(".codex").join("skills"); + let project_commands = workspace.join(".claw").join("commands"); + let user_home = temp_dir("skills-home"); + let user_skills = user_home.join(".codex").join("skills"); + + write_skill(&project_skills, "plan", "Project planning guidance"); + write_legacy_command(&project_commands, "deploy", "Legacy deployment guidance"); + write_skill(&user_skills, "plan", "User planning guidance"); + write_skill(&user_skills, "help", "Help guidance"); + + let roots = vec![ + SkillRoot { + source: DefinitionSource::ProjectCodex, + path: project_skills, + origin: SkillOrigin::SkillsDir, + }, + SkillRoot { + source: DefinitionSource::ProjectClaw, + path: project_commands, + origin: SkillOrigin::LegacyCommandsDir, + }, + SkillRoot { + source: DefinitionSource::UserCodex, + path: user_skills, + origin: SkillOrigin::SkillsDir, + }, + ]; + let report = + render_skills_report(&load_skills_from_roots(&roots).expect("skill roots should load")); + + assert!(report.contains("Skills")); + assert!(report.contains("3 available skills")); + assert!(report.contains("Project (.codex):")); + assert!(report.contains("plan · Project planning guidance")); + assert!(report.contains("Project (.claw):")); + assert!(report.contains("deploy · Legacy deployment guidance · legacy /commands")); + assert!(report.contains("User (~/.codex):")); + assert!(report.contains("(shadowed by Project (.codex)) plan · User planning guidance")); + assert!(report.contains("help · Help guidance")); + + let _ = fs::remove_dir_all(workspace); + let _ = fs::remove_dir_all(user_home); + } + + #[test] + fn agents_and_skills_usage_support_help_and_unexpected_args() { + let cwd = temp_dir("slash-usage"); + + let agents_help = + super::handle_agents_slash_command(Some("help"), &cwd).expect("agents help"); + assert!(agents_help.contains("Usage /agents")); + assert!(agents_help.contains("Direct CLI claw agents")); + + let agents_unexpected = + super::handle_agents_slash_command(Some("show planner"), &cwd).expect("agents usage"); + assert!(agents_unexpected.contains("Unexpected show planner")); + + let skills_help = + super::handle_skills_slash_command(Some("--help"), &cwd).expect("skills help"); + assert!(skills_help.contains("Usage /skills")); + assert!(skills_help.contains("legacy /commands")); + + let skills_unexpected = + super::handle_skills_slash_command(Some("show help"), &cwd).expect("skills usage"); + assert!(skills_unexpected.contains("Unexpected show help")); + + let _ = fs::remove_dir_all(cwd); + } + + #[test] + fn parses_quoted_skill_frontmatter_values() { + let contents = "---\nname: \"hud\"\ndescription: 'Quoted description'\n---\n"; + let (name, description) = super::parse_skill_frontmatter(contents); + assert_eq!(name.as_deref(), Some("hud")); + assert_eq!(description.as_deref(), Some("Quoted description")); + } + + #[test] + fn installs_plugin_from_path_and_lists_it() { + let config_home = temp_dir("home"); + let source_root = temp_dir("source"); + write_external_plugin(&source_root, "demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let install = handle_plugins_slash_command( + Some("install"), + Some(source_root.to_str().expect("utf8 path")), + &mut manager, + ) + .expect("install command should succeed"); + assert!(install.reload_runtime); + assert!(install.message.contains("installed demo@external")); + assert!(install.message.contains("Name demo")); + assert!(install.message.contains("Version 1.0.0")); + assert!(install.message.contains("Status enabled")); + + let list = handle_plugins_slash_command(Some("list"), None, &mut manager) + .expect("list command should succeed"); + assert!(!list.reload_runtime); + assert!(list.message.contains("demo")); + assert!(list.message.contains("v1.0.0")); + assert!(list.message.contains("enabled")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn enables_and_disables_plugin_by_name() { + let config_home = temp_dir("toggle-home"); + let source_root = temp_dir("toggle-source"); + write_external_plugin(&source_root, "demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + handle_plugins_slash_command( + Some("install"), + Some(source_root.to_str().expect("utf8 path")), + &mut manager, + ) + .expect("install command should succeed"); + + let disable = handle_plugins_slash_command(Some("disable"), Some("demo"), &mut manager) + .expect("disable command should succeed"); + assert!(disable.reload_runtime); + assert!(disable.message.contains("disabled demo@external")); + assert!(disable.message.contains("Name demo")); + assert!(disable.message.contains("Status disabled")); + + let list = handle_plugins_slash_command(Some("list"), None, &mut manager) + .expect("list command should succeed"); + assert!(list.message.contains("demo")); + assert!(list.message.contains("disabled")); + + let enable = handle_plugins_slash_command(Some("enable"), Some("demo"), &mut manager) + .expect("enable command should succeed"); + assert!(enable.reload_runtime); + assert!(enable.message.contains("enabled demo@external")); + assert!(enable.message.contains("Name demo")); + assert!(enable.message.contains("Status enabled")); + + let list = handle_plugins_slash_command(Some("list"), None, &mut manager) + .expect("list command should succeed"); + assert!(list.message.contains("demo")); + assert!(list.message.contains("enabled")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn lists_auto_installed_bundled_plugins_with_status() { + let config_home = temp_dir("bundled-home"); + let bundled_root = temp_dir("bundled-root"); + let bundled_plugin = bundled_root.join("starter"); + write_bundled_plugin(&bundled_plugin, "starter", "0.1.0", false); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + let mut manager = PluginManager::new(config); + + let list = handle_plugins_slash_command(Some("list"), None, &mut manager) + .expect("list command should succeed"); + assert!(!list.reload_runtime); + assert!(list.message.contains("starter")); + assert!(list.message.contains("v0.1.0")); + assert!(list.message.contains("disabled")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn branch_and_worktree_commands_manage_git_state() { + // given + let repo = init_git_repo("branch-worktree"); + let worktree_path = repo + .parent() + .expect("repo should have parent") + .join("branch-worktree-linked"); + + // when + let branch_list = + handle_branch_slash_command(Some("list"), None, &repo).expect("branch list succeeds"); + let created = handle_branch_slash_command(Some("create"), Some("feature/demo"), &repo) + .expect("branch create succeeds"); + let switched = handle_branch_slash_command(Some("switch"), Some("main"), &repo) + .expect("branch switch succeeds"); + let added = handle_worktree_slash_command( + Some("add"), + Some(worktree_path.to_str().expect("utf8 path")), + Some("wt-demo"), + &repo, + ) + .expect("worktree add succeeds"); + let listed_worktrees = + handle_worktree_slash_command(Some("list"), None, None, &repo).expect("list succeeds"); + let removed = handle_worktree_slash_command( + Some("remove"), + Some(worktree_path.to_str().expect("utf8 path")), + None, + &repo, + ) + .expect("remove succeeds"); + + // then + assert!(branch_list.contains("main")); + assert!(created.contains("feature/demo")); + assert!(switched.contains("main")); + assert!(added.contains("wt-demo")); + assert!(listed_worktrees.contains(worktree_path.to_str().expect("utf8 path"))); + assert!(removed.contains("Result removed")); + + let _ = fs::remove_dir_all(repo); + let _ = fs::remove_dir_all(worktree_path); + } + + #[test] + fn commit_command_stages_and_commits_changes() { + // given + let repo = init_git_repo("commit-command"); + fs::write(repo.join("notes.txt"), "hello\n").expect("write notes"); + + // when + let report = + handle_commit_slash_command("feat: add notes", &repo).expect("commit succeeds"); + let status = run_command(&repo, "git", &["status", "--short"]); + let message = run_command(&repo, "git", &["log", "-1", "--pretty=%B"]); + + // then + assert!(report.contains("Result created")); + assert!(status.trim().is_empty()); + assert_eq!(message.trim(), "feat: add notes"); + + let _ = fs::remove_dir_all(repo); + } + + #[cfg(unix)] + #[test] + fn commit_push_pr_command_commits_pushes_and_creates_pr() { + // given + let _guard = env_lock(); + let repo = init_git_repo("commit-push-pr"); + let remote = init_bare_repo("commit-push-pr-remote"); + run_command( + &repo, + "git", + &[ + "remote", + "add", + "origin", + remote.to_str().expect("utf8 remote"), + ], + ); + run_command(&repo, "git", &["push", "-u", "origin", "main"]); + fs::write(repo.join("feature.txt"), "feature\n").expect("write feature file"); + + let fake_bin = temp_dir("fake-gh-bin"); + let gh_log = fake_bin.join("gh.log"); + write_fake_gh(&fake_bin, &gh_log, "https://example.com/pr/123"); + + let previous_path = env::var_os("PATH"); + let mut new_path = fake_bin.display().to_string(); + if let Some(path) = &previous_path { + new_path.push(':'); + new_path.push_str(&path.to_string_lossy()); + } + env::set_var("PATH", &new_path); + let previous_safeuser = env::var_os("SAFEUSER"); + env::set_var("SAFEUSER", "tester"); + + let request = CommitPushPrRequest { + commit_message: Some("feat: add feature file".to_string()), + pr_title: "Add feature file".to_string(), + pr_body: "## Summary\n- add feature file".to_string(), + branch_name_hint: "Add feature file".to_string(), + }; + + // when + let report = + handle_commit_push_pr_slash_command(&request, &repo).expect("commit-push-pr succeeds"); + let branch = run_command(&repo, "git", &["branch", "--show-current"]); + let message = run_command(&repo, "git", &["log", "-1", "--pretty=%B"]); + let gh_invocations = fs::read_to_string(&gh_log).expect("gh log should exist"); + + // then + assert!(report.contains("Result created")); + assert!(report.contains("URL https://example.com/pr/123")); + assert_eq!(branch.trim(), "tester/add-feature-file"); + assert_eq!(message.trim(), "feat: add feature file"); + assert!(gh_invocations.contains("pr create")); + assert!(gh_invocations.contains("--base main")); + + if let Some(path) = previous_path { + env::set_var("PATH", path); + } else { + env::remove_var("PATH"); + } + if let Some(safeuser) = previous_safeuser { + env::set_var("SAFEUSER", safeuser); + } else { + env::remove_var("SAFEUSER"); + } + + let _ = fs::remove_dir_all(repo); + let _ = fs::remove_dir_all(remote); + let _ = fs::remove_dir_all(fake_bin); } } diff --git a/crates/lsp/Cargo.toml b/crates/lsp/Cargo.toml new file mode 100644 index 0000000..a2f1aec --- /dev/null +++ b/crates/lsp/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "lsp" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true + +[dependencies] +lsp-types.workspace = true +serde = { version = "1", features = ["derive"] } +serde_json.workspace = true +tokio = { version = "1", features = ["io-util", "macros", "process", "rt", "rt-multi-thread", "sync", "time"] } +url = "2" + +[lints] +workspace = true diff --git a/crates/lsp/src/client.rs b/crates/lsp/src/client.rs new file mode 100644 index 0000000..7ec663b --- /dev/null +++ b/crates/lsp/src/client.rs @@ -0,0 +1,463 @@ +use std::collections::BTreeMap; +use std::path::{Path, PathBuf}; +use std::process::Stdio; +use std::sync::Arc; +use std::sync::atomic::{AtomicI64, Ordering}; + +use lsp_types::{ + Diagnostic, GotoDefinitionResponse, Location, LocationLink, Position, PublishDiagnosticsParams, +}; +use serde_json::{json, Value}; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}; +use tokio::process::{Child, ChildStdin, ChildStdout, Command}; +use tokio::sync::{oneshot, Mutex}; + +use crate::error::LspError; +use crate::types::{LspServerConfig, SymbolLocation}; + +pub(crate) struct LspClient { + config: LspServerConfig, + writer: Mutex<BufWriter<ChildStdin>>, + child: Mutex<Child>, + pending_requests: Arc<Mutex<BTreeMap<i64, oneshot::Sender<Result<Value, LspError>>>>>, + diagnostics: Arc<Mutex<BTreeMap<String, Vec<Diagnostic>>>>, + open_documents: Mutex<BTreeMap<PathBuf, i32>>, + next_request_id: AtomicI64, +} + +impl LspClient { + pub(crate) async fn connect(config: LspServerConfig) -> Result<Self, LspError> { + let mut command = Command::new(&config.command); + command + .args(&config.args) + .current_dir(&config.workspace_root) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .envs(config.env.clone()); + + let mut child = command.spawn()?; + let stdin = child + .stdin + .take() + .ok_or_else(|| LspError::Protocol("missing LSP stdin pipe".to_string()))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| LspError::Protocol("missing LSP stdout pipe".to_string()))?; + let stderr = child.stderr.take(); + + let client = Self { + config, + writer: Mutex::new(BufWriter::new(stdin)), + child: Mutex::new(child), + pending_requests: Arc::new(Mutex::new(BTreeMap::new())), + diagnostics: Arc::new(Mutex::new(BTreeMap::new())), + open_documents: Mutex::new(BTreeMap::new()), + next_request_id: AtomicI64::new(1), + }; + + client.spawn_reader(stdout); + if let Some(stderr) = stderr { + client.spawn_stderr_drain(stderr); + } + client.initialize().await?; + Ok(client) + } + + pub(crate) async fn ensure_document_open(&self, path: &Path) -> Result<(), LspError> { + if self.is_document_open(path).await { + return Ok(()); + } + + let contents = std::fs::read_to_string(path)?; + self.open_document(path, &contents).await + } + + pub(crate) async fn open_document(&self, path: &Path, text: &str) -> Result<(), LspError> { + let uri = file_url(path)?; + let language_id = self + .config + .language_id_for(path) + .ok_or_else(|| LspError::UnsupportedDocument(path.to_path_buf()))?; + + self.notify( + "textDocument/didOpen", + json!({ + "textDocument": { + "uri": uri, + "languageId": language_id, + "version": 1, + "text": text, + } + }), + ) + .await?; + + self.open_documents + .lock() + .await + .insert(path.to_path_buf(), 1); + Ok(()) + } + + pub(crate) async fn change_document(&self, path: &Path, text: &str) -> Result<(), LspError> { + if !self.is_document_open(path).await { + return self.open_document(path, text).await; + } + + let uri = file_url(path)?; + let next_version = { + let mut open_documents = self.open_documents.lock().await; + let version = open_documents + .entry(path.to_path_buf()) + .and_modify(|value| *value += 1) + .or_insert(1); + *version + }; + + self.notify( + "textDocument/didChange", + json!({ + "textDocument": { + "uri": uri, + "version": next_version, + }, + "contentChanges": [{ + "text": text, + }], + }), + ) + .await + } + + pub(crate) async fn save_document(&self, path: &Path) -> Result<(), LspError> { + if !self.is_document_open(path).await { + return Ok(()); + } + + self.notify( + "textDocument/didSave", + json!({ + "textDocument": { + "uri": file_url(path)?, + } + }), + ) + .await + } + + pub(crate) async fn close_document(&self, path: &Path) -> Result<(), LspError> { + if !self.is_document_open(path).await { + return Ok(()); + } + + self.notify( + "textDocument/didClose", + json!({ + "textDocument": { + "uri": file_url(path)?, + } + }), + ) + .await?; + + self.open_documents.lock().await.remove(path); + Ok(()) + } + + pub(crate) async fn is_document_open(&self, path: &Path) -> bool { + self.open_documents.lock().await.contains_key(path) + } + + pub(crate) async fn go_to_definition( + &self, + path: &Path, + position: Position, + ) -> Result<Vec<SymbolLocation>, LspError> { + self.ensure_document_open(path).await?; + let response = self + .request::<Option<GotoDefinitionResponse>>( + "textDocument/definition", + json!({ + "textDocument": { "uri": file_url(path)? }, + "position": position, + }), + ) + .await?; + + Ok(match response { + Some(GotoDefinitionResponse::Scalar(location)) => { + location_to_symbol_locations(vec![location]) + } + Some(GotoDefinitionResponse::Array(locations)) => location_to_symbol_locations(locations), + Some(GotoDefinitionResponse::Link(links)) => location_links_to_symbol_locations(links), + None => Vec::new(), + }) + } + + pub(crate) async fn find_references( + &self, + path: &Path, + position: Position, + include_declaration: bool, + ) -> Result<Vec<SymbolLocation>, LspError> { + self.ensure_document_open(path).await?; + let response = self + .request::<Option<Vec<Location>>>( + "textDocument/references", + json!({ + "textDocument": { "uri": file_url(path)? }, + "position": position, + "context": { + "includeDeclaration": include_declaration, + }, + }), + ) + .await?; + + Ok(location_to_symbol_locations(response.unwrap_or_default())) + } + + pub(crate) async fn diagnostics_snapshot(&self) -> BTreeMap<String, Vec<Diagnostic>> { + self.diagnostics.lock().await.clone() + } + + pub(crate) async fn shutdown(&self) -> Result<(), LspError> { + let _ = self.request::<Value>("shutdown", json!({})).await; + let _ = self.notify("exit", Value::Null).await; + + let mut child = self.child.lock().await; + if child.kill().await.is_err() { + let _ = child.wait().await; + return Ok(()); + } + let _ = child.wait().await; + Ok(()) + } + + fn spawn_reader(&self, stdout: ChildStdout) { + let diagnostics = &self.diagnostics; + let pending_requests = &self.pending_requests; + + let diagnostics = diagnostics.clone(); + let pending_requests = pending_requests.clone(); + tokio::spawn(async move { + let mut reader = BufReader::new(stdout); + let result = async { + while let Some(message) = read_message(&mut reader).await? { + if let Some(id) = message.get("id").and_then(Value::as_i64) { + let response = if let Some(error) = message.get("error") { + Err(LspError::Protocol(error.to_string())) + } else { + Ok(message.get("result").cloned().unwrap_or(Value::Null)) + }; + + if let Some(sender) = pending_requests.lock().await.remove(&id) { + let _ = sender.send(response); + } + continue; + } + + let Some(method) = message.get("method").and_then(Value::as_str) else { + continue; + }; + if method != "textDocument/publishDiagnostics" { + continue; + } + + let params = message.get("params").cloned().unwrap_or(Value::Null); + let notification = serde_json::from_value::<PublishDiagnosticsParams>(params)?; + let mut diagnostics_map = diagnostics.lock().await; + if notification.diagnostics.is_empty() { + diagnostics_map.remove(¬ification.uri.to_string()); + } else { + diagnostics_map.insert(notification.uri.to_string(), notification.diagnostics); + } + } + Ok::<(), LspError>(()) + } + .await; + + if let Err(error) = result { + let mut pending = pending_requests.lock().await; + let drained = pending + .iter() + .map(|(id, _)| *id) + .collect::<Vec<_>>(); + for id in drained { + if let Some(sender) = pending.remove(&id) { + let _ = sender.send(Err(LspError::Protocol(error.to_string()))); + } + } + } + }); + } + + fn spawn_stderr_drain<R>(&self, stderr: R) + where + R: AsyncRead + Unpin + Send + 'static, + { + tokio::spawn(async move { + let mut reader = BufReader::new(stderr); + let mut sink = Vec::new(); + let _ = reader.read_to_end(&mut sink).await; + }); + } + + async fn initialize(&self) -> Result<(), LspError> { + let workspace_uri = file_url(&self.config.workspace_root)?; + let _ = self + .request::<Value>( + "initialize", + json!({ + "processId": std::process::id(), + "rootUri": workspace_uri, + "rootPath": self.config.workspace_root, + "workspaceFolders": [{ + "uri": workspace_uri, + "name": self.config.name, + }], + "initializationOptions": self.config.initialization_options.clone().unwrap_or(Value::Null), + "capabilities": { + "textDocument": { + "publishDiagnostics": { + "relatedInformation": true, + }, + "definition": { + "linkSupport": true, + }, + "references": {} + }, + "workspace": { + "configuration": false, + "workspaceFolders": true, + }, + "general": { + "positionEncodings": ["utf-16"], + } + } + }), + ) + .await?; + self.notify("initialized", json!({})).await + } + + async fn request<T>(&self, method: &str, params: Value) -> Result<T, LspError> + where + T: for<'de> serde::Deserialize<'de>, + { + let id = self.next_request_id.fetch_add(1, Ordering::Relaxed); + let (sender, receiver) = oneshot::channel(); + self.pending_requests.lock().await.insert(id, sender); + + if let Err(error) = self + .send_message(&json!({ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": params, + })) + .await + { + self.pending_requests.lock().await.remove(&id); + return Err(error); + } + + let response = receiver + .await + .map_err(|_| LspError::Protocol(format!("request channel closed for {method}")))??; + Ok(serde_json::from_value(response)?) + } + + async fn notify(&self, method: &str, params: Value) -> Result<(), LspError> { + self.send_message(&json!({ + "jsonrpc": "2.0", + "method": method, + "params": params, + })) + .await + } + + async fn send_message(&self, payload: &Value) -> Result<(), LspError> { + let body = serde_json::to_vec(payload)?; + let mut writer = self.writer.lock().await; + writer + .write_all(format!("Content-Length: {}\r\n\r\n", body.len()).as_bytes()) + .await?; + writer.write_all(&body).await?; + writer.flush().await?; + Ok(()) + } +} + +async fn read_message<R>(reader: &mut BufReader<R>) -> Result<Option<Value>, LspError> +where + R: AsyncRead + Unpin, +{ + let mut content_length = None; + + loop { + let mut line = String::new(); + let read = reader.read_line(&mut line).await?; + if read == 0 { + return Ok(None); + } + + if line == "\r\n" { + break; + } + + let trimmed = line.trim_end_matches(['\r', '\n']); + if let Some((name, value)) = trimmed.split_once(':') { + if name.eq_ignore_ascii_case("Content-Length") { + let value = value.trim().to_string(); + content_length = Some( + value + .parse::<usize>() + .map_err(|_| LspError::InvalidContentLength(value.clone()))?, + ); + } + } else { + return Err(LspError::InvalidHeader(trimmed.to_string())); + } + } + + let content_length = content_length.ok_or(LspError::MissingContentLength)?; + let mut body = vec![0_u8; content_length]; + reader.read_exact(&mut body).await?; + Ok(Some(serde_json::from_slice(&body)?)) +} + +fn file_url(path: &Path) -> Result<String, LspError> { + url::Url::from_file_path(path) + .map(|url| url.to_string()) + .map_err(|()| LspError::PathToUrl(path.to_path_buf())) +} + +fn location_to_symbol_locations(locations: Vec<Location>) -> Vec<SymbolLocation> { + locations + .into_iter() + .filter_map(|location| { + uri_to_path(&location.uri.to_string()).map(|path| SymbolLocation { + path, + range: location.range, + }) + }) + .collect() +} + +fn location_links_to_symbol_locations(links: Vec<LocationLink>) -> Vec<SymbolLocation> { + links.into_iter() + .filter_map(|link| { + uri_to_path(&link.target_uri.to_string()).map(|path| SymbolLocation { + path, + range: link.target_selection_range, + }) + }) + .collect() +} + +fn uri_to_path(uri: &str) -> Option<PathBuf> { + url::Url::parse(uri).ok()?.to_file_path().ok() +} diff --git a/crates/lsp/src/error.rs b/crates/lsp/src/error.rs new file mode 100644 index 0000000..6be1413 --- /dev/null +++ b/crates/lsp/src/error.rs @@ -0,0 +1,62 @@ +use std::fmt::{Display, Formatter}; +use std::path::PathBuf; + +#[derive(Debug)] +pub enum LspError { + Io(std::io::Error), + Json(serde_json::Error), + InvalidHeader(String), + MissingContentLength, + InvalidContentLength(String), + UnsupportedDocument(PathBuf), + UnknownServer(String), + DuplicateExtension { + extension: String, + existing_server: String, + new_server: String, + }, + PathToUrl(PathBuf), + Protocol(String), +} + +impl Display for LspError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Io(error) => write!(f, "{error}"), + Self::Json(error) => write!(f, "{error}"), + Self::InvalidHeader(header) => write!(f, "invalid LSP header: {header}"), + Self::MissingContentLength => write!(f, "missing LSP Content-Length header"), + Self::InvalidContentLength(value) => { + write!(f, "invalid LSP Content-Length value: {value}") + } + Self::UnsupportedDocument(path) => { + write!(f, "no LSP server configured for {}", path.display()) + } + Self::UnknownServer(name) => write!(f, "unknown LSP server: {name}"), + Self::DuplicateExtension { + extension, + existing_server, + new_server, + } => write!( + f, + "duplicate LSP extension mapping for {extension}: {existing_server} and {new_server}" + ), + Self::PathToUrl(path) => write!(f, "failed to convert path to file URL: {}", path.display()), + Self::Protocol(message) => write!(f, "LSP protocol error: {message}"), + } + } +} + +impl std::error::Error for LspError {} + +impl From<std::io::Error> for LspError { + fn from(value: std::io::Error) -> Self { + Self::Io(value) + } +} + +impl From<serde_json::Error> for LspError { + fn from(value: serde_json::Error) -> Self { + Self::Json(value) + } +} diff --git a/crates/lsp/src/lib.rs b/crates/lsp/src/lib.rs new file mode 100644 index 0000000..9b1b099 --- /dev/null +++ b/crates/lsp/src/lib.rs @@ -0,0 +1,283 @@ +mod client; +mod error; +mod manager; +mod types; + +pub use error::LspError; +pub use manager::LspManager; +pub use types::{ + FileDiagnostics, LspContextEnrichment, LspServerConfig, SymbolLocation, WorkspaceDiagnostics, +}; + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + use std::fs; + use std::path::PathBuf; + use std::process::Command; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use lsp_types::{DiagnosticSeverity, Position}; + + use crate::{LspManager, LspServerConfig}; + + fn temp_dir(label: &str) -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("lsp-{label}-{nanos}")) + } + + fn python3_path() -> Option<String> { + let candidates = ["python3", "/usr/bin/python3"]; + candidates.iter().find_map(|candidate| { + Command::new(candidate) + .arg("--version") + .output() + .ok() + .filter(|output| output.status.success()) + .map(|_| (*candidate).to_string()) + }) + } + + fn write_mock_server_script(root: &std::path::Path) -> PathBuf { + let script_path = root.join("mock_lsp_server.py"); + fs::write( + &script_path, + r#"import json +import sys + + +def read_message(): + headers = {} + while True: + line = sys.stdin.buffer.readline() + if not line: + return None + if line == b"\r\n": + break + key, value = line.decode("utf-8").split(":", 1) + headers[key.lower()] = value.strip() + length = int(headers["content-length"]) + body = sys.stdin.buffer.read(length) + return json.loads(body) + + +def write_message(payload): + raw = json.dumps(payload).encode("utf-8") + sys.stdout.buffer.write(f"Content-Length: {len(raw)}\r\n\r\n".encode("utf-8")) + sys.stdout.buffer.write(raw) + sys.stdout.buffer.flush() + + +while True: + message = read_message() + if message is None: + break + + method = message.get("method") + if method == "initialize": + write_message({ + "jsonrpc": "2.0", + "id": message["id"], + "result": { + "capabilities": { + "definitionProvider": True, + "referencesProvider": True, + "textDocumentSync": 1, + } + }, + }) + elif method == "initialized": + continue + elif method == "textDocument/didOpen": + document = message["params"]["textDocument"] + write_message({ + "jsonrpc": "2.0", + "method": "textDocument/publishDiagnostics", + "params": { + "uri": document["uri"], + "diagnostics": [ + { + "range": { + "start": {"line": 0, "character": 0}, + "end": {"line": 0, "character": 3}, + }, + "severity": 1, + "source": "mock-server", + "message": "mock error", + } + ], + }, + }) + elif method == "textDocument/didChange": + continue + elif method == "textDocument/didSave": + continue + elif method == "textDocument/definition": + uri = message["params"]["textDocument"]["uri"] + write_message({ + "jsonrpc": "2.0", + "id": message["id"], + "result": [ + { + "uri": uri, + "range": { + "start": {"line": 0, "character": 0}, + "end": {"line": 0, "character": 3}, + }, + } + ], + }) + elif method == "textDocument/references": + uri = message["params"]["textDocument"]["uri"] + write_message({ + "jsonrpc": "2.0", + "id": message["id"], + "result": [ + { + "uri": uri, + "range": { + "start": {"line": 0, "character": 0}, + "end": {"line": 0, "character": 3}, + }, + }, + { + "uri": uri, + "range": { + "start": {"line": 1, "character": 4}, + "end": {"line": 1, "character": 7}, + }, + }, + ], + }) + elif method == "shutdown": + write_message({"jsonrpc": "2.0", "id": message["id"], "result": None}) + elif method == "exit": + break +"#, + ) + .expect("mock server should be written"); + script_path + } + + async fn wait_for_diagnostics(manager: &LspManager) { + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if manager + .collect_workspace_diagnostics() + .await + .expect("diagnostics snapshot should load") + .total_diagnostics() + > 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("diagnostics should arrive from mock server"); + } + + #[tokio::test(flavor = "current_thread")] + async fn collects_diagnostics_and_symbol_navigation_from_mock_server() { + let Some(python) = python3_path() else { + return; + }; + + // given + let root = temp_dir("manager"); + fs::create_dir_all(root.join("src")).expect("workspace root should exist"); + let script_path = write_mock_server_script(&root); + let source_path = root.join("src").join("main.rs"); + fs::write(&source_path, "fn main() {}\nlet value = 1;\n").expect("source file should exist"); + let manager = LspManager::new(vec![LspServerConfig { + name: "rust-analyzer".to_string(), + command: python, + args: vec![script_path.display().to_string()], + env: BTreeMap::new(), + workspace_root: root.clone(), + initialization_options: None, + extension_to_language: BTreeMap::from([(".rs".to_string(), "rust".to_string())]), + }]) + .expect("manager should build"); + manager + .open_document(&source_path, &fs::read_to_string(&source_path).expect("source read should succeed")) + .await + .expect("document should open"); + wait_for_diagnostics(&manager).await; + + // when + let diagnostics = manager + .collect_workspace_diagnostics() + .await + .expect("diagnostics should be available"); + let definitions = manager + .go_to_definition(&source_path, Position::new(0, 0)) + .await + .expect("definition request should succeed"); + let references = manager + .find_references(&source_path, Position::new(0, 0), true) + .await + .expect("references request should succeed"); + + // then + assert_eq!(diagnostics.files.len(), 1); + assert_eq!(diagnostics.total_diagnostics(), 1); + assert_eq!(diagnostics.files[0].diagnostics[0].severity, Some(DiagnosticSeverity::ERROR)); + assert_eq!(definitions.len(), 1); + assert_eq!(definitions[0].start_line(), 1); + assert_eq!(references.len(), 2); + + manager.shutdown().await.expect("shutdown should succeed"); + fs::remove_dir_all(root).expect("temp workspace should be removed"); + } + + #[tokio::test(flavor = "current_thread")] + async fn renders_runtime_context_enrichment_for_prompt_usage() { + let Some(python) = python3_path() else { + return; + }; + + // given + let root = temp_dir("prompt"); + fs::create_dir_all(root.join("src")).expect("workspace root should exist"); + let script_path = write_mock_server_script(&root); + let source_path = root.join("src").join("lib.rs"); + fs::write(&source_path, "pub fn answer() -> i32 { 42 }\n").expect("source file should exist"); + let manager = LspManager::new(vec![LspServerConfig { + name: "rust-analyzer".to_string(), + command: python, + args: vec![script_path.display().to_string()], + env: BTreeMap::new(), + workspace_root: root.clone(), + initialization_options: None, + extension_to_language: BTreeMap::from([(".rs".to_string(), "rust".to_string())]), + }]) + .expect("manager should build"); + manager + .open_document(&source_path, &fs::read_to_string(&source_path).expect("source read should succeed")) + .await + .expect("document should open"); + wait_for_diagnostics(&manager).await; + + // when + let enrichment = manager + .context_enrichment(&source_path, Position::new(0, 0)) + .await + .expect("context enrichment should succeed"); + let rendered = enrichment.render_prompt_section(); + + // then + assert!(rendered.contains("# LSP context")); + assert!(rendered.contains("Workspace diagnostics: 1 across 1 file(s)")); + assert!(rendered.contains("Definitions:")); + assert!(rendered.contains("References:")); + assert!(rendered.contains("mock error")); + + manager.shutdown().await.expect("shutdown should succeed"); + fs::remove_dir_all(root).expect("temp workspace should be removed"); + } +} diff --git a/crates/lsp/src/manager.rs b/crates/lsp/src/manager.rs new file mode 100644 index 0000000..3c99f96 --- /dev/null +++ b/crates/lsp/src/manager.rs @@ -0,0 +1,191 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::path::Path; +use std::sync::Arc; + +use lsp_types::Position; +use tokio::sync::Mutex; + +use crate::client::LspClient; +use crate::error::LspError; +use crate::types::{ + normalize_extension, FileDiagnostics, LspContextEnrichment, LspServerConfig, SymbolLocation, + WorkspaceDiagnostics, +}; + +pub struct LspManager { + server_configs: BTreeMap<String, LspServerConfig>, + extension_map: BTreeMap<String, String>, + clients: Mutex<BTreeMap<String, Arc<LspClient>>>, +} + +impl LspManager { + pub fn new(server_configs: Vec<LspServerConfig>) -> Result<Self, LspError> { + let mut configs_by_name = BTreeMap::new(); + let mut extension_map = BTreeMap::new(); + + for config in server_configs { + for extension in config.extension_to_language.keys() { + let normalized = normalize_extension(extension); + if let Some(existing_server) = extension_map.insert(normalized.clone(), config.name.clone()) { + return Err(LspError::DuplicateExtension { + extension: normalized, + existing_server, + new_server: config.name.clone(), + }); + } + } + configs_by_name.insert(config.name.clone(), config); + } + + Ok(Self { + server_configs: configs_by_name, + extension_map, + clients: Mutex::new(BTreeMap::new()), + }) + } + + #[must_use] + pub fn supports_path(&self, path: &Path) -> bool { + path.extension().is_some_and(|extension| { + let normalized = normalize_extension(extension.to_string_lossy().as_ref()); + self.extension_map.contains_key(&normalized) + }) + } + + pub async fn open_document(&self, path: &Path, text: &str) -> Result<(), LspError> { + self.client_for_path(path).await?.open_document(path, text).await + } + + pub async fn sync_document_from_disk(&self, path: &Path) -> Result<(), LspError> { + let contents = std::fs::read_to_string(path)?; + self.change_document(path, &contents).await?; + self.save_document(path).await + } + + pub async fn change_document(&self, path: &Path, text: &str) -> Result<(), LspError> { + self.client_for_path(path).await?.change_document(path, text).await + } + + pub async fn save_document(&self, path: &Path) -> Result<(), LspError> { + self.client_for_path(path).await?.save_document(path).await + } + + pub async fn close_document(&self, path: &Path) -> Result<(), LspError> { + self.client_for_path(path).await?.close_document(path).await + } + + pub async fn go_to_definition( + &self, + path: &Path, + position: Position, + ) -> Result<Vec<SymbolLocation>, LspError> { + let mut locations = self.client_for_path(path).await?.go_to_definition(path, position).await?; + dedupe_locations(&mut locations); + Ok(locations) + } + + pub async fn find_references( + &self, + path: &Path, + position: Position, + include_declaration: bool, + ) -> Result<Vec<SymbolLocation>, LspError> { + let mut locations = self + .client_for_path(path) + .await? + .find_references(path, position, include_declaration) + .await?; + dedupe_locations(&mut locations); + Ok(locations) + } + + pub async fn collect_workspace_diagnostics(&self) -> Result<WorkspaceDiagnostics, LspError> { + let clients = self.clients.lock().await.values().cloned().collect::<Vec<_>>(); + let mut files = Vec::new(); + + for client in clients { + for (uri, diagnostics) in client.diagnostics_snapshot().await { + let Ok(path) = url::Url::parse(&uri) + .and_then(|url| url.to_file_path().map_err(|()| url::ParseError::RelativeUrlWithoutBase)) + else { + continue; + }; + if diagnostics.is_empty() { + continue; + } + files.push(FileDiagnostics { + path, + uri, + diagnostics, + }); + } + } + + files.sort_by(|left, right| left.path.cmp(&right.path)); + Ok(WorkspaceDiagnostics { files }) + } + + pub async fn context_enrichment( + &self, + path: &Path, + position: Position, + ) -> Result<LspContextEnrichment, LspError> { + Ok(LspContextEnrichment { + file_path: path.to_path_buf(), + diagnostics: self.collect_workspace_diagnostics().await?, + definitions: self.go_to_definition(path, position).await?, + references: self.find_references(path, position, true).await?, + }) + } + + pub async fn shutdown(&self) -> Result<(), LspError> { + let mut clients = self.clients.lock().await; + let drained = clients.values().cloned().collect::<Vec<_>>(); + clients.clear(); + drop(clients); + + for client in drained { + client.shutdown().await?; + } + Ok(()) + } + + async fn client_for_path(&self, path: &Path) -> Result<Arc<LspClient>, LspError> { + let extension = path + .extension() + .map(|extension| normalize_extension(extension.to_string_lossy().as_ref())) + .ok_or_else(|| LspError::UnsupportedDocument(path.to_path_buf()))?; + let server_name = self + .extension_map + .get(&extension) + .cloned() + .ok_or_else(|| LspError::UnsupportedDocument(path.to_path_buf()))?; + + let mut clients = self.clients.lock().await; + if let Some(client) = clients.get(&server_name) { + return Ok(client.clone()); + } + + let config = self + .server_configs + .get(&server_name) + .cloned() + .ok_or_else(|| LspError::UnknownServer(server_name.clone()))?; + let client = Arc::new(LspClient::connect(config).await?); + clients.insert(server_name, client.clone()); + Ok(client) + } +} + +fn dedupe_locations(locations: &mut Vec<SymbolLocation>) { + let mut seen = BTreeSet::new(); + locations.retain(|location| { + seen.insert(( + location.path.clone(), + location.range.start.line, + location.range.start.character, + location.range.end.line, + location.range.end.character, + )) + }); +} diff --git a/crates/lsp/src/types.rs b/crates/lsp/src/types.rs new file mode 100644 index 0000000..ab2573f --- /dev/null +++ b/crates/lsp/src/types.rs @@ -0,0 +1,186 @@ +use std::collections::BTreeMap; +use std::fmt::{Display, Formatter}; +use std::path::{Path, PathBuf}; + +use lsp_types::{Diagnostic, Range}; +use serde_json::Value; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LspServerConfig { + pub name: String, + pub command: String, + pub args: Vec<String>, + pub env: BTreeMap<String, String>, + pub workspace_root: PathBuf, + pub initialization_options: Option<Value>, + pub extension_to_language: BTreeMap<String, String>, +} + +impl LspServerConfig { + #[must_use] + pub fn language_id_for(&self, path: &Path) -> Option<&str> { + let extension = normalize_extension(path.extension()?.to_string_lossy().as_ref()); + self.extension_to_language + .get(&extension) + .map(String::as_str) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct FileDiagnostics { + pub path: PathBuf, + pub uri: String, + pub diagnostics: Vec<Diagnostic>, +} + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct WorkspaceDiagnostics { + pub files: Vec<FileDiagnostics>, +} + +impl WorkspaceDiagnostics { + #[must_use] + pub fn is_empty(&self) -> bool { + self.files.is_empty() + } + + #[must_use] + pub fn total_diagnostics(&self) -> usize { + self.files.iter().map(|file| file.diagnostics.len()).sum() + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SymbolLocation { + pub path: PathBuf, + pub range: Range, +} + +impl SymbolLocation { + #[must_use] + pub fn start_line(&self) -> u32 { + self.range.start.line + 1 + } + + #[must_use] + pub fn start_character(&self) -> u32 { + self.range.start.character + 1 + } +} + +impl Display for SymbolLocation { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}:{}:{}", + self.path.display(), + self.start_line(), + self.start_character() + ) + } +} + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct LspContextEnrichment { + pub file_path: PathBuf, + pub diagnostics: WorkspaceDiagnostics, + pub definitions: Vec<SymbolLocation>, + pub references: Vec<SymbolLocation>, +} + +impl LspContextEnrichment { + #[must_use] + pub fn is_empty(&self) -> bool { + self.diagnostics.is_empty() && self.definitions.is_empty() && self.references.is_empty() + } + + #[must_use] + pub fn render_prompt_section(&self) -> String { + const MAX_RENDERED_DIAGNOSTICS: usize = 12; + const MAX_RENDERED_LOCATIONS: usize = 12; + + let mut lines = vec!["# LSP context".to_string()]; + lines.push(format!(" - Focus file: {}", self.file_path.display())); + lines.push(format!( + " - Workspace diagnostics: {} across {} file(s)", + self.diagnostics.total_diagnostics(), + self.diagnostics.files.len() + )); + + if !self.diagnostics.files.is_empty() { + lines.push(String::new()); + lines.push("Diagnostics:".to_string()); + let mut rendered = 0usize; + for file in &self.diagnostics.files { + for diagnostic in &file.diagnostics { + if rendered == MAX_RENDERED_DIAGNOSTICS { + lines.push(" - Additional diagnostics omitted for brevity.".to_string()); + break; + } + let severity = diagnostic_severity_label(diagnostic.severity); + lines.push(format!( + " - {}:{}:{} [{}] {}", + file.path.display(), + diagnostic.range.start.line + 1, + diagnostic.range.start.character + 1, + severity, + diagnostic.message.replace('\n', " ") + )); + rendered += 1; + } + if rendered == MAX_RENDERED_DIAGNOSTICS { + break; + } + } + } + + if !self.definitions.is_empty() { + lines.push(String::new()); + lines.push("Definitions:".to_string()); + lines.extend( + self.definitions + .iter() + .take(MAX_RENDERED_LOCATIONS) + .map(|location| format!(" - {location}")), + ); + if self.definitions.len() > MAX_RENDERED_LOCATIONS { + lines.push(" - Additional definitions omitted for brevity.".to_string()); + } + } + + if !self.references.is_empty() { + lines.push(String::new()); + lines.push("References:".to_string()); + lines.extend( + self.references + .iter() + .take(MAX_RENDERED_LOCATIONS) + .map(|location| format!(" - {location}")), + ); + if self.references.len() > MAX_RENDERED_LOCATIONS { + lines.push(" - Additional references omitted for brevity.".to_string()); + } + } + + lines.join("\n") + } +} + +#[must_use] +pub(crate) fn normalize_extension(extension: &str) -> String { + if extension.starts_with('.') { + extension.to_ascii_lowercase() + } else { + format!(".{}", extension.to_ascii_lowercase()) + } +} + +fn diagnostic_severity_label(severity: Option<lsp_types::DiagnosticSeverity>) -> &'static str { + match severity { + Some(lsp_types::DiagnosticSeverity::ERROR) => "error", + Some(lsp_types::DiagnosticSeverity::WARNING) => "warning", + Some(lsp_types::DiagnosticSeverity::INFORMATION) => "info", + Some(lsp_types::DiagnosticSeverity::HINT) => "hint", + _ => "unknown", + } +} diff --git a/crates/plugins/Cargo.toml b/crates/plugins/Cargo.toml new file mode 100644 index 0000000..11213b5 --- /dev/null +++ b/crates/plugins/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "plugins" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true + +[dependencies] +serde = { version = "1", features = ["derive"] } +serde_json.workspace = true + +[lints] +workspace = true diff --git a/crates/plugins/bundled/example-bundled/.claw-plugin/plugin.json b/crates/plugins/bundled/example-bundled/.claw-plugin/plugin.json new file mode 100644 index 0000000..81a4220 --- /dev/null +++ b/crates/plugins/bundled/example-bundled/.claw-plugin/plugin.json @@ -0,0 +1,10 @@ +{ + "name": "example-bundled", + "version": "0.1.0", + "description": "Example bundled plugin scaffold for the Rust plugin system", + "defaultEnabled": false, + "hooks": { + "PreToolUse": ["./hooks/pre.sh"], + "PostToolUse": ["./hooks/post.sh"] + } +} diff --git a/crates/plugins/bundled/example-bundled/hooks/post.sh b/crates/plugins/bundled/example-bundled/hooks/post.sh new file mode 100644 index 0000000..c9eb66f --- /dev/null +++ b/crates/plugins/bundled/example-bundled/hooks/post.sh @@ -0,0 +1,2 @@ +#!/bin/sh +printf '%s\n' 'example bundled post hook' diff --git a/crates/plugins/bundled/example-bundled/hooks/pre.sh b/crates/plugins/bundled/example-bundled/hooks/pre.sh new file mode 100644 index 0000000..af6b46b --- /dev/null +++ b/crates/plugins/bundled/example-bundled/hooks/pre.sh @@ -0,0 +1,2 @@ +#!/bin/sh +printf '%s\n' 'example bundled pre hook' diff --git a/crates/plugins/bundled/sample-hooks/.claw-plugin/plugin.json b/crates/plugins/bundled/sample-hooks/.claw-plugin/plugin.json new file mode 100644 index 0000000..555f5df --- /dev/null +++ b/crates/plugins/bundled/sample-hooks/.claw-plugin/plugin.json @@ -0,0 +1,10 @@ +{ + "name": "sample-hooks", + "version": "0.1.0", + "description": "Bundled sample plugin scaffold for hook integration tests.", + "defaultEnabled": false, + "hooks": { + "PreToolUse": ["./hooks/pre.sh"], + "PostToolUse": ["./hooks/post.sh"] + } +} diff --git a/crates/plugins/bundled/sample-hooks/hooks/post.sh b/crates/plugins/bundled/sample-hooks/hooks/post.sh new file mode 100644 index 0000000..c968e6d --- /dev/null +++ b/crates/plugins/bundled/sample-hooks/hooks/post.sh @@ -0,0 +1,2 @@ +#!/bin/sh +printf 'sample bundled post hook' diff --git a/crates/plugins/bundled/sample-hooks/hooks/pre.sh b/crates/plugins/bundled/sample-hooks/hooks/pre.sh new file mode 100644 index 0000000..9560881 --- /dev/null +++ b/crates/plugins/bundled/sample-hooks/hooks/pre.sh @@ -0,0 +1,2 @@ +#!/bin/sh +printf 'sample bundled pre hook' diff --git a/crates/plugins/src/hooks.rs b/crates/plugins/src/hooks.rs new file mode 100644 index 0000000..fde23e8 --- /dev/null +++ b/crates/plugins/src/hooks.rs @@ -0,0 +1,395 @@ +use std::ffi::OsStr; +use std::path::Path; +use std::process::Command; + +use serde_json::json; + +use crate::{PluginError, PluginHooks, PluginRegistry}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HookEvent { + PreToolUse, + PostToolUse, +} + +impl HookEvent { + fn as_str(self) -> &'static str { + match self { + Self::PreToolUse => "PreToolUse", + Self::PostToolUse => "PostToolUse", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HookRunResult { + denied: bool, + messages: Vec<String>, +} + +impl HookRunResult { + #[must_use] + pub fn allow(messages: Vec<String>) -> Self { + Self { + denied: false, + messages, + } + } + + #[must_use] + pub fn is_denied(&self) -> bool { + self.denied + } + + #[must_use] + pub fn messages(&self) -> &[String] { + &self.messages + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct HookRunner { + hooks: PluginHooks, +} + +impl HookRunner { + #[must_use] + pub fn new(hooks: PluginHooks) -> Self { + Self { hooks } + } + + pub fn from_registry(plugin_registry: &PluginRegistry) -> Result<Self, PluginError> { + Ok(Self::new(plugin_registry.aggregated_hooks()?)) + } + + #[must_use] + pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { + self.run_commands( + HookEvent::PreToolUse, + &self.hooks.pre_tool_use, + tool_name, + tool_input, + None, + false, + ) + } + + #[must_use] + pub fn run_post_tool_use( + &self, + tool_name: &str, + tool_input: &str, + tool_output: &str, + is_error: bool, + ) -> HookRunResult { + self.run_commands( + HookEvent::PostToolUse, + &self.hooks.post_tool_use, + tool_name, + tool_input, + Some(tool_output), + is_error, + ) + } + + fn run_commands( + &self, + event: HookEvent, + commands: &[String], + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, + ) -> HookRunResult { + if commands.is_empty() { + return HookRunResult::allow(Vec::new()); + } + + let payload = json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_output": tool_output, + "tool_result_is_error": is_error, + }) + .to_string(); + + let mut messages = Vec::new(); + + for command in commands { + match self.run_command( + command, + event, + tool_name, + tool_input, + tool_output, + is_error, + &payload, + ) { + HookCommandOutcome::Allow { message } => { + if let Some(message) = message { + messages.push(message); + } + } + HookCommandOutcome::Deny { message } => { + messages.push(message.unwrap_or_else(|| { + format!("{} hook denied tool `{tool_name}`", event.as_str()) + })); + return HookRunResult { + denied: true, + messages, + }; + } + HookCommandOutcome::Warn { message } => messages.push(message), + } + } + + HookRunResult::allow(messages) + } + + #[allow(clippy::too_many_arguments, clippy::unused_self)] + fn run_command( + &self, + command: &str, + event: HookEvent, + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, + payload: &str, + ) -> HookCommandOutcome { + let mut child = shell_command(command); + child.stdin(std::process::Stdio::piped()); + child.stdout(std::process::Stdio::piped()); + child.stderr(std::process::Stdio::piped()); + child.env("HOOK_EVENT", event.as_str()); + child.env("HOOK_TOOL_NAME", tool_name); + child.env("HOOK_TOOL_INPUT", tool_input); + child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" }); + if let Some(tool_output) = tool_output { + child.env("HOOK_TOOL_OUTPUT", tool_output); + } + + match child.output_with_stdin(payload.as_bytes()) { + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + let message = (!stdout.is_empty()).then_some(stdout); + match output.status.code() { + Some(0) => HookCommandOutcome::Allow { message }, + Some(2) => HookCommandOutcome::Deny { message }, + Some(code) => HookCommandOutcome::Warn { + message: format_hook_warning( + command, + code, + message.as_deref(), + stderr.as_str(), + ), + }, + None => HookCommandOutcome::Warn { + message: format!( + "{} hook `{command}` terminated by signal while handling `{tool_name}`", + event.as_str() + ), + }, + } + } + Err(error) => HookCommandOutcome::Warn { + message: format!( + "{} hook `{command}` failed to start for `{tool_name}`: {error}", + event.as_str() + ), + }, + } + } +} + +enum HookCommandOutcome { + Allow { message: Option<String> }, + Deny { message: Option<String> }, + Warn { message: String }, +} + +fn parse_tool_input(tool_input: &str) -> serde_json::Value { + serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input })) +} + +fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String { + let mut message = + format!("Hook `{command}` exited with status {code}; allowing tool execution to continue"); + if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) { + message.push_str(": "); + message.push_str(stdout); + } else if !stderr.is_empty() { + message.push_str(": "); + message.push_str(stderr); + } + message +} + +fn shell_command(command: &str) -> CommandWithStdin { + #[cfg(windows)] + let command_builder = { + let mut command_builder = Command::new("cmd"); + command_builder.arg("/C").arg(command); + CommandWithStdin::new(command_builder) + }; + + #[cfg(not(windows))] + let command_builder = if Path::new(command).exists() { + let mut command_builder = Command::new("sh"); + command_builder.arg(command); + CommandWithStdin::new(command_builder) + } else { + let mut command_builder = Command::new("sh"); + command_builder.arg("-lc").arg(command); + CommandWithStdin::new(command_builder) + }; + + command_builder +} + +struct CommandWithStdin { + command: Command, +} + +impl CommandWithStdin { + fn new(command: Command) -> Self { + Self { command } + } + + fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stdin(cfg); + self + } + + fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stdout(cfg); + self + } + + fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stderr(cfg); + self + } + + fn env<K, V>(&mut self, key: K, value: V) -> &mut Self + where + K: AsRef<OsStr>, + V: AsRef<OsStr>, + { + self.command.env(key, value); + self + } + + fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result<std::process::Output> { + let mut child = self.command.spawn()?; + if let Some(mut child_stdin) = child.stdin.take() { + use std::io::Write as _; + child_stdin.write_all(stdin)?; + } + child.wait_with_output() + } +} + +#[cfg(test)] +mod tests { + use super::{HookRunResult, HookRunner}; + use crate::{PluginManager, PluginManagerConfig}; + use std::fs; + use std::path::{Path, PathBuf}; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn temp_dir(label: &str) -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("plugins-hook-runner-{label}-{nanos}")) + } + + fn write_hook_plugin(root: &Path, name: &str, pre_message: &str, post_message: &str) { + fs::create_dir_all(root.join(".claw-plugin")).expect("manifest dir"); + fs::create_dir_all(root.join("hooks")).expect("hooks dir"); + fs::write( + root.join("hooks").join("pre.sh"), + format!("#!/bin/sh\nprintf '%s\\n' '{pre_message}'\n"), + ) + .expect("write pre hook"); + fs::write( + root.join("hooks").join("post.sh"), + format!("#!/bin/sh\nprintf '%s\\n' '{post_message}'\n"), + ) + .expect("write post hook"); + fs::write( + root.join(".claw-plugin").join("plugin.json"), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"]\n }}\n}}" + ), + ) + .expect("write plugin manifest"); + } + + #[test] + fn collects_and_runs_hooks_from_enabled_plugins() { + let config_home = temp_dir("config"); + let first_source_root = temp_dir("source-a"); + let second_source_root = temp_dir("source-b"); + write_hook_plugin( + &first_source_root, + "first", + "plugin pre one", + "plugin post one", + ); + write_hook_plugin( + &second_source_root, + "second", + "plugin pre two", + "plugin post two", + ); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + manager + .install(first_source_root.to_str().expect("utf8 path")) + .expect("first plugin install should succeed"); + manager + .install(second_source_root.to_str().expect("utf8 path")) + .expect("second plugin install should succeed"); + let registry = manager.plugin_registry().expect("registry should build"); + + let runner = HookRunner::from_registry(®istry).expect("plugin hooks should load"); + + assert_eq!( + runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#), + HookRunResult::allow(vec![ + "plugin pre one".to_string(), + "plugin pre two".to_string(), + ]) + ); + assert_eq!( + runner.run_post_tool_use("Read", r#"{"path":"README.md"}"#, "ok", false), + HookRunResult::allow(vec![ + "plugin post one".to_string(), + "plugin post two".to_string(), + ]) + ); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(first_source_root); + let _ = fs::remove_dir_all(second_source_root); + } + + #[test] + fn pre_tool_use_denies_when_plugin_hook_exits_two() { + let runner = HookRunner::new(crate::PluginHooks { + pre_tool_use: vec!["printf 'blocked by plugin'; exit 2".to_string()], + post_tool_use: Vec::new(), + }); + + let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); + + assert!(result.is_denied()); + assert_eq!(result.messages(), &["blocked by plugin".to_string()]); + } +} diff --git a/crates/plugins/src/lib.rs b/crates/plugins/src/lib.rs new file mode 100644 index 0000000..6105ad9 --- /dev/null +++ b/crates/plugins/src/lib.rs @@ -0,0 +1,2943 @@ +mod hooks; + +use std::collections::{BTreeMap, BTreeSet}; +use std::fmt::{Display, Formatter}; +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; + +pub use hooks::{HookEvent, HookRunResult, HookRunner}; + +const EXTERNAL_MARKETPLACE: &str = "external"; +const BUILTIN_MARKETPLACE: &str = "builtin"; +const BUNDLED_MARKETPLACE: &str = "bundled"; +const SETTINGS_FILE_NAME: &str = "settings.json"; +const REGISTRY_FILE_NAME: &str = "installed.json"; +const MANIFEST_FILE_NAME: &str = "plugin.json"; +const MANIFEST_RELATIVE_PATH: &str = ".claw-plugin/plugin.json"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum PluginKind { + Builtin, + Bundled, + External, +} + +impl Display for PluginKind { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Builtin => write!(f, "builtin"), + Self::Bundled => write!(f, "bundled"), + Self::External => write!(f, "external"), + } + } +} + +impl PluginKind { + #[must_use] + fn marketplace(self) -> &'static str { + match self { + Self::Builtin => BUILTIN_MARKETPLACE, + Self::Bundled => BUNDLED_MARKETPLACE, + Self::External => EXTERNAL_MARKETPLACE, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PluginMetadata { + pub id: String, + pub name: String, + pub version: String, + pub description: String, + pub kind: PluginKind, + pub source: String, + pub default_enabled: bool, + pub root: Option<PathBuf>, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct PluginHooks { + #[serde(rename = "PreToolUse", default)] + pub pre_tool_use: Vec<String>, + #[serde(rename = "PostToolUse", default)] + pub post_tool_use: Vec<String>, +} + +impl PluginHooks { + #[must_use] + pub fn is_empty(&self) -> bool { + self.pre_tool_use.is_empty() && self.post_tool_use.is_empty() + } + + #[must_use] + pub fn merged_with(&self, other: &Self) -> Self { + let mut merged = self.clone(); + merged + .pre_tool_use + .extend(other.pre_tool_use.iter().cloned()); + merged + .post_tool_use + .extend(other.post_tool_use.iter().cloned()); + merged + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct PluginLifecycle { + #[serde(rename = "Init", default)] + pub init: Vec<String>, + #[serde(rename = "Shutdown", default)] + pub shutdown: Vec<String>, +} + +impl PluginLifecycle { + #[must_use] + pub fn is_empty(&self) -> bool { + self.init.is_empty() && self.shutdown.is_empty() + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PluginManifest { + pub name: String, + pub version: String, + pub description: String, + pub permissions: Vec<PluginPermission>, + #[serde(rename = "defaultEnabled", default)] + pub default_enabled: bool, + #[serde(default)] + pub hooks: PluginHooks, + #[serde(default)] + pub lifecycle: PluginLifecycle, + #[serde(default)] + pub tools: Vec<PluginToolManifest>, + #[serde(default)] + pub commands: Vec<PluginCommandManifest>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum PluginPermission { + Read, + Write, + Execute, +} + +impl PluginPermission { + #[must_use] + pub fn as_str(self) -> &'static str { + match self { + Self::Read => "read", + Self::Write => "write", + Self::Execute => "execute", + } + } + + fn parse(value: &str) -> Option<Self> { + match value { + "read" => Some(Self::Read), + "write" => Some(Self::Write), + "execute" => Some(Self::Execute), + _ => None, + } + } +} + +impl AsRef<str> for PluginPermission { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PluginToolManifest { + pub name: String, + pub description: String, + #[serde(rename = "inputSchema")] + pub input_schema: Value, + pub command: String, + #[serde(default)] + pub args: Vec<String>, + pub required_permission: PluginToolPermission, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum PluginToolPermission { + ReadOnly, + WorkspaceWrite, + DangerFullAccess, +} + +impl PluginToolPermission { + #[must_use] + pub fn as_str(self) -> &'static str { + match self { + Self::ReadOnly => "read-only", + Self::WorkspaceWrite => "workspace-write", + Self::DangerFullAccess => "danger-full-access", + } + } + + fn parse(value: &str) -> Option<Self> { + match value { + "read-only" => Some(Self::ReadOnly), + "workspace-write" => Some(Self::WorkspaceWrite), + "danger-full-access" => Some(Self::DangerFullAccess), + _ => None, + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PluginToolDefinition { + pub name: String, + #[serde(default)] + pub description: Option<String>, + #[serde(rename = "inputSchema")] + pub input_schema: Value, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct PluginCommandManifest { + pub name: String, + pub description: String, + pub command: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct RawPluginManifest { + pub name: String, + pub version: String, + pub description: String, + #[serde(default)] + pub permissions: Vec<String>, + #[serde(rename = "defaultEnabled", default)] + pub default_enabled: bool, + #[serde(default)] + pub hooks: PluginHooks, + #[serde(default)] + pub lifecycle: PluginLifecycle, + #[serde(default)] + pub tools: Vec<RawPluginToolManifest>, + #[serde(default)] + pub commands: Vec<PluginCommandManifest>, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct RawPluginToolManifest { + pub name: String, + pub description: String, + #[serde(rename = "inputSchema")] + pub input_schema: Value, + pub command: String, + #[serde(default)] + pub args: Vec<String>, + #[serde( + rename = "requiredPermission", + default = "default_tool_permission_label" + )] + pub required_permission: String, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PluginTool { + plugin_id: String, + plugin_name: String, + definition: PluginToolDefinition, + command: String, + args: Vec<String>, + required_permission: PluginToolPermission, + root: Option<PathBuf>, +} + +impl PluginTool { + #[must_use] + pub fn new( + plugin_id: impl Into<String>, + plugin_name: impl Into<String>, + definition: PluginToolDefinition, + command: impl Into<String>, + args: Vec<String>, + required_permission: PluginToolPermission, + root: Option<PathBuf>, + ) -> Self { + Self { + plugin_id: plugin_id.into(), + plugin_name: plugin_name.into(), + definition, + command: command.into(), + args, + required_permission, + root, + } + } + + #[must_use] + pub fn plugin_id(&self) -> &str { + &self.plugin_id + } + + #[must_use] + pub fn definition(&self) -> &PluginToolDefinition { + &self.definition + } + + #[must_use] + pub fn required_permission(&self) -> &str { + self.required_permission.as_str() + } + + pub fn execute(&self, input: &Value) -> Result<String, PluginError> { + let input_json = input.to_string(); + let mut process = Command::new(&self.command); + process + .args(&self.args) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .env("CLAW_PLUGIN_ID", &self.plugin_id) + .env("CLAW_PLUGIN_NAME", &self.plugin_name) + .env("CLAW_TOOL_NAME", &self.definition.name) + .env("CLAW_TOOL_INPUT", &input_json); + if let Some(root) = &self.root { + process + .current_dir(root) + .env("CLAW_PLUGIN_ROOT", root.display().to_string()); + } + + let mut child = process.spawn()?; + if let Some(stdin) = child.stdin.as_mut() { + use std::io::Write as _; + stdin.write_all(input_json.as_bytes())?; + } + + let output = child.wait_with_output()?; + if output.status.success() { + Ok(String::from_utf8_lossy(&output.stdout).trim().to_string()) + } else { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + Err(PluginError::CommandFailed(format!( + "plugin tool `{}` from `{}` failed for `{}`: {}", + self.definition.name, + self.plugin_id, + self.command, + if stderr.is_empty() { + format!("exit status {}", output.status) + } else { + stderr + } + ))) + } + } +} + +fn default_tool_permission_label() -> String { + "danger-full-access".to_string() +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum PluginInstallSource { + LocalPath { path: PathBuf }, + GitUrl { url: String }, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct InstalledPluginRecord { + #[serde(default = "default_plugin_kind")] + pub kind: PluginKind, + pub id: String, + pub name: String, + pub version: String, + pub description: String, + pub install_path: PathBuf, + pub source: PluginInstallSource, + pub installed_at_unix_ms: u128, + pub updated_at_unix_ms: u128, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct InstalledPluginRegistry { + #[serde(default)] + pub plugins: BTreeMap<String, InstalledPluginRecord>, +} + +fn default_plugin_kind() -> PluginKind { + PluginKind::External +} + +#[derive(Debug, Clone, PartialEq)] +pub struct BuiltinPlugin { + metadata: PluginMetadata, + hooks: PluginHooks, + lifecycle: PluginLifecycle, + tools: Vec<PluginTool>, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct BundledPlugin { + metadata: PluginMetadata, + hooks: PluginHooks, + lifecycle: PluginLifecycle, + tools: Vec<PluginTool>, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ExternalPlugin { + metadata: PluginMetadata, + hooks: PluginHooks, + lifecycle: PluginLifecycle, + tools: Vec<PluginTool>, +} + +pub trait Plugin { + fn metadata(&self) -> &PluginMetadata; + fn hooks(&self) -> &PluginHooks; + fn lifecycle(&self) -> &PluginLifecycle; + fn tools(&self) -> &[PluginTool]; + fn validate(&self) -> Result<(), PluginError>; + fn initialize(&self) -> Result<(), PluginError>; + fn shutdown(&self) -> Result<(), PluginError>; +} + +#[derive(Debug, Clone, PartialEq)] +pub enum PluginDefinition { + Builtin(BuiltinPlugin), + Bundled(BundledPlugin), + External(ExternalPlugin), +} + +impl Plugin for BuiltinPlugin { + fn metadata(&self) -> &PluginMetadata { + &self.metadata + } + + fn hooks(&self) -> &PluginHooks { + &self.hooks + } + + fn lifecycle(&self) -> &PluginLifecycle { + &self.lifecycle + } + + fn tools(&self) -> &[PluginTool] { + &self.tools + } + + fn validate(&self) -> Result<(), PluginError> { + Ok(()) + } + + fn initialize(&self) -> Result<(), PluginError> { + Ok(()) + } + + fn shutdown(&self) -> Result<(), PluginError> { + Ok(()) + } +} + +impl Plugin for BundledPlugin { + fn metadata(&self) -> &PluginMetadata { + &self.metadata + } + + fn hooks(&self) -> &PluginHooks { + &self.hooks + } + + fn lifecycle(&self) -> &PluginLifecycle { + &self.lifecycle + } + + fn tools(&self) -> &[PluginTool] { + &self.tools + } + + fn validate(&self) -> Result<(), PluginError> { + validate_hook_paths(self.metadata.root.as_deref(), &self.hooks)?; + validate_lifecycle_paths(self.metadata.root.as_deref(), &self.lifecycle)?; + validate_tool_paths(self.metadata.root.as_deref(), &self.tools) + } + + fn initialize(&self) -> Result<(), PluginError> { + run_lifecycle_commands( + self.metadata(), + self.lifecycle(), + "init", + &self.lifecycle.init, + ) + } + + fn shutdown(&self) -> Result<(), PluginError> { + run_lifecycle_commands( + self.metadata(), + self.lifecycle(), + "shutdown", + &self.lifecycle.shutdown, + ) + } +} + +impl Plugin for ExternalPlugin { + fn metadata(&self) -> &PluginMetadata { + &self.metadata + } + + fn hooks(&self) -> &PluginHooks { + &self.hooks + } + + fn lifecycle(&self) -> &PluginLifecycle { + &self.lifecycle + } + + fn tools(&self) -> &[PluginTool] { + &self.tools + } + + fn validate(&self) -> Result<(), PluginError> { + validate_hook_paths(self.metadata.root.as_deref(), &self.hooks)?; + validate_lifecycle_paths(self.metadata.root.as_deref(), &self.lifecycle)?; + validate_tool_paths(self.metadata.root.as_deref(), &self.tools) + } + + fn initialize(&self) -> Result<(), PluginError> { + run_lifecycle_commands( + self.metadata(), + self.lifecycle(), + "init", + &self.lifecycle.init, + ) + } + + fn shutdown(&self) -> Result<(), PluginError> { + run_lifecycle_commands( + self.metadata(), + self.lifecycle(), + "shutdown", + &self.lifecycle.shutdown, + ) + } +} + +impl Plugin for PluginDefinition { + fn metadata(&self) -> &PluginMetadata { + match self { + Self::Builtin(plugin) => plugin.metadata(), + Self::Bundled(plugin) => plugin.metadata(), + Self::External(plugin) => plugin.metadata(), + } + } + + fn hooks(&self) -> &PluginHooks { + match self { + Self::Builtin(plugin) => plugin.hooks(), + Self::Bundled(plugin) => plugin.hooks(), + Self::External(plugin) => plugin.hooks(), + } + } + + fn lifecycle(&self) -> &PluginLifecycle { + match self { + Self::Builtin(plugin) => plugin.lifecycle(), + Self::Bundled(plugin) => plugin.lifecycle(), + Self::External(plugin) => plugin.lifecycle(), + } + } + + fn tools(&self) -> &[PluginTool] { + match self { + Self::Builtin(plugin) => plugin.tools(), + Self::Bundled(plugin) => plugin.tools(), + Self::External(plugin) => plugin.tools(), + } + } + + fn validate(&self) -> Result<(), PluginError> { + match self { + Self::Builtin(plugin) => plugin.validate(), + Self::Bundled(plugin) => plugin.validate(), + Self::External(plugin) => plugin.validate(), + } + } + + fn initialize(&self) -> Result<(), PluginError> { + match self { + Self::Builtin(plugin) => plugin.initialize(), + Self::Bundled(plugin) => plugin.initialize(), + Self::External(plugin) => plugin.initialize(), + } + } + + fn shutdown(&self) -> Result<(), PluginError> { + match self { + Self::Builtin(plugin) => plugin.shutdown(), + Self::Bundled(plugin) => plugin.shutdown(), + Self::External(plugin) => plugin.shutdown(), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct RegisteredPlugin { + definition: PluginDefinition, + enabled: bool, +} + +impl RegisteredPlugin { + #[must_use] + pub fn new(definition: PluginDefinition, enabled: bool) -> Self { + Self { + definition, + enabled, + } + } + + #[must_use] + pub fn metadata(&self) -> &PluginMetadata { + self.definition.metadata() + } + + #[must_use] + pub fn hooks(&self) -> &PluginHooks { + self.definition.hooks() + } + + #[must_use] + pub fn tools(&self) -> &[PluginTool] { + self.definition.tools() + } + + #[must_use] + pub fn is_enabled(&self) -> bool { + self.enabled + } + + pub fn validate(&self) -> Result<(), PluginError> { + self.definition.validate() + } + + pub fn initialize(&self) -> Result<(), PluginError> { + self.definition.initialize() + } + + pub fn shutdown(&self) -> Result<(), PluginError> { + self.definition.shutdown() + } + + #[must_use] + pub fn summary(&self) -> PluginSummary { + PluginSummary { + metadata: self.metadata().clone(), + enabled: self.enabled, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PluginSummary { + pub metadata: PluginMetadata, + pub enabled: bool, +} + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct PluginRegistry { + plugins: Vec<RegisteredPlugin>, +} + +impl PluginRegistry { + #[must_use] + pub fn new(mut plugins: Vec<RegisteredPlugin>) -> Self { + plugins.sort_by(|left, right| left.metadata().id.cmp(&right.metadata().id)); + Self { plugins } + } + + #[must_use] + pub fn plugins(&self) -> &[RegisteredPlugin] { + &self.plugins + } + + #[must_use] + pub fn get(&self, plugin_id: &str) -> Option<&RegisteredPlugin> { + self.plugins + .iter() + .find(|plugin| plugin.metadata().id == plugin_id) + } + + #[must_use] + pub fn contains(&self, plugin_id: &str) -> bool { + self.get(plugin_id).is_some() + } + + #[must_use] + pub fn summaries(&self) -> Vec<PluginSummary> { + self.plugins.iter().map(RegisteredPlugin::summary).collect() + } + + pub fn aggregated_hooks(&self) -> Result<PluginHooks, PluginError> { + self.plugins + .iter() + .filter(|plugin| plugin.is_enabled()) + .try_fold(PluginHooks::default(), |acc, plugin| { + plugin.validate()?; + Ok(acc.merged_with(plugin.hooks())) + }) + } + + pub fn aggregated_tools(&self) -> Result<Vec<PluginTool>, PluginError> { + let mut tools = Vec::new(); + let mut seen_names = BTreeMap::new(); + for plugin in self.plugins.iter().filter(|plugin| plugin.is_enabled()) { + plugin.validate()?; + for tool in plugin.tools() { + if let Some(existing_plugin) = + seen_names.insert(tool.definition().name.clone(), tool.plugin_id().to_string()) + { + return Err(PluginError::InvalidManifest(format!( + "plugin tool `{}` is defined by both `{existing_plugin}` and `{}`", + tool.definition().name, + tool.plugin_id() + ))); + } + tools.push(tool.clone()); + } + } + Ok(tools) + } + + pub fn initialize(&self) -> Result<(), PluginError> { + for plugin in self.plugins.iter().filter(|plugin| plugin.is_enabled()) { + plugin.validate()?; + plugin.initialize()?; + } + Ok(()) + } + + pub fn shutdown(&self) -> Result<(), PluginError> { + for plugin in self + .plugins + .iter() + .rev() + .filter(|plugin| plugin.is_enabled()) + { + plugin.shutdown()?; + } + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PluginManagerConfig { + pub config_home: PathBuf, + pub enabled_plugins: BTreeMap<String, bool>, + pub external_dirs: Vec<PathBuf>, + pub install_root: Option<PathBuf>, + pub registry_path: Option<PathBuf>, + pub bundled_root: Option<PathBuf>, +} + +impl PluginManagerConfig { + #[must_use] + pub fn new(config_home: impl Into<PathBuf>) -> Self { + Self { + config_home: config_home.into(), + enabled_plugins: BTreeMap::new(), + external_dirs: Vec::new(), + install_root: None, + registry_path: None, + bundled_root: None, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PluginManager { + config: PluginManagerConfig, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct InstallOutcome { + pub plugin_id: String, + pub version: String, + pub install_path: PathBuf, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UpdateOutcome { + pub plugin_id: String, + pub old_version: String, + pub new_version: String, + pub install_path: PathBuf, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PluginManifestValidationError { + EmptyField { + field: &'static str, + }, + EmptyEntryField { + kind: &'static str, + field: &'static str, + name: Option<String>, + }, + InvalidPermission { + permission: String, + }, + DuplicatePermission { + permission: String, + }, + DuplicateEntry { + kind: &'static str, + name: String, + }, + MissingPath { + kind: &'static str, + path: PathBuf, + }, + InvalidToolInputSchema { + tool_name: String, + }, + InvalidToolRequiredPermission { + tool_name: String, + permission: String, + }, +} + +impl Display for PluginManifestValidationError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::EmptyField { field } => { + write!(f, "plugin manifest {field} cannot be empty") + } + Self::EmptyEntryField { kind, field, name } => match name { + Some(name) if !name.is_empty() => { + write!(f, "plugin {kind} `{name}` {field} cannot be empty") + } + _ => write!(f, "plugin {kind} {field} cannot be empty"), + }, + Self::InvalidPermission { permission } => { + write!( + f, + "plugin manifest permission `{permission}` must be one of read, write, or execute" + ) + } + Self::DuplicatePermission { permission } => { + write!(f, "plugin manifest permission `{permission}` is duplicated") + } + Self::DuplicateEntry { kind, name } => { + write!(f, "plugin {kind} `{name}` is duplicated") + } + Self::MissingPath { kind, path } => { + write!(f, "{kind} path `{}` does not exist", path.display()) + } + Self::InvalidToolInputSchema { tool_name } => { + write!( + f, + "plugin tool `{tool_name}` inputSchema must be a JSON object" + ) + } + Self::InvalidToolRequiredPermission { + tool_name, + permission, + } => write!( + f, + "plugin tool `{tool_name}` requiredPermission `{permission}` must be read-only, workspace-write, or danger-full-access" + ), + } + } +} + +#[derive(Debug)] +pub enum PluginError { + Io(std::io::Error), + Json(serde_json::Error), + ManifestValidation(Vec<PluginManifestValidationError>), + InvalidManifest(String), + NotFound(String), + CommandFailed(String), +} + +impl Display for PluginError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Io(error) => write!(f, "{error}"), + Self::Json(error) => write!(f, "{error}"), + Self::ManifestValidation(errors) => { + for (index, error) in errors.iter().enumerate() { + if index > 0 { + write!(f, "; ")?; + } + write!(f, "{error}")?; + } + Ok(()) + } + Self::InvalidManifest(message) + | Self::NotFound(message) + | Self::CommandFailed(message) => write!(f, "{message}"), + } + } +} + +impl std::error::Error for PluginError {} + +impl From<std::io::Error> for PluginError { + fn from(value: std::io::Error) -> Self { + Self::Io(value) + } +} + +impl From<serde_json::Error> for PluginError { + fn from(value: serde_json::Error) -> Self { + Self::Json(value) + } +} + +impl PluginManager { + #[must_use] + pub fn new(config: PluginManagerConfig) -> Self { + Self { config } + } + + #[must_use] + pub fn bundled_root() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("bundled") + } + + #[must_use] + pub fn install_root(&self) -> PathBuf { + self.config + .install_root + .clone() + .unwrap_or_else(|| self.config.config_home.join("plugins").join("installed")) + } + + #[must_use] + pub fn registry_path(&self) -> PathBuf { + self.config.registry_path.clone().unwrap_or_else(|| { + self.config + .config_home + .join("plugins") + .join(REGISTRY_FILE_NAME) + }) + } + + #[must_use] + pub fn settings_path(&self) -> PathBuf { + self.config.config_home.join(SETTINGS_FILE_NAME) + } + + pub fn plugin_registry(&self) -> Result<PluginRegistry, PluginError> { + Ok(PluginRegistry::new( + self.discover_plugins()? + .into_iter() + .map(|plugin| { + let enabled = self.is_enabled(plugin.metadata()); + RegisteredPlugin::new(plugin, enabled) + }) + .collect(), + )) + } + + pub fn list_plugins(&self) -> Result<Vec<PluginSummary>, PluginError> { + Ok(self.plugin_registry()?.summaries()) + } + + pub fn list_installed_plugins(&self) -> Result<Vec<PluginSummary>, PluginError> { + Ok(self.installed_plugin_registry()?.summaries()) + } + + pub fn discover_plugins(&self) -> Result<Vec<PluginDefinition>, PluginError> { + self.sync_bundled_plugins()?; + let mut plugins = builtin_plugins(); + plugins.extend(self.discover_installed_plugins()?); + plugins.extend(self.discover_external_directory_plugins(&plugins)?); + Ok(plugins) + } + + pub fn aggregated_hooks(&self) -> Result<PluginHooks, PluginError> { + self.plugin_registry()?.aggregated_hooks() + } + + pub fn aggregated_tools(&self) -> Result<Vec<PluginTool>, PluginError> { + self.plugin_registry()?.aggregated_tools() + } + + pub fn validate_plugin_source(&self, source: &str) -> Result<PluginManifest, PluginError> { + let path = resolve_local_source(source)?; + load_plugin_from_directory(&path) + } + + pub fn install(&mut self, source: &str) -> Result<InstallOutcome, PluginError> { + let install_source = parse_install_source(source)?; + let temp_root = self.install_root().join(".tmp"); + let staged_source = materialize_source(&install_source, &temp_root)?; + let cleanup_source = matches!(install_source, PluginInstallSource::GitUrl { .. }); + let manifest = load_plugin_from_directory(&staged_source)?; + + let plugin_id = plugin_id(&manifest.name, EXTERNAL_MARKETPLACE); + let install_path = self.install_root().join(sanitize_plugin_id(&plugin_id)); + if install_path.exists() { + fs::remove_dir_all(&install_path)?; + } + copy_dir_all(&staged_source, &install_path)?; + if cleanup_source { + let _ = fs::remove_dir_all(&staged_source); + } + + let now = unix_time_ms(); + let record = InstalledPluginRecord { + kind: PluginKind::External, + id: plugin_id.clone(), + name: manifest.name, + version: manifest.version.clone(), + description: manifest.description, + install_path: install_path.clone(), + source: install_source, + installed_at_unix_ms: now, + updated_at_unix_ms: now, + }; + + let mut registry = self.load_registry()?; + registry.plugins.insert(plugin_id.clone(), record); + self.store_registry(®istry)?; + self.write_enabled_state(&plugin_id, Some(true))?; + self.config.enabled_plugins.insert(plugin_id.clone(), true); + + Ok(InstallOutcome { + plugin_id, + version: manifest.version, + install_path, + }) + } + + pub fn enable(&mut self, plugin_id: &str) -> Result<(), PluginError> { + self.ensure_known_plugin(plugin_id)?; + self.write_enabled_state(plugin_id, Some(true))?; + self.config + .enabled_plugins + .insert(plugin_id.to_string(), true); + Ok(()) + } + + pub fn disable(&mut self, plugin_id: &str) -> Result<(), PluginError> { + self.ensure_known_plugin(plugin_id)?; + self.write_enabled_state(plugin_id, Some(false))?; + self.config + .enabled_plugins + .insert(plugin_id.to_string(), false); + Ok(()) + } + + pub fn uninstall(&mut self, plugin_id: &str) -> Result<(), PluginError> { + let mut registry = self.load_registry()?; + let record = registry.plugins.remove(plugin_id).ok_or_else(|| { + PluginError::NotFound(format!("plugin `{plugin_id}` is not installed")) + })?; + if record.kind == PluginKind::Bundled { + registry.plugins.insert(plugin_id.to_string(), record); + return Err(PluginError::CommandFailed(format!( + "plugin `{plugin_id}` is bundled and managed automatically; disable it instead" + ))); + } + if record.install_path.exists() { + fs::remove_dir_all(&record.install_path)?; + } + self.store_registry(®istry)?; + self.write_enabled_state(plugin_id, None)?; + self.config.enabled_plugins.remove(plugin_id); + Ok(()) + } + + pub fn update(&mut self, plugin_id: &str) -> Result<UpdateOutcome, PluginError> { + let mut registry = self.load_registry()?; + let record = registry.plugins.get(plugin_id).cloned().ok_or_else(|| { + PluginError::NotFound(format!("plugin `{plugin_id}` is not installed")) + })?; + + let temp_root = self.install_root().join(".tmp"); + let staged_source = materialize_source(&record.source, &temp_root)?; + let cleanup_source = matches!(record.source, PluginInstallSource::GitUrl { .. }); + let manifest = load_plugin_from_directory(&staged_source)?; + + if record.install_path.exists() { + fs::remove_dir_all(&record.install_path)?; + } + copy_dir_all(&staged_source, &record.install_path)?; + if cleanup_source { + let _ = fs::remove_dir_all(&staged_source); + } + + let updated_record = InstalledPluginRecord { + version: manifest.version.clone(), + description: manifest.description, + updated_at_unix_ms: unix_time_ms(), + ..record.clone() + }; + registry + .plugins + .insert(plugin_id.to_string(), updated_record); + self.store_registry(®istry)?; + + Ok(UpdateOutcome { + plugin_id: plugin_id.to_string(), + old_version: record.version, + new_version: manifest.version, + install_path: record.install_path, + }) + } + + fn discover_installed_plugins(&self) -> Result<Vec<PluginDefinition>, PluginError> { + let mut registry = self.load_registry()?; + let mut plugins = Vec::new(); + let mut seen_ids = BTreeSet::<String>::new(); + let mut seen_paths = BTreeSet::<PathBuf>::new(); + let mut stale_registry_ids = Vec::new(); + + for install_path in discover_plugin_dirs(&self.install_root())? { + let matched_record = registry + .plugins + .values() + .find(|record| record.install_path == install_path); + let kind = matched_record.map_or(PluginKind::External, |record| record.kind); + let source = matched_record.map_or_else( + || install_path.display().to_string(), + |record| describe_install_source(&record.source), + ); + let plugin = load_plugin_definition(&install_path, kind, source, kind.marketplace())?; + if seen_ids.insert(plugin.metadata().id.clone()) { + seen_paths.insert(install_path); + plugins.push(plugin); + } + } + + for record in registry.plugins.values() { + if seen_paths.contains(&record.install_path) { + continue; + } + if !record.install_path.exists() || plugin_manifest_path(&record.install_path).is_err() + { + stale_registry_ids.push(record.id.clone()); + continue; + } + let plugin = load_plugin_definition( + &record.install_path, + record.kind, + describe_install_source(&record.source), + record.kind.marketplace(), + )?; + if seen_ids.insert(plugin.metadata().id.clone()) { + seen_paths.insert(record.install_path.clone()); + plugins.push(plugin); + } + } + + if !stale_registry_ids.is_empty() { + for plugin_id in stale_registry_ids { + registry.plugins.remove(&plugin_id); + } + self.store_registry(®istry)?; + } + + Ok(plugins) + } + + fn discover_external_directory_plugins( + &self, + existing_plugins: &[PluginDefinition], + ) -> Result<Vec<PluginDefinition>, PluginError> { + let mut plugins = Vec::new(); + + for directory in &self.config.external_dirs { + for root in discover_plugin_dirs(directory)? { + let plugin = load_plugin_definition( + &root, + PluginKind::External, + root.display().to_string(), + EXTERNAL_MARKETPLACE, + )?; + if existing_plugins + .iter() + .chain(plugins.iter()) + .all(|existing| existing.metadata().id != plugin.metadata().id) + { + plugins.push(plugin); + } + } + } + + Ok(plugins) + } + + fn installed_plugin_registry(&self) -> Result<PluginRegistry, PluginError> { + self.sync_bundled_plugins()?; + Ok(PluginRegistry::new( + self.discover_installed_plugins()? + .into_iter() + .map(|plugin| { + let enabled = self.is_enabled(plugin.metadata()); + RegisteredPlugin::new(plugin, enabled) + }) + .collect(), + )) + } + + fn sync_bundled_plugins(&self) -> Result<(), PluginError> { + let bundled_root = self + .config + .bundled_root + .clone() + .unwrap_or_else(Self::bundled_root); + let bundled_plugins = discover_plugin_dirs(&bundled_root)?; + let mut registry = self.load_registry()?; + let mut changed = false; + let install_root = self.install_root(); + let mut active_bundled_ids = BTreeSet::new(); + + for source_root in bundled_plugins { + let manifest = load_plugin_from_directory(&source_root)?; + let plugin_id = plugin_id(&manifest.name, BUNDLED_MARKETPLACE); + active_bundled_ids.insert(plugin_id.clone()); + let install_path = install_root.join(sanitize_plugin_id(&plugin_id)); + let now = unix_time_ms(); + let existing_record = registry.plugins.get(&plugin_id); + let installed_copy_is_valid = + install_path.exists() && load_plugin_from_directory(&install_path).is_ok(); + let needs_sync = existing_record.is_none_or(|record| { + record.kind != PluginKind::Bundled + || record.version != manifest.version + || record.name != manifest.name + || record.description != manifest.description + || record.install_path != install_path + || !record.install_path.exists() + || !installed_copy_is_valid + }); + + if !needs_sync { + continue; + } + + if install_path.exists() { + fs::remove_dir_all(&install_path)?; + } + copy_dir_all(&source_root, &install_path)?; + + let installed_at_unix_ms = + existing_record.map_or(now, |record| record.installed_at_unix_ms); + registry.plugins.insert( + plugin_id.clone(), + InstalledPluginRecord { + kind: PluginKind::Bundled, + id: plugin_id, + name: manifest.name, + version: manifest.version, + description: manifest.description, + install_path, + source: PluginInstallSource::LocalPath { path: source_root }, + installed_at_unix_ms, + updated_at_unix_ms: now, + }, + ); + changed = true; + } + + let stale_bundled_ids = registry + .plugins + .iter() + .filter_map(|(plugin_id, record)| { + (record.kind == PluginKind::Bundled && !active_bundled_ids.contains(plugin_id)) + .then_some(plugin_id.clone()) + }) + .collect::<Vec<_>>(); + + for plugin_id in stale_bundled_ids { + if let Some(record) = registry.plugins.remove(&plugin_id) { + if record.install_path.exists() { + fs::remove_dir_all(&record.install_path)?; + } + changed = true; + } + } + + if changed { + self.store_registry(®istry)?; + } + + Ok(()) + } + + fn is_enabled(&self, metadata: &PluginMetadata) -> bool { + self.config + .enabled_plugins + .get(&metadata.id) + .copied() + .unwrap_or(match metadata.kind { + PluginKind::External => false, + PluginKind::Builtin | PluginKind::Bundled => metadata.default_enabled, + }) + } + + fn ensure_known_plugin(&self, plugin_id: &str) -> Result<(), PluginError> { + if self.plugin_registry()?.contains(plugin_id) { + Ok(()) + } else { + Err(PluginError::NotFound(format!( + "plugin `{plugin_id}` is not installed or discoverable" + ))) + } + } + + fn load_registry(&self) -> Result<InstalledPluginRegistry, PluginError> { + let path = self.registry_path(); + match fs::read_to_string(&path) { + Ok(contents) if contents.trim().is_empty() => Ok(InstalledPluginRegistry::default()), + Ok(contents) => Ok(serde_json::from_str(&contents)?), + Err(error) if error.kind() == std::io::ErrorKind::NotFound => { + Ok(InstalledPluginRegistry::default()) + } + Err(error) => Err(PluginError::Io(error)), + } + } + + fn store_registry(&self, registry: &InstalledPluginRegistry) -> Result<(), PluginError> { + let path = self.registry_path(); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + fs::write(path, serde_json::to_string_pretty(registry)?)?; + Ok(()) + } + + fn write_enabled_state( + &self, + plugin_id: &str, + enabled: Option<bool>, + ) -> Result<(), PluginError> { + update_settings_json(&self.settings_path(), |root| { + let enabled_plugins = ensure_object(root, "enabledPlugins"); + match enabled { + Some(value) => { + enabled_plugins.insert(plugin_id.to_string(), Value::Bool(value)); + } + None => { + enabled_plugins.remove(plugin_id); + } + } + }) + } +} + +#[must_use] +pub fn builtin_plugins() -> Vec<PluginDefinition> { + vec![PluginDefinition::Builtin(BuiltinPlugin { + metadata: PluginMetadata { + id: plugin_id("example-builtin", BUILTIN_MARKETPLACE), + name: "example-builtin".to_string(), + version: "0.1.0".to_string(), + description: "Example built-in plugin scaffold for the Rust plugin system".to_string(), + kind: PluginKind::Builtin, + source: BUILTIN_MARKETPLACE.to_string(), + default_enabled: false, + root: None, + }, + hooks: PluginHooks::default(), + lifecycle: PluginLifecycle::default(), + tools: Vec::new(), + })] +} + +fn load_plugin_definition( + root: &Path, + kind: PluginKind, + source: String, + marketplace: &str, +) -> Result<PluginDefinition, PluginError> { + let manifest = load_plugin_from_directory(root)?; + let metadata = PluginMetadata { + id: plugin_id(&manifest.name, marketplace), + name: manifest.name, + version: manifest.version, + description: manifest.description, + kind, + source, + default_enabled: manifest.default_enabled, + root: Some(root.to_path_buf()), + }; + let hooks = resolve_hooks(root, &manifest.hooks); + let lifecycle = resolve_lifecycle(root, &manifest.lifecycle); + let tools = resolve_tools(root, &metadata.id, &metadata.name, &manifest.tools); + Ok(match kind { + PluginKind::Builtin => PluginDefinition::Builtin(BuiltinPlugin { + metadata, + hooks, + lifecycle, + tools, + }), + PluginKind::Bundled => PluginDefinition::Bundled(BundledPlugin { + metadata, + hooks, + lifecycle, + tools, + }), + PluginKind::External => PluginDefinition::External(ExternalPlugin { + metadata, + hooks, + lifecycle, + tools, + }), + }) +} + +pub fn load_plugin_from_directory(root: &Path) -> Result<PluginManifest, PluginError> { + load_manifest_from_directory(root) +} + +fn load_manifest_from_directory(root: &Path) -> Result<PluginManifest, PluginError> { + let manifest_path = plugin_manifest_path(root)?; + load_manifest_from_path(root, &manifest_path) +} + +fn load_manifest_from_path( + root: &Path, + manifest_path: &Path, +) -> Result<PluginManifest, PluginError> { + let contents = fs::read_to_string(manifest_path).map_err(|error| { + PluginError::NotFound(format!( + "plugin manifest not found at {}: {error}", + manifest_path.display() + )) + })?; + let raw_manifest: RawPluginManifest = serde_json::from_str(&contents)?; + build_plugin_manifest(root, raw_manifest) +} + +fn plugin_manifest_path(root: &Path) -> Result<PathBuf, PluginError> { + let direct_path = root.join(MANIFEST_FILE_NAME); + if direct_path.exists() { + return Ok(direct_path); + } + + let packaged_path = root.join(MANIFEST_RELATIVE_PATH); + if packaged_path.exists() { + return Ok(packaged_path); + } + + Err(PluginError::NotFound(format!( + "plugin manifest not found at {} or {}", + direct_path.display(), + packaged_path.display() + ))) +} + +fn build_plugin_manifest( + root: &Path, + raw: RawPluginManifest, +) -> Result<PluginManifest, PluginError> { + let mut errors = Vec::new(); + + validate_required_manifest_field("name", &raw.name, &mut errors); + validate_required_manifest_field("version", &raw.version, &mut errors); + validate_required_manifest_field("description", &raw.description, &mut errors); + + let permissions = build_manifest_permissions(&raw.permissions, &mut errors); + validate_command_entries(root, raw.hooks.pre_tool_use.iter(), "hook", &mut errors); + validate_command_entries(root, raw.hooks.post_tool_use.iter(), "hook", &mut errors); + validate_command_entries( + root, + raw.lifecycle.init.iter(), + "lifecycle command", + &mut errors, + ); + validate_command_entries( + root, + raw.lifecycle.shutdown.iter(), + "lifecycle command", + &mut errors, + ); + let tools = build_manifest_tools(root, raw.tools, &mut errors); + let commands = build_manifest_commands(root, raw.commands, &mut errors); + + if !errors.is_empty() { + return Err(PluginError::ManifestValidation(errors)); + } + + Ok(PluginManifest { + name: raw.name, + version: raw.version, + description: raw.description, + permissions, + default_enabled: raw.default_enabled, + hooks: raw.hooks, + lifecycle: raw.lifecycle, + tools, + commands, + }) +} + +fn validate_required_manifest_field( + field: &'static str, + value: &str, + errors: &mut Vec<PluginManifestValidationError>, +) { + if value.trim().is_empty() { + errors.push(PluginManifestValidationError::EmptyField { field }); + } +} + +fn build_manifest_permissions( + permissions: &[String], + errors: &mut Vec<PluginManifestValidationError>, +) -> Vec<PluginPermission> { + let mut seen = BTreeSet::new(); + let mut validated = Vec::new(); + + for permission in permissions { + let permission = permission.trim(); + if permission.is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "permission", + field: "value", + name: None, + }); + continue; + } + if !seen.insert(permission.to_string()) { + errors.push(PluginManifestValidationError::DuplicatePermission { + permission: permission.to_string(), + }); + continue; + } + match PluginPermission::parse(permission) { + Some(permission) => validated.push(permission), + None => errors.push(PluginManifestValidationError::InvalidPermission { + permission: permission.to_string(), + }), + } + } + + validated +} + +fn build_manifest_tools( + root: &Path, + tools: Vec<RawPluginToolManifest>, + errors: &mut Vec<PluginManifestValidationError>, +) -> Vec<PluginToolManifest> { + let mut seen = BTreeSet::new(); + let mut validated = Vec::new(); + + for tool in tools { + let name = tool.name.trim().to_string(); + if name.is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "tool", + field: "name", + name: None, + }); + continue; + } + if !seen.insert(name.clone()) { + errors.push(PluginManifestValidationError::DuplicateEntry { kind: "tool", name }); + continue; + } + if tool.description.trim().is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "tool", + field: "description", + name: Some(name.clone()), + }); + } + if tool.command.trim().is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "tool", + field: "command", + name: Some(name.clone()), + }); + } else { + validate_command_entry(root, &tool.command, "tool", errors); + } + if !tool.input_schema.is_object() { + errors.push(PluginManifestValidationError::InvalidToolInputSchema { + tool_name: name.clone(), + }); + } + let Some(required_permission) = + PluginToolPermission::parse(tool.required_permission.trim()) + else { + errors.push( + PluginManifestValidationError::InvalidToolRequiredPermission { + tool_name: name.clone(), + permission: tool.required_permission.trim().to_string(), + }, + ); + continue; + }; + + validated.push(PluginToolManifest { + name, + description: tool.description, + input_schema: tool.input_schema, + command: tool.command, + args: tool.args, + required_permission, + }); + } + + validated +} + +fn build_manifest_commands( + root: &Path, + commands: Vec<PluginCommandManifest>, + errors: &mut Vec<PluginManifestValidationError>, +) -> Vec<PluginCommandManifest> { + let mut seen = BTreeSet::new(); + let mut validated = Vec::new(); + + for command in commands { + let name = command.name.trim().to_string(); + if name.is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "command", + field: "name", + name: None, + }); + continue; + } + if !seen.insert(name.clone()) { + errors.push(PluginManifestValidationError::DuplicateEntry { + kind: "command", + name, + }); + continue; + } + if command.description.trim().is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "command", + field: "description", + name: Some(name.clone()), + }); + } + if command.command.trim().is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "command", + field: "command", + name: Some(name.clone()), + }); + } else { + validate_command_entry(root, &command.command, "command", errors); + } + validated.push(command); + } + + validated +} + +fn validate_command_entries<'a>( + root: &Path, + entries: impl Iterator<Item = &'a String>, + kind: &'static str, + errors: &mut Vec<PluginManifestValidationError>, +) { + for entry in entries { + validate_command_entry(root, entry, kind, errors); + } +} + +fn validate_command_entry( + root: &Path, + entry: &str, + kind: &'static str, + errors: &mut Vec<PluginManifestValidationError>, +) { + if entry.trim().is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind, + field: "command", + name: None, + }); + return; + } + if is_literal_command(entry) { + return; + } + + let path = if Path::new(entry).is_absolute() { + PathBuf::from(entry) + } else { + root.join(entry) + }; + if !path.exists() { + errors.push(PluginManifestValidationError::MissingPath { kind, path }); + } +} + +fn resolve_hooks(root: &Path, hooks: &PluginHooks) -> PluginHooks { + PluginHooks { + pre_tool_use: hooks + .pre_tool_use + .iter() + .map(|entry| resolve_hook_entry(root, entry)) + .collect(), + post_tool_use: hooks + .post_tool_use + .iter() + .map(|entry| resolve_hook_entry(root, entry)) + .collect(), + } +} + +fn resolve_lifecycle(root: &Path, lifecycle: &PluginLifecycle) -> PluginLifecycle { + PluginLifecycle { + init: lifecycle + .init + .iter() + .map(|entry| resolve_hook_entry(root, entry)) + .collect(), + shutdown: lifecycle + .shutdown + .iter() + .map(|entry| resolve_hook_entry(root, entry)) + .collect(), + } +} + +fn resolve_tools( + root: &Path, + plugin_id: &str, + plugin_name: &str, + tools: &[PluginToolManifest], +) -> Vec<PluginTool> { + tools + .iter() + .map(|tool| { + PluginTool::new( + plugin_id, + plugin_name, + PluginToolDefinition { + name: tool.name.clone(), + description: Some(tool.description.clone()), + input_schema: tool.input_schema.clone(), + }, + resolve_hook_entry(root, &tool.command), + tool.args.clone(), + tool.required_permission, + Some(root.to_path_buf()), + ) + }) + .collect() +} + +fn validate_hook_paths(root: Option<&Path>, hooks: &PluginHooks) -> Result<(), PluginError> { + let Some(root) = root else { + return Ok(()); + }; + for entry in hooks.pre_tool_use.iter().chain(hooks.post_tool_use.iter()) { + validate_command_path(root, entry, "hook")?; + } + Ok(()) +} + +fn validate_lifecycle_paths( + root: Option<&Path>, + lifecycle: &PluginLifecycle, +) -> Result<(), PluginError> { + let Some(root) = root else { + return Ok(()); + }; + for entry in lifecycle.init.iter().chain(lifecycle.shutdown.iter()) { + validate_command_path(root, entry, "lifecycle command")?; + } + Ok(()) +} + +fn validate_tool_paths(root: Option<&Path>, tools: &[PluginTool]) -> Result<(), PluginError> { + let Some(root) = root else { + return Ok(()); + }; + for tool in tools { + validate_command_path(root, &tool.command, "tool")?; + } + Ok(()) +} + +fn validate_command_path(root: &Path, entry: &str, kind: &str) -> Result<(), PluginError> { + if is_literal_command(entry) { + return Ok(()); + } + let path = if Path::new(entry).is_absolute() { + PathBuf::from(entry) + } else { + root.join(entry) + }; + if !path.exists() { + return Err(PluginError::InvalidManifest(format!( + "{kind} path `{}` does not exist", + path.display() + ))); + } + Ok(()) +} + +fn resolve_hook_entry(root: &Path, entry: &str) -> String { + if is_literal_command(entry) { + entry.to_string() + } else { + root.join(entry).display().to_string() + } +} + +fn is_literal_command(entry: &str) -> bool { + !entry.starts_with("./") && !entry.starts_with("../") && !Path::new(entry).is_absolute() +} + +fn run_lifecycle_commands( + metadata: &PluginMetadata, + lifecycle: &PluginLifecycle, + phase: &str, + commands: &[String], +) -> Result<(), PluginError> { + if lifecycle.is_empty() || commands.is_empty() { + return Ok(()); + } + + for command in commands { + let mut process = if Path::new(command).exists() { + if cfg!(windows) { + let mut process = Command::new("cmd"); + process.arg("/C").arg(command); + process + } else { + let mut process = Command::new("sh"); + process.arg(command); + process + } + } else if cfg!(windows) { + let mut process = Command::new("cmd"); + process.arg("/C").arg(command); + process + } else { + let mut process = Command::new("sh"); + process.arg("-lc").arg(command); + process + }; + if let Some(root) = &metadata.root { + process.current_dir(root); + } + let output = process.output()?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(PluginError::CommandFailed(format!( + "plugin `{}` {} failed for `{}`: {}", + metadata.id, + phase, + command, + if stderr.is_empty() { + format!("exit status {}", output.status) + } else { + stderr + } + ))); + } + } + + Ok(()) +} + +fn resolve_local_source(source: &str) -> Result<PathBuf, PluginError> { + let path = PathBuf::from(source); + if path.exists() { + Ok(path) + } else { + Err(PluginError::NotFound(format!( + "plugin source `{source}` was not found" + ))) + } +} + +fn parse_install_source(source: &str) -> Result<PluginInstallSource, PluginError> { + if source.starts_with("http://") + || source.starts_with("https://") + || source.starts_with("git@") + || Path::new(source) + .extension() + .is_some_and(|extension| extension.eq_ignore_ascii_case("git")) + { + Ok(PluginInstallSource::GitUrl { + url: source.to_string(), + }) + } else { + Ok(PluginInstallSource::LocalPath { + path: resolve_local_source(source)?, + }) + } +} + +fn materialize_source( + source: &PluginInstallSource, + temp_root: &Path, +) -> Result<PathBuf, PluginError> { + fs::create_dir_all(temp_root)?; + match source { + PluginInstallSource::LocalPath { path } => Ok(path.clone()), + PluginInstallSource::GitUrl { url } => { + let destination = temp_root.join(format!("plugin-{}", unix_time_ms())); + let output = Command::new("git") + .arg("clone") + .arg("--depth") + .arg("1") + .arg(url) + .arg(&destination) + .output()?; + if !output.status.success() { + return Err(PluginError::CommandFailed(format!( + "git clone failed for `{url}`: {}", + String::from_utf8_lossy(&output.stderr).trim() + ))); + } + Ok(destination) + } + } +} + +fn discover_plugin_dirs(root: &Path) -> Result<Vec<PathBuf>, PluginError> { + match fs::read_dir(root) { + Ok(entries) => { + let mut paths = Vec::new(); + for entry in entries { + let path = entry?.path(); + if path.is_dir() && plugin_manifest_path(&path).is_ok() { + paths.push(path); + } + } + paths.sort(); + Ok(paths) + } + Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(Vec::new()), + Err(error) => Err(PluginError::Io(error)), + } +} + +fn plugin_id(name: &str, marketplace: &str) -> String { + format!("{name}@{marketplace}") +} + +fn sanitize_plugin_id(plugin_id: &str) -> String { + plugin_id + .chars() + .map(|ch| match ch { + '/' | '\\' | '@' | ':' => '-', + other => other, + }) + .collect() +} + +fn describe_install_source(source: &PluginInstallSource) -> String { + match source { + PluginInstallSource::LocalPath { path } => path.display().to_string(), + PluginInstallSource::GitUrl { url } => url.clone(), + } +} + +fn unix_time_ms() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_millis() +} + +fn copy_dir_all(source: &Path, destination: &Path) -> Result<(), PluginError> { + fs::create_dir_all(destination)?; + for entry in fs::read_dir(source)? { + let entry = entry?; + let target = destination.join(entry.file_name()); + if entry.file_type()?.is_dir() { + copy_dir_all(&entry.path(), &target)?; + } else { + fs::copy(entry.path(), target)?; + } + } + Ok(()) +} + +fn update_settings_json( + path: &Path, + mut update: impl FnMut(&mut Map<String, Value>), +) -> Result<(), PluginError> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + let mut root = match fs::read_to_string(path) { + Ok(contents) if !contents.trim().is_empty() => serde_json::from_str::<Value>(&contents)?, + Ok(_) => Value::Object(Map::new()), + Err(error) if error.kind() == std::io::ErrorKind::NotFound => Value::Object(Map::new()), + Err(error) => return Err(PluginError::Io(error)), + }; + + let object = root.as_object_mut().ok_or_else(|| { + PluginError::InvalidManifest(format!( + "settings file {} must contain a JSON object", + path.display() + )) + })?; + update(object); + fs::write(path, serde_json::to_string_pretty(&root)?)?; + Ok(()) +} + +fn ensure_object<'a>(root: &'a mut Map<String, Value>, key: &str) -> &'a mut Map<String, Value> { + if !root.get(key).is_some_and(Value::is_object) { + root.insert(key.to_string(), Value::Object(Map::new())); + } + root.get_mut(key) + .and_then(Value::as_object_mut) + .expect("object should exist") +} + +#[cfg(test)] +mod tests { + use super::*; + + fn temp_dir(label: &str) -> PathBuf { + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("plugins-{label}-{nanos}")) + } + + fn write_file(path: &Path, contents: &str) { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).expect("parent dir"); + } + fs::write(path, contents).expect("write file"); + } + + fn write_loader_plugin(root: &Path) { + write_file( + root.join("hooks").join("pre.sh").as_path(), + "#!/bin/sh\nprintf 'pre'\n", + ); + write_file( + root.join("tools").join("echo-tool.sh").as_path(), + "#!/bin/sh\ncat\n", + ); + write_file( + root.join("commands").join("sync.sh").as_path(), + "#!/bin/sh\nprintf 'sync'\n", + ); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "loader-demo", + "version": "1.2.3", + "description": "Manifest loader test plugin", + "permissions": ["read", "write"], + "hooks": { + "PreToolUse": ["./hooks/pre.sh"] + }, + "tools": [ + { + "name": "echo_tool", + "description": "Echoes JSON input", + "inputSchema": { + "type": "object" + }, + "command": "./tools/echo-tool.sh", + "requiredPermission": "workspace-write" + } + ], + "commands": [ + { + "name": "sync", + "description": "Sync command", + "command": "./commands/sync.sh" + } + ] +}"#, + ); + } + + fn write_external_plugin(root: &Path, name: &str, version: &str) { + write_file( + root.join("hooks").join("pre.sh").as_path(), + "#!/bin/sh\nprintf 'pre'\n", + ); + write_file( + root.join("hooks").join("post.sh").as_path(), + "#!/bin/sh\nprintf 'post'\n", + ); + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"test plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"]\n }}\n}}" + ) + .as_str(), + ); + } + + fn write_broken_plugin(root: &Path, name: &str) { + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"broken plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/missing.sh\"]\n }}\n}}" + ) + .as_str(), + ); + } + + fn write_lifecycle_plugin(root: &Path, name: &str, version: &str) -> PathBuf { + let log_path = root.join("lifecycle.log"); + write_file( + root.join("lifecycle").join("init.sh").as_path(), + "#!/bin/sh\nprintf 'init\\n' >> lifecycle.log\n", + ); + write_file( + root.join("lifecycle").join("shutdown.sh").as_path(), + "#!/bin/sh\nprintf 'shutdown\\n' >> lifecycle.log\n", + ); + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"lifecycle plugin\",\n \"lifecycle\": {{\n \"Init\": [\"./lifecycle/init.sh\"],\n \"Shutdown\": [\"./lifecycle/shutdown.sh\"]\n }}\n}}" + ) + .as_str(), + ); + log_path + } + + fn write_tool_plugin(root: &Path, name: &str, version: &str) { + write_tool_plugin_with_name(root, name, version, "plugin_echo"); + } + + fn write_tool_plugin_with_name(root: &Path, name: &str, version: &str, tool_name: &str) { + let script_path = root.join("tools").join("echo-json.sh"); + write_file( + &script_path, + "#!/bin/sh\nINPUT=$(cat)\nprintf '{\"plugin\":\"%s\",\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAW_PLUGIN_ID\" \"$CLAW_TOOL_NAME\" \"$INPUT\"\n", + ); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + + let mut permissions = fs::metadata(&script_path).expect("metadata").permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&script_path, permissions).expect("chmod"); + } + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"tool plugin\",\n \"tools\": [\n {{\n \"name\": \"{tool_name}\",\n \"description\": \"Echo JSON input\",\n \"inputSchema\": {{\"type\": \"object\", \"properties\": {{\"message\": {{\"type\": \"string\"}}}}, \"required\": [\"message\"], \"additionalProperties\": false}},\n \"command\": \"./tools/echo-json.sh\",\n \"requiredPermission\": \"workspace-write\"\n }}\n ]\n}}" + ) + .as_str(), + ); + } + + fn write_bundled_plugin(root: &Path, name: &str, version: &str, default_enabled: bool) { + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"bundled plugin\",\n \"defaultEnabled\": {}\n}}", + if default_enabled { "true" } else { "false" } + ) + .as_str(), + ); + } + + fn load_enabled_plugins(path: &Path) -> BTreeMap<String, bool> { + let contents = fs::read_to_string(path).expect("settings should exist"); + let root: Value = serde_json::from_str(&contents).expect("settings json"); + root.get("enabledPlugins") + .and_then(Value::as_object) + .map(|enabled_plugins| { + enabled_plugins + .iter() + .map(|(plugin_id, value)| { + ( + plugin_id.clone(), + value.as_bool().expect("plugin state should be a bool"), + ) + }) + .collect() + }) + .unwrap_or_default() + } + + #[test] + fn load_plugin_from_directory_validates_required_fields() { + let root = temp_dir("manifest-required"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{"name":"","version":"1.0.0","description":"desc"}"#, + ); + + let error = load_plugin_from_directory(&root).expect_err("empty name should fail"); + assert!(error.to_string().contains("name cannot be empty")); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_reads_root_manifest_and_validates_entries() { + let root = temp_dir("manifest-root"); + write_loader_plugin(&root); + + let manifest = load_plugin_from_directory(&root).expect("manifest should load"); + assert_eq!(manifest.name, "loader-demo"); + assert_eq!(manifest.version, "1.2.3"); + assert_eq!( + manifest + .permissions + .iter() + .map(|permission| permission.as_str()) + .collect::<Vec<_>>(), + vec!["read", "write"] + ); + assert_eq!(manifest.hooks.pre_tool_use, vec!["./hooks/pre.sh"]); + assert_eq!(manifest.tools.len(), 1); + assert_eq!(manifest.tools[0].name, "echo_tool"); + assert_eq!( + manifest.tools[0].required_permission, + PluginToolPermission::WorkspaceWrite + ); + assert_eq!(manifest.commands.len(), 1); + assert_eq!(manifest.commands[0].name, "sync"); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_supports_packaged_manifest_path() { + let root = temp_dir("manifest-packaged"); + write_external_plugin(&root, "packaged-demo", "1.0.0"); + + let manifest = load_plugin_from_directory(&root).expect("packaged manifest should load"); + assert_eq!(manifest.name, "packaged-demo"); + assert!(manifest.tools.is_empty()); + assert!(manifest.commands.is_empty()); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_defaults_optional_fields() { + let root = temp_dir("manifest-defaults"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "minimal", + "version": "0.1.0", + "description": "Minimal manifest" +}"#, + ); + + let manifest = load_plugin_from_directory(&root).expect("minimal manifest should load"); + assert!(manifest.permissions.is_empty()); + assert!(manifest.hooks.is_empty()); + assert!(manifest.tools.is_empty()); + assert!(manifest.commands.is_empty()); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_rejects_duplicate_permissions_and_commands() { + let root = temp_dir("manifest-duplicates"); + write_file( + root.join("commands").join("sync.sh").as_path(), + "#!/bin/sh\nprintf 'sync'\n", + ); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "duplicate-manifest", + "version": "1.0.0", + "description": "Duplicate validation", + "permissions": ["read", "read"], + "commands": [ + {"name": "sync", "description": "Sync one", "command": "./commands/sync.sh"}, + {"name": "sync", "description": "Sync two", "command": "./commands/sync.sh"} + ] +}"#, + ); + + let error = load_plugin_from_directory(&root).expect_err("duplicates should fail"); + match error { + PluginError::ManifestValidation(errors) => { + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::DuplicatePermission { permission } + if permission == "read" + ))); + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::DuplicateEntry { kind, name } + if *kind == "command" && name == "sync" + ))); + } + other => panic!("expected manifest validation errors, got {other}"), + } + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_rejects_missing_tool_or_command_paths() { + let root = temp_dir("manifest-paths"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "missing-paths", + "version": "1.0.0", + "description": "Missing path validation", + "tools": [ + { + "name": "tool_one", + "description": "Missing tool script", + "inputSchema": {"type": "object"}, + "command": "./tools/missing.sh" + } + ] +}"#, + ); + + let error = load_plugin_from_directory(&root).expect_err("missing paths should fail"); + assert!(error.to_string().contains("does not exist")); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_rejects_invalid_permissions() { + let root = temp_dir("manifest-invalid-permissions"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "invalid-permissions", + "version": "1.0.0", + "description": "Invalid permission validation", + "permissions": ["admin"] +}"#, + ); + + let error = load_plugin_from_directory(&root).expect_err("invalid permissions should fail"); + match error { + PluginError::ManifestValidation(errors) => { + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::InvalidPermission { permission } + if permission == "admin" + ))); + } + other => panic!("expected manifest validation errors, got {other}"), + } + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_rejects_invalid_tool_required_permission() { + let root = temp_dir("manifest-invalid-tool-permission"); + write_file( + root.join("tools").join("echo.sh").as_path(), + "#!/bin/sh\ncat\n", + ); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "invalid-tool-permission", + "version": "1.0.0", + "description": "Invalid tool permission validation", + "tools": [ + { + "name": "echo_tool", + "description": "Echo tool", + "inputSchema": {"type": "object"}, + "command": "./tools/echo.sh", + "requiredPermission": "admin" + } + ] +}"#, + ); + + let error = + load_plugin_from_directory(&root).expect_err("invalid tool permission should fail"); + match error { + PluginError::ManifestValidation(errors) => { + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::InvalidToolRequiredPermission { + tool_name, + permission + } if tool_name == "echo_tool" && permission == "admin" + ))); + } + other => panic!("expected manifest validation errors, got {other}"), + } + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_accumulates_multiple_validation_errors() { + let root = temp_dir("manifest-multi-error"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "", + "version": "1.0.0", + "description": "", + "permissions": ["admin"], + "commands": [ + {"name": "", "description": "", "command": "./commands/missing.sh"} + ] +}"#, + ); + + let error = + load_plugin_from_directory(&root).expect_err("multiple manifest errors should fail"); + match error { + PluginError::ManifestValidation(errors) => { + assert!(errors.len() >= 4); + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::EmptyField { field } if *field == "name" + ))); + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::EmptyField { field } + if *field == "description" + ))); + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::InvalidPermission { permission } + if permission == "admin" + ))); + } + other => panic!("expected manifest validation errors, got {other}"), + } + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn discovers_builtin_and_bundled_plugins() { + let manager = PluginManager::new(PluginManagerConfig::new(temp_dir("discover"))); + let plugins = manager.list_plugins().expect("plugins should list"); + assert!(plugins + .iter() + .any(|plugin| plugin.metadata.kind == PluginKind::Builtin)); + assert!(plugins + .iter() + .any(|plugin| plugin.metadata.kind == PluginKind::Bundled)); + } + + #[test] + fn installs_enables_updates_and_uninstalls_external_plugins() { + let config_home = temp_dir("home"); + let source_root = temp_dir("source"); + write_external_plugin(&source_root, "demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let install = manager + .install(source_root.to_str().expect("utf8 path")) + .expect("install should succeed"); + assert_eq!(install.plugin_id, "demo@external"); + assert!(manager + .list_plugins() + .expect("list plugins") + .iter() + .any(|plugin| plugin.metadata.id == "demo@external" && plugin.enabled)); + + let hooks = manager.aggregated_hooks().expect("hooks should aggregate"); + assert_eq!(hooks.pre_tool_use.len(), 1); + assert!(hooks.pre_tool_use[0].contains("pre.sh")); + + manager + .disable("demo@external") + .expect("disable should work"); + assert!(manager + .aggregated_hooks() + .expect("hooks after disable") + .is_empty()); + manager.enable("demo@external").expect("enable should work"); + + write_external_plugin(&source_root, "demo", "2.0.0"); + let update = manager.update("demo@external").expect("update should work"); + assert_eq!(update.old_version, "1.0.0"); + assert_eq!(update.new_version, "2.0.0"); + + manager + .uninstall("demo@external") + .expect("uninstall should work"); + assert!(!manager + .list_plugins() + .expect("list plugins") + .iter() + .any(|plugin| plugin.metadata.id == "demo@external")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn auto_installs_bundled_plugins_into_the_registry() { + let config_home = temp_dir("bundled-home"); + let bundled_root = temp_dir("bundled-root"); + write_bundled_plugin(&bundled_root.join("starter"), "starter", "0.1.0", false); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + let manager = PluginManager::new(config); + + let installed = manager + .list_installed_plugins() + .expect("bundled plugins should auto-install"); + assert!(installed.iter().any(|plugin| { + plugin.metadata.id == "starter@bundled" + && plugin.metadata.kind == PluginKind::Bundled + && !plugin.enabled + })); + + let registry = manager.load_registry().expect("registry should exist"); + let record = registry + .plugins + .get("starter@bundled") + .expect("bundled plugin should be recorded"); + assert_eq!(record.kind, PluginKind::Bundled); + assert!(record.install_path.exists()); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn default_bundled_root_loads_repo_bundles_as_installed_plugins() { + let config_home = temp_dir("default-bundled-home"); + let manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + + let installed = manager + .list_installed_plugins() + .expect("default bundled plugins should auto-install"); + assert!(installed + .iter() + .any(|plugin| plugin.metadata.id == "example-bundled@bundled")); + assert!(installed + .iter() + .any(|plugin| plugin.metadata.id == "sample-hooks@bundled")); + + let _ = fs::remove_dir_all(config_home); + } + + #[test] + fn bundled_sync_prunes_removed_bundled_registry_entries() { + let config_home = temp_dir("bundled-prune-home"); + let bundled_root = temp_dir("bundled-prune-root"); + let stale_install_path = config_home + .join("plugins") + .join("installed") + .join("stale-bundled-external"); + write_bundled_plugin(&bundled_root.join("active"), "active", "0.1.0", false); + write_file( + stale_install_path.join(MANIFEST_RELATIVE_PATH).as_path(), + r#"{ + "name": "stale", + "version": "0.1.0", + "description": "stale bundled plugin" +}"#, + ); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + config.install_root = Some(config_home.join("plugins").join("installed")); + let manager = PluginManager::new(config); + + let mut registry = InstalledPluginRegistry::default(); + registry.plugins.insert( + "stale@bundled".to_string(), + InstalledPluginRecord { + kind: PluginKind::Bundled, + id: "stale@bundled".to_string(), + name: "stale".to_string(), + version: "0.1.0".to_string(), + description: "stale bundled plugin".to_string(), + install_path: stale_install_path.clone(), + source: PluginInstallSource::LocalPath { + path: bundled_root.join("stale"), + }, + installed_at_unix_ms: 1, + updated_at_unix_ms: 1, + }, + ); + manager.store_registry(®istry).expect("store registry"); + manager + .write_enabled_state("stale@bundled", Some(true)) + .expect("seed bundled enabled state"); + + let installed = manager + .list_installed_plugins() + .expect("bundled sync should succeed"); + assert!(installed + .iter() + .any(|plugin| plugin.metadata.id == "active@bundled")); + assert!(!installed + .iter() + .any(|plugin| plugin.metadata.id == "stale@bundled")); + + let registry = manager.load_registry().expect("load registry"); + assert!(!registry.plugins.contains_key("stale@bundled")); + assert!(!stale_install_path.exists()); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn installed_plugin_discovery_keeps_registry_entries_outside_install_root() { + let config_home = temp_dir("registry-fallback-home"); + let bundled_root = temp_dir("registry-fallback-bundled"); + let install_root = config_home.join("plugins").join("installed"); + let external_install_path = temp_dir("registry-fallback-external"); + write_file( + external_install_path.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "registry-fallback", + "version": "1.0.0", + "description": "Registry fallback plugin" +}"#, + ); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + config.install_root = Some(install_root.clone()); + let manager = PluginManager::new(config); + + let mut registry = InstalledPluginRegistry::default(); + registry.plugins.insert( + "registry-fallback@external".to_string(), + InstalledPluginRecord { + kind: PluginKind::External, + id: "registry-fallback@external".to_string(), + name: "registry-fallback".to_string(), + version: "1.0.0".to_string(), + description: "Registry fallback plugin".to_string(), + install_path: external_install_path.clone(), + source: PluginInstallSource::LocalPath { + path: external_install_path.clone(), + }, + installed_at_unix_ms: 1, + updated_at_unix_ms: 1, + }, + ); + manager.store_registry(®istry).expect("store registry"); + manager + .write_enabled_state("stale-external@external", Some(true)) + .expect("seed stale external enabled state"); + + let installed = manager + .list_installed_plugins() + .expect("registry fallback plugin should load"); + assert!(installed + .iter() + .any(|plugin| plugin.metadata.id == "registry-fallback@external")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + let _ = fs::remove_dir_all(external_install_path); + } + + #[test] + fn installed_plugin_discovery_prunes_stale_registry_entries() { + let config_home = temp_dir("registry-prune-home"); + let bundled_root = temp_dir("registry-prune-bundled"); + let install_root = config_home.join("plugins").join("installed"); + let missing_install_path = temp_dir("registry-prune-missing"); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + config.install_root = Some(install_root); + let manager = PluginManager::new(config); + + let mut registry = InstalledPluginRegistry::default(); + registry.plugins.insert( + "stale-external@external".to_string(), + InstalledPluginRecord { + kind: PluginKind::External, + id: "stale-external@external".to_string(), + name: "stale-external".to_string(), + version: "1.0.0".to_string(), + description: "stale external plugin".to_string(), + install_path: missing_install_path.clone(), + source: PluginInstallSource::LocalPath { + path: missing_install_path.clone(), + }, + installed_at_unix_ms: 1, + updated_at_unix_ms: 1, + }, + ); + manager.store_registry(®istry).expect("store registry"); + + let installed = manager + .list_installed_plugins() + .expect("stale registry entries should be pruned"); + assert!(!installed + .iter() + .any(|plugin| plugin.metadata.id == "stale-external@external")); + + let registry = manager.load_registry().expect("load registry"); + assert!(!registry.plugins.contains_key("stale-external@external")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn persists_bundled_plugin_enable_state_across_reloads() { + let config_home = temp_dir("bundled-state-home"); + let bundled_root = temp_dir("bundled-state-root"); + write_bundled_plugin(&bundled_root.join("starter"), "starter", "0.1.0", false); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + let mut manager = PluginManager::new(config.clone()); + + manager + .enable("starter@bundled") + .expect("enable bundled plugin should succeed"); + assert_eq!( + load_enabled_plugins(&manager.settings_path()).get("starter@bundled"), + Some(&true) + ); + + let mut reloaded_config = PluginManagerConfig::new(&config_home); + reloaded_config.bundled_root = Some(bundled_root.clone()); + reloaded_config.enabled_plugins = load_enabled_plugins(&manager.settings_path()); + let reloaded_manager = PluginManager::new(reloaded_config); + let reloaded = reloaded_manager + .list_installed_plugins() + .expect("bundled plugins should still be listed"); + assert!(reloaded + .iter() + .any(|plugin| { plugin.metadata.id == "starter@bundled" && plugin.enabled })); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn persists_bundled_plugin_disable_state_across_reloads() { + let config_home = temp_dir("bundled-disabled-home"); + let bundled_root = temp_dir("bundled-disabled-root"); + write_bundled_plugin(&bundled_root.join("starter"), "starter", "0.1.0", true); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + let mut manager = PluginManager::new(config); + + manager + .disable("starter@bundled") + .expect("disable bundled plugin should succeed"); + assert_eq!( + load_enabled_plugins(&manager.settings_path()).get("starter@bundled"), + Some(&false) + ); + + let mut reloaded_config = PluginManagerConfig::new(&config_home); + reloaded_config.bundled_root = Some(bundled_root.clone()); + reloaded_config.enabled_plugins = load_enabled_plugins(&manager.settings_path()); + let reloaded_manager = PluginManager::new(reloaded_config); + let reloaded = reloaded_manager + .list_installed_plugins() + .expect("bundled plugins should still be listed"); + assert!(reloaded + .iter() + .any(|plugin| { plugin.metadata.id == "starter@bundled" && !plugin.enabled })); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn validates_plugin_source_before_install() { + let config_home = temp_dir("validate-home"); + let source_root = temp_dir("validate-source"); + write_external_plugin(&source_root, "validator", "1.0.0"); + let manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let manifest = manager + .validate_plugin_source(source_root.to_str().expect("utf8 path")) + .expect("manifest should validate"); + assert_eq!(manifest.name, "validator"); + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn plugin_registry_tracks_enabled_state_and_lookup() { + let config_home = temp_dir("registry-home"); + let source_root = temp_dir("registry-source"); + write_external_plugin(&source_root, "registry-demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + manager + .install(source_root.to_str().expect("utf8 path")) + .expect("install should succeed"); + manager + .disable("registry-demo@external") + .expect("disable should succeed"); + + let registry = manager.plugin_registry().expect("registry should build"); + let plugin = registry + .get("registry-demo@external") + .expect("installed plugin should be discoverable"); + assert_eq!(plugin.metadata().name, "registry-demo"); + assert!(!plugin.is_enabled()); + assert!(registry.contains("registry-demo@external")); + assert!(!registry.contains("missing@external")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn rejects_plugin_sources_with_missing_hook_paths() { + let config_home = temp_dir("broken-home"); + let source_root = temp_dir("broken-source"); + write_broken_plugin(&source_root, "broken"); + + let manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let error = manager + .validate_plugin_source(source_root.to_str().expect("utf8 path")) + .expect_err("missing hook file should fail validation"); + assert!(error.to_string().contains("does not exist")); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let install_error = manager + .install(source_root.to_str().expect("utf8 path")) + .expect_err("install should reject invalid hook paths"); + assert!(install_error.to_string().contains("does not exist")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn plugin_registry_runs_initialize_and_shutdown_for_enabled_plugins() { + let config_home = temp_dir("lifecycle-home"); + let source_root = temp_dir("lifecycle-source"); + let _ = write_lifecycle_plugin(&source_root, "lifecycle-demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let install = manager + .install(source_root.to_str().expect("utf8 path")) + .expect("install should succeed"); + let log_path = install.install_path.join("lifecycle.log"); + + let registry = manager.plugin_registry().expect("registry should build"); + registry.initialize().expect("init should succeed"); + registry.shutdown().expect("shutdown should succeed"); + + let log = fs::read_to_string(&log_path).expect("lifecycle log should exist"); + assert_eq!(log, "init\nshutdown\n"); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn aggregates_and_executes_plugin_tools() { + let config_home = temp_dir("tool-home"); + let source_root = temp_dir("tool-source"); + write_tool_plugin(&source_root, "tool-demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + manager + .install(source_root.to_str().expect("utf8 path")) + .expect("install should succeed"); + + let tools = manager.aggregated_tools().expect("tools should aggregate"); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].definition().name, "plugin_echo"); + assert_eq!(tools[0].required_permission(), "workspace-write"); + + let output = tools[0] + .execute(&serde_json::json!({ "message": "hello" })) + .expect("plugin tool should execute"); + let payload: Value = serde_json::from_str(&output).expect("valid json"); + assert_eq!(payload["plugin"], "tool-demo@external"); + assert_eq!(payload["tool"], "plugin_echo"); + assert_eq!(payload["input"]["message"], "hello"); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn list_installed_plugins_scans_install_root_without_registry_entries() { + let config_home = temp_dir("installed-scan-home"); + let bundled_root = temp_dir("installed-scan-bundled"); + let install_root = config_home.join("plugins").join("installed"); + let installed_plugin_root = install_root.join("scan-demo"); + write_file( + installed_plugin_root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "scan-demo", + "version": "1.0.0", + "description": "Scanned from install root" +}"#, + ); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + config.install_root = Some(install_root); + let manager = PluginManager::new(config); + + let installed = manager + .list_installed_plugins() + .expect("installed plugins should scan directories"); + assert!(installed + .iter() + .any(|plugin| plugin.metadata.id == "scan-demo@external")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn list_installed_plugins_scans_packaged_manifests_in_install_root() { + let config_home = temp_dir("installed-packaged-scan-home"); + let bundled_root = temp_dir("installed-packaged-scan-bundled"); + let install_root = config_home.join("plugins").join("installed"); + let installed_plugin_root = install_root.join("scan-packaged"); + write_file( + installed_plugin_root.join(MANIFEST_RELATIVE_PATH).as_path(), + r#"{ + "name": "scan-packaged", + "version": "1.0.0", + "description": "Packaged manifest in install root" +}"#, + ); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + config.install_root = Some(install_root); + let manager = PluginManager::new(config); + + let installed = manager + .list_installed_plugins() + .expect("installed plugins should scan packaged manifests"); + assert!(installed + .iter() + .any(|plugin| plugin.metadata.id == "scan-packaged@external")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } +} diff --git a/crates/runtime/Cargo.toml b/crates/runtime/Cargo.toml index 7ce7cd8..025cd03 100644 --- a/crates/runtime/Cargo.toml +++ b/crates/runtime/Cargo.toml @@ -8,9 +8,11 @@ publish.workspace = true [dependencies] sha2 = "0.10" glob = "0.3" +lsp = { path = "../lsp" } +plugins = { path = "../plugins" } regex = "1" serde = { version = "1", features = ["derive"] } -serde_json = "1" +serde_json.workspace = true tokio = { version = "1", features = ["io-util", "macros", "process", "rt", "rt-multi-thread", "time"] } walkdir = "2" diff --git a/crates/runtime/src/bootstrap.rs b/crates/runtime/src/bootstrap.rs index dfc99ab..760f27e 100644 --- a/crates/runtime/src/bootstrap.rs +++ b/crates/runtime/src/bootstrap.rs @@ -21,7 +21,7 @@ pub struct BootstrapPlan { impl BootstrapPlan { #[must_use] - pub fn claude_code_default() -> Self { + pub fn claw_default() -> Self { Self::from_phases(vec![ BootstrapPhase::CliEntry, BootstrapPhase::FastPathVersion, diff --git a/crates/runtime/src/compact.rs b/crates/runtime/src/compact.rs index e227019..a0792da 100644 --- a/crates/runtime/src/compact.rs +++ b/crates/runtime/src/compact.rs @@ -1,5 +1,10 @@ use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session}; +const COMPACT_CONTINUATION_PREAMBLE: &str = + "This session is being continued from a previous conversation that ran out of context. The summary below covers the earlier portion of the conversation.\n\n"; +const COMPACT_RECENT_MESSAGES_NOTE: &str = "Recent messages are preserved verbatim."; +const COMPACT_DIRECT_RESUME_INSTRUCTION: &str = "Continue the conversation from where it left off without asking the user any further questions. Resume directly — do not acknowledge the summary, do not recap what was happening, and do not preface with continuation text."; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct CompactionConfig { pub preserve_recent_messages: usize, @@ -30,8 +35,15 @@ pub fn estimate_session_tokens(session: &Session) -> usize { #[must_use] pub fn should_compact(session: &Session, config: CompactionConfig) -> bool { - session.messages.len() > config.preserve_recent_messages - && estimate_session_tokens(session) >= config.max_estimated_tokens + let start = compacted_summary_prefix_len(session); + let compactable = &session.messages[start..]; + + compactable.len() > config.preserve_recent_messages + && compactable + .iter() + .map(estimate_message_tokens) + .sum::<usize>() + >= config.max_estimated_tokens } #[must_use] @@ -56,16 +68,18 @@ pub fn get_compact_continuation_message( recent_messages_preserved: bool, ) -> String { let mut base = format!( - "This session is being continued from a previous conversation that ran out of context. The summary below covers the earlier portion of the conversation.\n\n{}", + "{COMPACT_CONTINUATION_PREAMBLE}{}", format_compact_summary(summary) ); if recent_messages_preserved { - base.push_str("\n\nRecent messages are preserved verbatim."); + base.push_str("\n\n"); + base.push_str(COMPACT_RECENT_MESSAGES_NOTE); } if suppress_follow_up_questions { - base.push_str("\nContinue the conversation from where it left off without asking the user any further questions. Resume directly — do not acknowledge the summary, do not recap what was happening, and do not preface with continuation text."); + base.push('\n'); + base.push_str(COMPACT_DIRECT_RESUME_INSTRUCTION); } base @@ -82,13 +96,19 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio }; } + let existing_summary = session + .messages + .first() + .and_then(extract_existing_compacted_summary); + let compacted_prefix_len = usize::from(existing_summary.is_some()); let keep_from = session .messages .len() .saturating_sub(config.preserve_recent_messages); - let removed = &session.messages[..keep_from]; + let removed = &session.messages[compacted_prefix_len..keep_from]; let preserved = session.messages[keep_from..].to_vec(); - let summary = summarize_messages(removed); + let summary = + merge_compact_summaries(existing_summary.as_deref(), &summarize_messages(removed)); let formatted_summary = format_compact_summary(&summary); let continuation = get_compact_continuation_message(&summary, true, !preserved.is_empty()); @@ -110,6 +130,16 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio } } +fn compacted_summary_prefix_len(session: &Session) -> usize { + usize::from( + session + .messages + .first() + .and_then(extract_existing_compacted_summary) + .is_some(), + ) +} + fn summarize_messages(messages: &[ConversationMessage]) -> String { let user_messages = messages .iter() @@ -197,6 +227,41 @@ fn summarize_messages(messages: &[ConversationMessage]) -> String { lines.join("\n") } +fn merge_compact_summaries(existing_summary: Option<&str>, new_summary: &str) -> String { + let Some(existing_summary) = existing_summary else { + return new_summary.to_string(); + }; + + let previous_highlights = extract_summary_highlights(existing_summary); + let new_formatted_summary = format_compact_summary(new_summary); + let new_highlights = extract_summary_highlights(&new_formatted_summary); + let new_timeline = extract_summary_timeline(&new_formatted_summary); + + let mut lines = vec!["<summary>".to_string(), "Conversation summary:".to_string()]; + + if !previous_highlights.is_empty() { + lines.push("- Previously compacted context:".to_string()); + lines.extend( + previous_highlights + .into_iter() + .map(|line| format!(" {line}")), + ); + } + + if !new_highlights.is_empty() { + lines.push("- Newly compacted context:".to_string()); + lines.extend(new_highlights.into_iter().map(|line| format!(" {line}"))); + } + + if !new_timeline.is_empty() { + lines.push("- Key timeline:".to_string()); + lines.extend(new_timeline.into_iter().map(|line| format!(" {line}"))); + } + + lines.push("</summary>".to_string()); + lines.join("\n") +} + fn summarize_block(block: &ContentBlock) -> String { let raw = match block { ContentBlock::Text { text } => text.clone(), @@ -374,11 +439,71 @@ fn collapse_blank_lines(content: &str) -> String { result } +fn extract_existing_compacted_summary(message: &ConversationMessage) -> Option<String> { + if message.role != MessageRole::System { + return None; + } + + let text = first_text_block(message)?; + let summary = text.strip_prefix(COMPACT_CONTINUATION_PREAMBLE)?; + let summary = summary + .split_once(&format!("\n\n{COMPACT_RECENT_MESSAGES_NOTE}")) + .map_or(summary, |(value, _)| value); + let summary = summary + .split_once(&format!("\n{COMPACT_DIRECT_RESUME_INSTRUCTION}")) + .map_or(summary, |(value, _)| value); + Some(summary.trim().to_string()) +} + +fn extract_summary_highlights(summary: &str) -> Vec<String> { + let mut lines = Vec::new(); + let mut in_timeline = false; + + for line in format_compact_summary(summary).lines() { + let trimmed = line.trim_end(); + if trimmed.is_empty() || trimmed == "Summary:" || trimmed == "Conversation summary:" { + continue; + } + if trimmed == "- Key timeline:" { + in_timeline = true; + continue; + } + if in_timeline { + continue; + } + lines.push(trimmed.to_string()); + } + + lines +} + +fn extract_summary_timeline(summary: &str) -> Vec<String> { + let mut lines = Vec::new(); + let mut in_timeline = false; + + for line in format_compact_summary(summary).lines() { + let trimmed = line.trim_end(); + if trimmed == "- Key timeline:" { + in_timeline = true; + continue; + } + if !in_timeline { + continue; + } + if trimmed.is_empty() { + break; + } + lines.push(trimmed.to_string()); + } + + lines +} + #[cfg(test)] mod tests { use super::{ collect_key_files, compact_session, estimate_session_tokens, format_compact_summary, - infer_pending_work, should_compact, CompactionConfig, + get_compact_continuation_message, infer_pending_work, should_compact, CompactionConfig, }; use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session}; @@ -453,6 +578,98 @@ mod tests { ); } + #[test] + fn keeps_previous_compacted_context_when_compacting_again() { + let initial_session = Session { + version: 1, + messages: vec![ + ConversationMessage::user_text("Investigate rust/crates/runtime/src/compact.rs"), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "I will inspect the compact flow.".to_string(), + }]), + ConversationMessage::user_text( + "Also update rust/crates/runtime/src/conversation.rs", + ), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "Next: preserve prior summary context during auto compact.".to_string(), + }]), + ], + }; + let config = CompactionConfig { + preserve_recent_messages: 2, + max_estimated_tokens: 1, + }; + + let first = compact_session(&initial_session, config); + let mut follow_up_messages = first.compacted_session.messages.clone(); + follow_up_messages.extend([ + ConversationMessage::user_text("Please add regression tests for compaction."), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "Working on regression coverage now.".to_string(), + }]), + ]); + + let second = compact_session( + &Session { + version: 1, + messages: follow_up_messages, + }, + config, + ); + + assert!(second + .formatted_summary + .contains("Previously compacted context:")); + assert!(second + .formatted_summary + .contains("Scope: 2 earlier messages compacted")); + assert!(second + .formatted_summary + .contains("Newly compacted context:")); + assert!(second + .formatted_summary + .contains("Also update rust/crates/runtime/src/conversation.rs")); + assert!(matches!( + &second.compacted_session.messages[0].blocks[0], + ContentBlock::Text { text } + if text.contains("Previously compacted context:") + && text.contains("Newly compacted context:") + )); + assert!(matches!( + &second.compacted_session.messages[1].blocks[0], + ContentBlock::Text { text } if text.contains("Please add regression tests for compaction.") + )); + } + + #[test] + fn ignores_existing_compacted_summary_when_deciding_to_recompact() { + let summary = "<summary>Conversation summary:\n- Scope: earlier work preserved.\n- Key timeline:\n - user: large preserved context\n</summary>"; + let session = Session { + version: 1, + messages: vec![ + ConversationMessage { + role: MessageRole::System, + blocks: vec![ContentBlock::Text { + text: get_compact_continuation_message(summary, true, true), + }], + usage: None, + }, + ConversationMessage::user_text("tiny"), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "recent".to_string(), + }]), + ], + }; + + assert!(!should_compact( + &session, + CompactionConfig { + preserve_recent_messages: 2, + max_estimated_tokens: 1, + } + )); + } + #[test] fn truncates_long_blocks_in_summary() { let summary = super::summarize_block(&ContentBlock::Text { @@ -465,10 +682,10 @@ mod tests { #[test] fn extracts_key_files_from_message_content() { let files = collect_key_files(&[ConversationMessage::user_text( - "Update rust/crates/runtime/src/compact.rs and rust/crates/rusty-claude-cli/src/main.rs next.", + "Update rust/crates/runtime/src/compact.rs and rust/crates/tools/src/lib.rs next.", )]); assert!(files.contains(&"rust/crates/runtime/src/compact.rs".to_string())); - assert!(files.contains(&"rust/crates/rusty-claude-cli/src/main.rs".to_string())); + assert!(files.contains(&"rust/crates/tools/src/lib.rs".to_string())); } #[test] diff --git a/crates/runtime/src/config.rs b/crates/runtime/src/config.rs index 60ef53f..11ec21d 100644 --- a/crates/runtime/src/config.rs +++ b/crates/runtime/src/config.rs @@ -6,7 +6,7 @@ use std::path::{Path, PathBuf}; use crate::json::JsonValue; use crate::sandbox::{FilesystemIsolationMode, SandboxConfig}; -pub const CLAUDE_CODE_SETTINGS_SCHEMA_NAME: &str = "SettingsSchema"; +pub const CLAW_SETTINGS_SCHEMA_NAME: &str = "SettingsSchema"; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum ConfigSource { @@ -35,8 +35,19 @@ pub struct RuntimeConfig { feature_config: RuntimeFeatureConfig, } +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct RuntimePluginConfig { + enabled_plugins: BTreeMap<String, bool>, + external_directories: Vec<String>, + install_root: Option<String>, + registry_path: Option<String>, + bundled_root: Option<String>, +} + #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct RuntimeFeatureConfig { + hooks: RuntimeHookConfig, + plugins: RuntimePluginConfig, mcp: McpConfigCollection, oauth: Option<OAuthConfig>, model: Option<String>, @@ -44,6 +55,12 @@ pub struct RuntimeFeatureConfig { sandbox: SandboxConfig, } +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct RuntimeHookConfig { + pre_tool_use: Vec<String>, + post_tool_use: Vec<String>, +} + #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct McpConfigCollection { servers: BTreeMap<String, ScopedMcpServerConfig>, @@ -62,7 +79,7 @@ pub enum McpTransport { Http, Ws, Sdk, - ClaudeAiProxy, + ManagedProxy, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -72,7 +89,7 @@ pub enum McpServerConfig { Http(McpRemoteServerConfig), Ws(McpWebSocketServerConfig), Sdk(McpSdkServerConfig), - ClaudeAiProxy(McpClaudeAiProxyServerConfig), + ManagedProxy(McpManagedProxyServerConfig), } #[derive(Debug, Clone, PartialEq, Eq)] @@ -103,7 +120,7 @@ pub struct McpSdkServerConfig { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct McpClaudeAiProxyServerConfig { +pub struct McpManagedProxyServerConfig { pub url: String, pub id: String, } @@ -167,25 +184,20 @@ impl ConfigLoader { #[must_use] pub fn default_for(cwd: impl Into<PathBuf>) -> Self { let cwd = cwd.into(); - let config_home = std::env::var_os("CLAUDE_CONFIG_HOME") - .map(PathBuf::from) - .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".claude"))) - .or_else(|| { - if cfg!(target_os = "windows") { - std::env::var_os("USERPROFILE").map(|home| PathBuf::from(home).join(".claude")) - } else { - None - } - }) - .unwrap_or_else(|| PathBuf::from(".claude")); + let config_home = default_config_home(); Self { cwd, config_home } } + #[must_use] + pub fn config_home(&self) -> &Path { + &self.config_home + } + #[must_use] pub fn discover(&self) -> Vec<ConfigEntry> { let user_legacy_path = self.config_home.parent().map_or_else( - || PathBuf::from(".claude.json"), - |parent| parent.join(".claude.json"), + || PathBuf::from(".claw.json"), + |parent| parent.join(".claw.json"), ); vec![ ConfigEntry { @@ -198,15 +210,15 @@ impl ConfigLoader { }, ConfigEntry { source: ConfigSource::Project, - path: self.cwd.join(".claude.json"), + path: self.cwd.join(".claw.json"), }, ConfigEntry { source: ConfigSource::Project, - path: self.cwd.join(".claude").join("settings.json"), + path: self.cwd.join(".claw").join("settings.json"), }, ConfigEntry { source: ConfigSource::Local, - path: self.cwd.join(".claude").join("settings.local.json"), + path: self.cwd.join(".claw").join("settings.local.json"), }, ] } @@ -228,6 +240,8 @@ impl ConfigLoader { let merged_value = JsonValue::Object(merged.clone()); let feature_config = RuntimeFeatureConfig { + hooks: parse_optional_hooks_config(&merged_value)?, + plugins: parse_optional_plugin_config(&merged_value)?, mcp: McpConfigCollection { servers: mcp_servers, }, @@ -285,6 +299,16 @@ impl RuntimeConfig { &self.feature_config.mcp } + #[must_use] + pub fn hooks(&self) -> &RuntimeHookConfig { + &self.feature_config.hooks + } + + #[must_use] + pub fn plugins(&self) -> &RuntimePluginConfig { + &self.feature_config.plugins + } + #[must_use] pub fn oauth(&self) -> Option<&OAuthConfig> { self.feature_config.oauth.as_ref() @@ -307,6 +331,28 @@ impl RuntimeConfig { } impl RuntimeFeatureConfig { + #[must_use] + pub fn with_hooks(mut self, hooks: RuntimeHookConfig) -> Self { + self.hooks = hooks; + self + } + + #[must_use] + pub fn with_plugins(mut self, plugins: RuntimePluginConfig) -> Self { + self.plugins = plugins; + self + } + + #[must_use] + pub fn hooks(&self) -> &RuntimeHookConfig { + &self.hooks + } + + #[must_use] + pub fn plugins(&self) -> &RuntimePluginConfig { + &self.plugins + } + #[must_use] pub fn mcp(&self) -> &McpConfigCollection { &self.mcp @@ -333,6 +379,85 @@ impl RuntimeFeatureConfig { } } +impl RuntimePluginConfig { + #[must_use] + pub fn enabled_plugins(&self) -> &BTreeMap<String, bool> { + &self.enabled_plugins + } + + #[must_use] + pub fn external_directories(&self) -> &[String] { + &self.external_directories + } + + #[must_use] + pub fn install_root(&self) -> Option<&str> { + self.install_root.as_deref() + } + + #[must_use] + pub fn registry_path(&self) -> Option<&str> { + self.registry_path.as_deref() + } + + #[must_use] + pub fn bundled_root(&self) -> Option<&str> { + self.bundled_root.as_deref() + } + + pub fn set_plugin_state(&mut self, plugin_id: String, enabled: bool) { + self.enabled_plugins.insert(plugin_id, enabled); + } + + #[must_use] + pub fn state_for(&self, plugin_id: &str, default_enabled: bool) -> bool { + self.enabled_plugins + .get(plugin_id) + .copied() + .unwrap_or(default_enabled) + } +} + +#[must_use] +pub fn default_config_home() -> PathBuf { + std::env::var_os("CLAW_CONFIG_HOME") + .map(PathBuf::from) + .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".claw"))) + .unwrap_or_else(|| PathBuf::from(".claw")) +} + +impl RuntimeHookConfig { + #[must_use] + pub fn new(pre_tool_use: Vec<String>, post_tool_use: Vec<String>) -> Self { + Self { + pre_tool_use, + post_tool_use, + } + } + + #[must_use] + pub fn pre_tool_use(&self) -> &[String] { + &self.pre_tool_use + } + + #[must_use] + pub fn post_tool_use(&self) -> &[String] { + &self.post_tool_use + } + + #[must_use] + pub fn merged(&self, other: &Self) -> Self { + let mut merged = self.clone(); + merged.extend(other); + merged + } + + pub fn extend(&mut self, other: &Self) { + extend_unique(&mut self.pre_tool_use, other.pre_tool_use()); + extend_unique(&mut self.post_tool_use, other.post_tool_use()); + } +} + impl McpConfigCollection { #[must_use] pub fn servers(&self) -> &BTreeMap<String, ScopedMcpServerConfig> { @@ -361,7 +486,7 @@ impl McpServerConfig { Self::Http(_) => McpTransport::Http, Self::Ws(_) => McpTransport::Ws, Self::Sdk(_) => McpTransport::Sdk, - Self::ClaudeAiProxy(_) => McpTransport::ClaudeAiProxy, + Self::ManagedProxy(_) => McpTransport::ManagedProxy, } } } @@ -369,7 +494,7 @@ impl McpServerConfig { fn read_optional_json_object( path: &Path, ) -> Result<Option<BTreeMap<String, JsonValue>>, ConfigError> { - let is_legacy_config = path.file_name().and_then(|name| name.to_str()) == Some(".claude.json"); + let is_legacy_config = path.file_name().and_then(|name| name.to_str()) == Some(".claw.json"); let contents = match fs::read_to_string(path) { Ok(contents) => contents, Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(None), @@ -431,6 +556,52 @@ fn parse_optional_model(root: &JsonValue) -> Option<String> { .map(ToOwned::to_owned) } +fn parse_optional_hooks_config(root: &JsonValue) -> Result<RuntimeHookConfig, ConfigError> { + let Some(object) = root.as_object() else { + return Ok(RuntimeHookConfig::default()); + }; + let Some(hooks_value) = object.get("hooks") else { + return Ok(RuntimeHookConfig::default()); + }; + let hooks = expect_object(hooks_value, "merged settings.hooks")?; + Ok(RuntimeHookConfig { + pre_tool_use: optional_string_array(hooks, "PreToolUse", "merged settings.hooks")? + .unwrap_or_default(), + post_tool_use: optional_string_array(hooks, "PostToolUse", "merged settings.hooks")? + .unwrap_or_default(), + }) +} + +fn parse_optional_plugin_config(root: &JsonValue) -> Result<RuntimePluginConfig, ConfigError> { + let Some(object) = root.as_object() else { + return Ok(RuntimePluginConfig::default()); + }; + + let mut config = RuntimePluginConfig::default(); + if let Some(enabled_plugins) = object.get("enabledPlugins") { + config.enabled_plugins = parse_bool_map(enabled_plugins, "merged settings.enabledPlugins")?; + } + + let Some(plugins_value) = object.get("plugins") else { + return Ok(config); + }; + let plugins = expect_object(plugins_value, "merged settings.plugins")?; + + if let Some(enabled_value) = plugins.get("enabled") { + config.enabled_plugins = parse_bool_map(enabled_value, "merged settings.plugins.enabled")?; + } + config.external_directories = + optional_string_array(plugins, "externalDirectories", "merged settings.plugins")? + .unwrap_or_default(); + config.install_root = + optional_string(plugins, "installRoot", "merged settings.plugins")?.map(str::to_string); + config.registry_path = + optional_string(plugins, "registryPath", "merged settings.plugins")?.map(str::to_string); + config.bundled_root = + optional_string(plugins, "bundledRoot", "merged settings.plugins")?.map(str::to_string); + Ok(config) +} + fn parse_optional_permission_mode( root: &JsonValue, ) -> Result<Option<ResolvedPermissionMode>, ConfigError> { @@ -553,12 +724,10 @@ fn parse_mcp_server_config( "sdk" => Ok(McpServerConfig::Sdk(McpSdkServerConfig { name: expect_string(object, "name", context)?.to_string(), })), - "claudeai-proxy" => Ok(McpServerConfig::ClaudeAiProxy( - McpClaudeAiProxyServerConfig { - url: expect_string(object, "url", context)?.to_string(), - id: expect_string(object, "id", context)?.to_string(), - }, - )), + "claudeai-proxy" => Ok(McpServerConfig::ManagedProxy(McpManagedProxyServerConfig { + url: expect_string(object, "url", context)?.to_string(), + id: expect_string(object, "id", context)?.to_string(), + })), other => Err(ConfigError::Parse(format!( "{context}: unsupported MCP server type for {server_name}: {other}" ))), @@ -663,6 +832,24 @@ fn optional_u16( } } +fn parse_bool_map(value: &JsonValue, context: &str) -> Result<BTreeMap<String, bool>, ConfigError> { + let Some(map) = value.as_object() else { + return Err(ConfigError::Parse(format!( + "{context}: expected JSON object" + ))); + }; + map.iter() + .map(|(key, value)| { + value + .as_bool() + .map(|enabled| (key.clone(), enabled)) + .ok_or_else(|| { + ConfigError::Parse(format!("{context}: field {key} must be a boolean")) + }) + }) + .collect() +} + fn optional_string_array( object: &BTreeMap<String, JsonValue>, key: &str, @@ -737,11 +924,23 @@ fn deep_merge_objects( } } +fn extend_unique(target: &mut Vec<String>, values: &[String]) { + for value in values { + push_unique(target, value.clone()); + } +} + +fn push_unique(target: &mut Vec<String>, value: String) { + if !target.iter().any(|existing| existing == &value) { + target.push(value); + } +} + #[cfg(test)] mod tests { use super::{ ConfigLoader, ConfigSource, McpServerConfig, McpTransport, ResolvedPermissionMode, - CLAUDE_CODE_SETTINGS_SCHEMA_NAME, + CLAW_SETTINGS_SCHEMA_NAME, }; use crate::json::JsonValue; use crate::sandbox::FilesystemIsolationMode; @@ -760,7 +959,7 @@ mod tests { fn rejects_non_object_settings_files() { let root = temp_dir(); let cwd = root.join("project"); - let home = root.join("home").join(".claude"); + let home = root.join("home").join(".claw"); fs::create_dir_all(&home).expect("home config dir"); fs::create_dir_all(&cwd).expect("project dir"); fs::write(home.join("settings.json"), "[]").expect("write bad settings"); @@ -776,15 +975,15 @@ mod tests { } #[test] - fn loads_and_merges_claude_code_config_files_by_precedence() { + fn loads_and_merges_claw_code_config_files_by_precedence() { let root = temp_dir(); let cwd = root.join("project"); - let home = root.join("home").join(".claude"); - fs::create_dir_all(cwd.join(".claude")).expect("project config dir"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); fs::create_dir_all(&home).expect("home config dir"); fs::write( - home.parent().expect("home parent").join(".claude.json"), + home.parent().expect("home parent").join(".claw.json"), r#"{"model":"haiku","env":{"A":"1"},"mcpServers":{"home":{"command":"uvx","args":["home"]}}}"#, ) .expect("write user compat config"); @@ -794,17 +993,17 @@ mod tests { ) .expect("write user settings"); fs::write( - cwd.join(".claude.json"), + cwd.join(".claw.json"), r#"{"model":"project-compat","env":{"B":"2"}}"#, ) .expect("write project compat config"); fs::write( - cwd.join(".claude").join("settings.json"), + cwd.join(".claw").join("settings.json"), r#"{"env":{"C":"3"},"hooks":{"PostToolUse":["project"]},"mcpServers":{"project":{"command":"uvx","args":["project"]}}}"#, ) .expect("write project settings"); fs::write( - cwd.join(".claude").join("settings.local.json"), + cwd.join(".claw").join("settings.local.json"), r#"{"model":"opus","permissionMode":"acceptEdits"}"#, ) .expect("write local settings"); @@ -813,7 +1012,7 @@ mod tests { .load() .expect("config should load"); - assert_eq!(CLAUDE_CODE_SETTINGS_SCHEMA_NAME, "SettingsSchema"); + assert_eq!(CLAW_SETTINGS_SCHEMA_NAME, "SettingsSchema"); assert_eq!(loaded.loaded_entries().len(), 5); assert_eq!(loaded.loaded_entries()[0].source, ConfigSource::User); assert_eq!( @@ -843,6 +1042,8 @@ mod tests { .and_then(JsonValue::as_object) .expect("hooks object") .contains_key("PostToolUse")); + assert_eq!(loaded.hooks().pre_tool_use(), &["base".to_string()]); + assert_eq!(loaded.hooks().post_tool_use(), &["project".to_string()]); assert!(loaded.mcp().get("home").is_some()); assert!(loaded.mcp().get("project").is_some()); @@ -853,12 +1054,12 @@ mod tests { fn parses_sandbox_config() { let root = temp_dir(); let cwd = root.join("project"); - let home = root.join("home").join(".claude"); - fs::create_dir_all(cwd.join(".claude")).expect("project config dir"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); fs::create_dir_all(&home).expect("home config dir"); fs::write( - cwd.join(".claude").join("settings.local.json"), + cwd.join(".claw").join("settings.local.json"), r#"{ "sandbox": { "enabled": true, @@ -891,8 +1092,8 @@ mod tests { fn parses_typed_mcp_and_oauth_config() { let root = temp_dir(); let cwd = root.join("project"); - let home = root.join("home").join(".claude"); - fs::create_dir_all(cwd.join(".claude")).expect("project config dir"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); fs::create_dir_all(&home).expect("home config dir"); fs::write( @@ -929,7 +1130,7 @@ mod tests { ) .expect("write user settings"); fs::write( - cwd.join(".claude").join("settings.local.json"), + cwd.join(".claw").join("settings.local.json"), r#"{ "mcpServers": { "remote-server": { @@ -978,11 +1179,101 @@ mod tests { fs::remove_dir_all(root).expect("cleanup temp dir"); } + #[test] + fn parses_plugin_config_from_enabled_plugins() { + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); + fs::create_dir_all(&home).expect("home config dir"); + + fs::write( + home.join("settings.json"), + r#"{ + "enabledPlugins": { + "tool-guard@builtin": true, + "sample-plugin@external": false + } + }"#, + ) + .expect("write user settings"); + + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + assert_eq!( + loaded.plugins().enabled_plugins().get("tool-guard@builtin"), + Some(&true) + ); + assert_eq!( + loaded + .plugins() + .enabled_plugins() + .get("sample-plugin@external"), + Some(&false) + ); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn parses_plugin_config() { + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); + fs::create_dir_all(&home).expect("home config dir"); + + fs::write( + home.join("settings.json"), + r#"{ + "enabledPlugins": { + "core-helpers@builtin": true + }, + "plugins": { + "externalDirectories": ["./external-plugins"], + "installRoot": "plugin-cache/installed", + "registryPath": "plugin-cache/installed.json", + "bundledRoot": "./bundled-plugins" + } + }"#, + ) + .expect("write plugin settings"); + + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + assert_eq!( + loaded + .plugins() + .enabled_plugins() + .get("core-helpers@builtin"), + Some(&true) + ); + assert_eq!( + loaded.plugins().external_directories(), + &["./external-plugins".to_string()] + ); + assert_eq!( + loaded.plugins().install_root(), + Some("plugin-cache/installed") + ); + assert_eq!( + loaded.plugins().registry_path(), + Some("plugin-cache/installed.json") + ); + assert_eq!(loaded.plugins().bundled_root(), Some("./bundled-plugins")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + #[test] fn rejects_invalid_mcp_server_shapes() { let root = temp_dir(); let cwd = root.join("project"); - let home = root.join("home").join(".claude"); + let home = root.join("home").join(".claw"); fs::create_dir_all(&home).expect("home config dir"); fs::create_dir_all(&cwd).expect("project dir"); fs::write( diff --git a/crates/runtime/src/conversation.rs b/crates/runtime/src/conversation.rs index 625fb25..8411b8d 100644 --- a/crates/runtime/src/conversation.rs +++ b/crates/runtime/src/conversation.rs @@ -4,6 +4,8 @@ use std::fmt::{Display, Formatter}; use crate::compact::{ compact_session, estimate_session_tokens, CompactionConfig, CompactionResult, }; +use crate::config::RuntimeFeatureConfig; +use crate::hooks::{HookRunResult, HookRunner}; use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter}; use crate::session::{ContentBlock, ConversationMessage, Session}; use crate::usage::{TokenUsage, UsageTracker}; @@ -94,6 +96,7 @@ pub struct ConversationRuntime<C, T> { system_prompt: Vec<String>, max_iterations: usize, usage_tracker: UsageTracker, + hook_runner: HookRunner, } impl<C, T> ConversationRuntime<C, T> @@ -108,6 +111,25 @@ where tool_executor: T, permission_policy: PermissionPolicy, system_prompt: Vec<String>, + ) -> Self { + Self::new_with_features( + session, + api_client, + tool_executor, + permission_policy, + system_prompt, + RuntimeFeatureConfig::default(), + ) + } + + #[must_use] + pub fn new_with_features( + session: Session, + api_client: C, + tool_executor: T, + permission_policy: PermissionPolicy, + system_prompt: Vec<String>, + feature_config: RuntimeFeatureConfig, ) -> Self { let usage_tracker = UsageTracker::from_session(&session); Self { @@ -116,8 +138,9 @@ where tool_executor, permission_policy, system_prompt, - max_iterations: 16, + max_iterations: usize::MAX, usage_tracker, + hook_runner: HookRunner::from_feature_config(&feature_config), } } @@ -185,19 +208,41 @@ where let result_message = match permission_outcome { PermissionOutcome::Allow => { - match self.tool_executor.execute(&tool_name, &input) { - Ok(output) => ConversationMessage::tool_result( + let pre_hook_result = self.hook_runner.run_pre_tool_use(&tool_name, &input); + if pre_hook_result.is_denied() { + let deny_message = format!("PreToolUse hook denied tool `{tool_name}`"); + ConversationMessage::tool_result( + tool_use_id, + tool_name, + format_hook_message(&pre_hook_result, &deny_message), + true, + ) + } else { + let (mut output, mut is_error) = + match self.tool_executor.execute(&tool_name, &input) { + Ok(output) => (output, false), + Err(error) => (error.to_string(), true), + }; + output = merge_hook_feedback(pre_hook_result.messages(), output, false); + + let post_hook_result = self + .hook_runner + .run_post_tool_use(&tool_name, &input, &output, is_error); + if post_hook_result.is_denied() { + is_error = true; + } + output = merge_hook_feedback( + post_hook_result.messages(), + output, + post_hook_result.is_denied(), + ); + + ConversationMessage::tool_result( tool_use_id, tool_name, output, - false, - ), - Err(error) => ConversationMessage::tool_result( - tool_use_id, - tool_name, - error.to_string(), - true, - ), + is_error, + ) } } PermissionOutcome::Deny { reason } => { @@ -290,6 +335,32 @@ fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) { } } +fn format_hook_message(result: &HookRunResult, fallback: &str) -> String { + if result.messages().is_empty() { + fallback.to_string() + } else { + result.messages().join("\n") + } +} + +fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String { + if messages.is_empty() { + return output; + } + + let mut sections = Vec::new(); + if !output.trim().is_empty() { + sections.push(output); + } + let label = if denied { + "Hook feedback (denied)" + } else { + "Hook feedback" + }; + sections.push(format!("{label}:\n{}", messages.join("\n"))); + sections.join("\n\n") +} + type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>; #[derive(Default)] @@ -329,6 +400,7 @@ mod tests { StaticToolExecutor, }; use crate::compact::CompactionConfig; + use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; use crate::permissions::{ PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter, PermissionRequest, @@ -503,6 +575,141 @@ mod tests { )); } + #[test] + fn denies_tool_use_when_pre_tool_hook_blocks() { + struct SingleCallApiClient; + impl ApiClient for SingleCallApiClient { + fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { + if request + .messages + .iter() + .any(|message| message.role == MessageRole::Tool) + { + return Ok(vec![ + AssistantEvent::TextDelta("blocked".to_string()), + AssistantEvent::MessageStop, + ]); + } + Ok(vec![ + AssistantEvent::ToolUse { + id: "tool-1".to_string(), + name: "blocked".to_string(), + input: r#"{"path":"secret.txt"}"#.to_string(), + }, + AssistantEvent::MessageStop, + ]) + } + } + + let mut runtime = ConversationRuntime::new_with_features( + Session::new(), + SingleCallApiClient, + StaticToolExecutor::new().register("blocked", |_input| { + panic!("tool should not execute when hook denies") + }), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + vec![shell_snippet("printf 'blocked by hook'; exit 2")], + Vec::new(), + )), + ); + + let summary = runtime + .run_turn("use the tool", None) + .expect("conversation should continue after hook denial"); + + assert_eq!(summary.tool_results.len(), 1); + let ContentBlock::ToolResult { + is_error, output, .. + } = &summary.tool_results[0].blocks[0] + else { + panic!("expected tool result block"); + }; + assert!( + *is_error, + "hook denial should produce an error result: {output}" + ); + assert!( + output.contains("denied tool") || output.contains("blocked by hook"), + "unexpected hook denial output: {output:?}" + ); + } + + #[test] + fn appends_post_tool_hook_feedback_to_tool_result() { + struct TwoCallApiClient { + calls: usize, + } + + impl ApiClient for TwoCallApiClient { + fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { + self.calls += 1; + match self.calls { + 1 => Ok(vec![ + AssistantEvent::ToolUse { + id: "tool-1".to_string(), + name: "add".to_string(), + input: r#"{"lhs":2,"rhs":2}"#.to_string(), + }, + AssistantEvent::MessageStop, + ]), + 2 => { + assert!(request + .messages + .iter() + .any(|message| message.role == MessageRole::Tool)); + Ok(vec![ + AssistantEvent::TextDelta("done".to_string()), + AssistantEvent::MessageStop, + ]) + } + _ => Err(RuntimeError::new("unexpected extra API call")), + } + } + } + + let mut runtime = ConversationRuntime::new_with_features( + Session::new(), + TwoCallApiClient { calls: 0 }, + StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + vec![shell_snippet("printf 'pre hook ran'")], + vec![shell_snippet("printf 'post hook ran'")], + )), + ); + + let summary = runtime + .run_turn("use add", None) + .expect("tool loop succeeds"); + + assert_eq!(summary.tool_results.len(), 1); + let ContentBlock::ToolResult { + is_error, output, .. + } = &summary.tool_results[0].blocks[0] + else { + panic!("expected tool result block"); + }; + assert!( + !*is_error, + "post hook should preserve non-error result: {output:?}" + ); + assert!( + output.contains('4'), + "tool output missing value: {output:?}" + ); + assert!( + output.contains("pre hook ran"), + "tool output missing pre hook feedback: {output:?}" + ); + assert!( + output.contains("post hook ran"), + "tool output missing post hook feedback: {output:?}" + ); + } + #[test] fn reconstructs_usage_tracker_from_restored_session() { struct SimpleApi; @@ -581,4 +788,14 @@ mod tests { MessageRole::System ); } + + #[cfg(windows)] + fn shell_snippet(script: &str) -> String { + script.replace('\'', "\"") + } + + #[cfg(not(windows))] + fn shell_snippet(script: &str) -> String { + script.to_string() + } } diff --git a/crates/runtime/src/file_ops.rs b/crates/runtime/src/file_ops.rs index a647b85..1faf9ab 100644 --- a/crates/runtime/src/file_ops.rs +++ b/crates/runtime/src/file_ops.rs @@ -488,7 +488,7 @@ mod tests { .duration_since(UNIX_EPOCH) .expect("time should move forward") .as_nanos(); - std::env::temp_dir().join(format!("clawd-native-{name}-{unique}")) + std::env::temp_dir().join(format!("claw-native-{name}-{unique}")) } #[test] diff --git a/crates/runtime/src/hooks.rs b/crates/runtime/src/hooks.rs new file mode 100644 index 0000000..63ef9ff --- /dev/null +++ b/crates/runtime/src/hooks.rs @@ -0,0 +1,357 @@ +use std::ffi::OsStr; +use std::process::Command; + +use serde_json::json; + +use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HookEvent { + PreToolUse, + PostToolUse, +} + +impl HookEvent { + fn as_str(self) -> &'static str { + match self { + Self::PreToolUse => "PreToolUse", + Self::PostToolUse => "PostToolUse", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HookRunResult { + denied: bool, + messages: Vec<String>, +} + +impl HookRunResult { + #[must_use] + pub fn allow(messages: Vec<String>) -> Self { + Self { + denied: false, + messages, + } + } + + #[must_use] + pub fn is_denied(&self) -> bool { + self.denied + } + + #[must_use] + pub fn messages(&self) -> &[String] { + &self.messages + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct HookRunner { + config: RuntimeHookConfig, +} + +#[derive(Debug, Clone, Copy)] +struct HookCommandRequest<'a> { + event: HookEvent, + tool_name: &'a str, + tool_input: &'a str, + tool_output: Option<&'a str>, + is_error: bool, + payload: &'a str, +} + +impl HookRunner { + #[must_use] + pub fn new(config: RuntimeHookConfig) -> Self { + Self { config } + } + + #[must_use] + pub fn from_feature_config(feature_config: &RuntimeFeatureConfig) -> Self { + Self::new(feature_config.hooks().clone()) + } + + #[must_use] + pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { + self.run_commands( + HookEvent::PreToolUse, + self.config.pre_tool_use(), + tool_name, + tool_input, + None, + false, + ) + } + + #[must_use] + pub fn run_post_tool_use( + &self, + tool_name: &str, + tool_input: &str, + tool_output: &str, + is_error: bool, + ) -> HookRunResult { + self.run_commands( + HookEvent::PostToolUse, + self.config.post_tool_use(), + tool_name, + tool_input, + Some(tool_output), + is_error, + ) + } + + fn run_commands( + &self, + event: HookEvent, + commands: &[String], + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, + ) -> HookRunResult { + if commands.is_empty() { + return HookRunResult::allow(Vec::new()); + } + + let payload = json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_output": tool_output, + "tool_result_is_error": is_error, + }) + .to_string(); + + let mut messages = Vec::new(); + + for command in commands { + match Self::run_command( + command, + HookCommandRequest { + event, + tool_name, + tool_input, + tool_output, + is_error, + payload: &payload, + }, + ) { + HookCommandOutcome::Allow { message } => { + if let Some(message) = message { + messages.push(message); + } + } + HookCommandOutcome::Deny { message } => { + let message = message.unwrap_or_else(|| { + format!("{} hook denied tool `{tool_name}`", event.as_str()) + }); + messages.push(message); + return HookRunResult { + denied: true, + messages, + }; + } + HookCommandOutcome::Warn { message } => messages.push(message), + } + } + + HookRunResult::allow(messages) + } + + fn run_command(command: &str, request: HookCommandRequest<'_>) -> HookCommandOutcome { + let mut child = shell_command(command); + child.stdin(std::process::Stdio::piped()); + child.stdout(std::process::Stdio::piped()); + child.stderr(std::process::Stdio::piped()); + child.env("HOOK_EVENT", request.event.as_str()); + child.env("HOOK_TOOL_NAME", request.tool_name); + child.env("HOOK_TOOL_INPUT", request.tool_input); + child.env( + "HOOK_TOOL_IS_ERROR", + if request.is_error { "1" } else { "0" }, + ); + if let Some(tool_output) = request.tool_output { + child.env("HOOK_TOOL_OUTPUT", tool_output); + } + + match child.output_with_stdin(request.payload.as_bytes()) { + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + let message = (!stdout.is_empty()).then_some(stdout); + match output.status.code() { + Some(0) => HookCommandOutcome::Allow { message }, + Some(2) => HookCommandOutcome::Deny { message }, + Some(code) => HookCommandOutcome::Warn { + message: format_hook_warning( + command, + code, + message.as_deref(), + stderr.as_str(), + ), + }, + None => HookCommandOutcome::Warn { + message: format!( + "{} hook `{command}` terminated by signal while handling `{}`", + request.event.as_str(), + request.tool_name + ), + }, + } + } + Err(error) => HookCommandOutcome::Warn { + message: format!( + "{} hook `{command}` failed to start for `{}`: {error}", + request.event.as_str(), + request.tool_name + ), + }, + } + } +} + +enum HookCommandOutcome { + Allow { message: Option<String> }, + Deny { message: Option<String> }, + Warn { message: String }, +} + +fn parse_tool_input(tool_input: &str) -> serde_json::Value { + serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input })) +} + +fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String { + let mut message = + format!("Hook `{command}` exited with status {code}; allowing tool execution to continue"); + if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) { + message.push_str(": "); + message.push_str(stdout); + } else if !stderr.is_empty() { + message.push_str(": "); + message.push_str(stderr); + } + message +} + +fn shell_command(command: &str) -> CommandWithStdin { + #[cfg(windows)] + let mut command_builder = { + let mut command_builder = Command::new("cmd"); + command_builder.arg("/C").arg(command); + CommandWithStdin::new(command_builder) + }; + + #[cfg(not(windows))] + let command_builder = { + let mut command_builder = Command::new("sh"); + command_builder.arg("-lc").arg(command); + CommandWithStdin::new(command_builder) + }; + + command_builder +} + +struct CommandWithStdin { + command: Command, +} + +impl CommandWithStdin { + fn new(command: Command) -> Self { + Self { command } + } + + fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stdin(cfg); + self + } + + fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stdout(cfg); + self + } + + fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stderr(cfg); + self + } + + fn env<K, V>(&mut self, key: K, value: V) -> &mut Self + where + K: AsRef<OsStr>, + V: AsRef<OsStr>, + { + self.command.env(key, value); + self + } + + fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result<std::process::Output> { + let mut child = self.command.spawn()?; + if let Some(mut child_stdin) = child.stdin.take() { + use std::io::Write; + child_stdin.write_all(stdin)?; + } + child.wait_with_output() + } +} + +#[cfg(test)] +mod tests { + use super::{HookRunResult, HookRunner}; + use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; + + #[test] + fn allows_exit_code_zero_and_captures_stdout() { + let runner = HookRunner::new(RuntimeHookConfig::new( + vec![shell_snippet("printf 'pre ok'")], + Vec::new(), + )); + + let result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#); + + assert_eq!(result, HookRunResult::allow(vec!["pre ok".to_string()])); + } + + #[test] + fn denies_exit_code_two() { + let runner = HookRunner::new(RuntimeHookConfig::new( + vec![shell_snippet("printf 'blocked by hook'; exit 2")], + Vec::new(), + )); + + let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); + + assert!(result.is_denied()); + assert_eq!(result.messages(), &["blocked by hook".to_string()]); + } + + #[test] + fn warns_for_other_non_zero_statuses() { + let runner = HookRunner::from_feature_config(&RuntimeFeatureConfig::default().with_hooks( + RuntimeHookConfig::new( + vec![shell_snippet("printf 'warning hook'; exit 1")], + Vec::new(), + ), + )); + + let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#); + + assert!(!result.is_denied()); + assert!(result + .messages() + .iter() + .any(|message| message.contains("allowing tool execution to continue"))); + } + + #[cfg(windows)] + fn shell_snippet(script: &str) -> String { + script.replace('\'', "\"") + } + + #[cfg(not(windows))] + fn shell_snippet(script: &str) -> String { + script.to_string() + } +} diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index 2861d47..c714f95 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -4,6 +4,7 @@ mod compact; mod config; mod conversation; mod file_ops; +mod hooks; mod json; mod mcp; mod mcp_client; @@ -16,6 +17,10 @@ pub mod sandbox; mod session; mod usage; +pub use lsp::{ + FileDiagnostics, LspContextEnrichment, LspError, LspManager, LspServerConfig, + SymbolLocation, WorkspaceDiagnostics, +}; pub use bash::{execute_bash, BashCommandInput, BashCommandOutput}; pub use bootstrap::{BootstrapPhase, BootstrapPlan}; pub use compact::{ @@ -23,11 +28,11 @@ pub use compact::{ get_compact_continuation_message, should_compact, CompactionConfig, CompactionResult, }; pub use config::{ - ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpClaudeAiProxyServerConfig, + ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpManagedProxyServerConfig, McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig, - ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, ScopedMcpServerConfig, - CLAUDE_CODE_SETTINGS_SCHEMA_NAME, + ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, RuntimeHookConfig, + RuntimePluginConfig, ScopedMcpServerConfig, CLAW_SETTINGS_SCHEMA_NAME, }; pub use conversation::{ ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor, @@ -38,12 +43,13 @@ pub use file_ops::{ GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload, WriteFileOutput, }; +pub use hooks::{HookEvent, HookRunResult, HookRunner}; pub use mcp::{ mcp_server_signature, mcp_tool_name, mcp_tool_prefix, normalize_name_for_mcp, scoped_mcp_config_hash, unwrap_ccr_proxy_url, }; pub use mcp_client::{ - McpClaudeAiProxyTransport, McpClientAuth, McpClientBootstrap, McpClientTransport, + McpManagedProxyTransport, McpClientAuth, McpClientBootstrap, McpClientTransport, McpRemoteTransport, McpSdkTransport, McpStdioTransport, }; pub use mcp_stdio::{ diff --git a/crates/runtime/src/mcp.rs b/crates/runtime/src/mcp.rs index 103fbe4..b37ea33 100644 --- a/crates/runtime/src/mcp.rs +++ b/crates/runtime/src/mcp.rs @@ -73,7 +73,7 @@ pub fn mcp_server_signature(config: &McpServerConfig) -> Option<String> { Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url))) } McpServerConfig::Ws(config) => Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url))), - McpServerConfig::ClaudeAiProxy(config) => { + McpServerConfig::ManagedProxy(config) => { Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url))) } McpServerConfig::Sdk(_) => None, @@ -110,7 +110,7 @@ pub fn scoped_mcp_config_hash(config: &ScopedMcpServerConfig) -> String { ws.headers_helper.as_deref().unwrap_or("") ), McpServerConfig::Sdk(sdk) => format!("sdk|{}", sdk.name), - McpServerConfig::ClaudeAiProxy(proxy) => { + McpServerConfig::ManagedProxy(proxy) => { format!("claudeai-proxy|{}|{}", proxy.url, proxy.id) } }; diff --git a/crates/runtime/src/mcp_client.rs b/crates/runtime/src/mcp_client.rs index 23ccb95..e0e1f2c 100644 --- a/crates/runtime/src/mcp_client.rs +++ b/crates/runtime/src/mcp_client.rs @@ -10,7 +10,7 @@ pub enum McpClientTransport { Http(McpRemoteTransport), WebSocket(McpRemoteTransport), Sdk(McpSdkTransport), - ClaudeAiProxy(McpClaudeAiProxyTransport), + ManagedProxy(McpManagedProxyTransport), } #[derive(Debug, Clone, PartialEq, Eq)] @@ -34,7 +34,7 @@ pub struct McpSdkTransport { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct McpClaudeAiProxyTransport { +pub struct McpManagedProxyTransport { pub url: String, pub id: String, } @@ -97,12 +97,10 @@ impl McpClientTransport { McpServerConfig::Sdk(config) => Self::Sdk(McpSdkTransport { name: config.name.clone(), }), - McpServerConfig::ClaudeAiProxy(config) => { - Self::ClaudeAiProxy(McpClaudeAiProxyTransport { - url: config.url.clone(), - id: config.id.clone(), - }) - } + McpServerConfig::ManagedProxy(config) => Self::ManagedProxy(McpManagedProxyTransport { + url: config.url.clone(), + id: config.id.clone(), + }), } } } diff --git a/crates/runtime/src/mcp_stdio.rs b/crates/runtime/src/mcp_stdio.rs index 7e67d5d..27402d6 100644 --- a/crates/runtime/src/mcp_stdio.rs +++ b/crates/runtime/src/mcp_stdio.rs @@ -809,6 +809,7 @@ mod tests { use std::io::ErrorKind; use std::os::unix::fs::PermissionsExt; use std::path::{Path, PathBuf}; + use std::process::Command; use std::time::{SystemTime, UNIX_EPOCH}; use serde_json::json; @@ -1137,15 +1138,37 @@ mod tests { fn script_transport(script_path: &Path) -> crate::mcp_client::McpStdioTransport { crate::mcp_client::McpStdioTransport { - command: "python3".to_string(), + command: python_command(), args: vec![script_path.to_string_lossy().into_owned()], env: BTreeMap::new(), } } + fn python_command() -> String { + for key in ["MCP_TEST_PYTHON", "PYTHON3", "PYTHON"] { + if let Ok(value) = std::env::var(key) { + if !value.trim().is_empty() { + return value; + } + } + } + + for candidate in ["python3", "python"] { + if Command::new(candidate).arg("--version").output().is_ok() { + return candidate.to_string(); + } + } + + panic!("expected a Python interpreter for MCP stdio tests") + } + fn cleanup_script(script_path: &Path) { - fs::remove_file(script_path).expect("cleanup script"); - fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir"); + if let Err(error) = fs::remove_file(script_path) { + assert_eq!(error.kind(), std::io::ErrorKind::NotFound, "cleanup script"); + } + if let Err(error) = fs::remove_dir_all(script_path.parent().expect("script parent")) { + assert_eq!(error.kind(), std::io::ErrorKind::NotFound, "cleanup dir"); + } } fn manager_server_config( @@ -1156,7 +1179,7 @@ mod tests { ScopedMcpServerConfig { scope: ConfigSource::Local, config: McpServerConfig::Stdio(McpStdioServerConfig { - command: "python3".to_string(), + command: python_command(), args: vec![script_path.to_string_lossy().into_owned()], env: BTreeMap::from([ ("MCP_SERVER_LABEL".to_string(), label.to_string()), diff --git a/crates/runtime/src/oauth.rs b/crates/runtime/src/oauth.rs index 837bdf2..e4756c1 100644 --- a/crates/runtime/src/oauth.rs +++ b/crates/runtime/src/oauth.rs @@ -324,15 +324,15 @@ fn generate_random_token(bytes: usize) -> io::Result<String> { } fn credentials_home_dir() -> io::Result<PathBuf> { - if let Some(path) = std::env::var_os("CLAUDE_CONFIG_HOME") { + if let Some(path) = std::env::var_os("CLAW_CONFIG_HOME") { return Ok(PathBuf::from(path)); } if let Some(path) = std::env::var_os("HOME") { - return Ok(PathBuf::from(path).join(".claude")); + return Ok(PathBuf::from(path).join(".claw")); } if cfg!(target_os = "windows") { if let Some(path) = std::env::var_os("USERPROFILE") { - return Ok(PathBuf::from(path).join(".claude")); + return Ok(PathBuf::from(path).join(".claw")); } } Err(io::Error::new(io::ErrorKind::NotFound, "HOME or USERPROFILE is not set")) @@ -547,7 +547,7 @@ mod tests { fn oauth_credentials_round_trip_and_clear_preserves_other_fields() { let _guard = env_lock(); let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); let path = credentials_path().expect("credentials path"); std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent"); std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials"); @@ -573,7 +573,7 @@ mod tests { assert!(cleared.contains("\"other\": \"value\"")); assert!(!cleared.contains("\"oauth\"")); - std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::env::remove_var("CLAW_CONFIG_HOME"); std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); } diff --git a/crates/runtime/src/prompt.rs b/crates/runtime/src/prompt.rs index 7192412..d3b09e3 100644 --- a/crates/runtime/src/prompt.rs +++ b/crates/runtime/src/prompt.rs @@ -4,6 +4,7 @@ use std::path::{Path, PathBuf}; use std::process::Command; use crate::config::{ConfigError, ConfigLoader, RuntimeConfig}; +use lsp::LspContextEnrichment; #[derive(Debug)] pub enum PromptBuildError { @@ -35,7 +36,7 @@ impl From<ConfigError> for PromptBuildError { } pub const SYSTEM_PROMPT_DYNAMIC_BOUNDARY: &str = "__SYSTEM_PROMPT_DYNAMIC_BOUNDARY__"; -pub const FRONTIER_MODEL_NAME: &str = "Claude Opus 4.6"; +pub const FRONTIER_MODEL_NAME: &str = "Opus 4.6"; const MAX_INSTRUCTION_FILE_CHARS: usize = 4_000; const MAX_TOTAL_INSTRUCTION_CHARS: usize = 12_000; @@ -130,6 +131,15 @@ impl SystemPromptBuilder { self } + #[must_use] + pub fn with_lsp_context(mut self, enrichment: &LspContextEnrichment) -> Self { + if !enrichment.is_empty() { + self.append_sections + .push(enrichment.render_prompt_section()); + } + self + } + #[must_use] pub fn build(&self) -> Vec<String> { let mut sections = Vec::new(); @@ -201,10 +211,10 @@ fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> { let mut files = Vec::new(); for dir in directories { for candidate in [ - dir.join("CLAUDE.md"), - dir.join("CLAUDE.local.md"), - dir.join(".claude").join("CLAUDE.md"), - dir.join(".claude").join("instructions.md"), + dir.join("CLAW.md"), + dir.join("CLAW.local.md"), + dir.join(".claw").join("CLAW.md"), + dir.join(".claw").join("instructions.md"), ] { push_context_file(&mut files, candidate)?; } @@ -282,7 +292,7 @@ fn render_project_context(project_context: &ProjectContext) -> String { ]; if !project_context.instruction_files.is_empty() { bullets.push(format!( - "Claude instruction files discovered: {}.", + "Claw instruction files discovered: {}.", project_context.instruction_files.len() )); } @@ -301,7 +311,7 @@ fn render_project_context(project_context: &ProjectContext) -> String { } fn render_instruction_files(files: &[ContextFile]) -> String { - let mut sections = vec!["# Claude instructions".to_string()]; + let mut sections = vec!["# Claw instructions".to_string()]; let mut remaining_chars = MAX_TOTAL_INSTRUCTION_CHARS; for file in files { if remaining_chars == 0 { @@ -421,7 +431,7 @@ fn render_config_section(config: &RuntimeConfig) -> String { let mut lines = vec!["# Runtime config".to_string()]; if config.loaded_entries().is_empty() { lines.extend(prepend_bullets(vec![ - "No Claude Code settings files loaded.".to_string(), + "No Claw Code settings files loaded.".to_string() ])); return lines.join("\n"); } @@ -517,23 +527,23 @@ mod tests { fn discovers_instruction_files_from_ancestor_chain() { let root = temp_dir(); let nested = root.join("apps").join("api"); - fs::create_dir_all(nested.join(".claude")).expect("nested claude dir"); - fs::write(root.join("CLAUDE.md"), "root instructions").expect("write root instructions"); - fs::write(root.join("CLAUDE.local.md"), "local instructions") + fs::create_dir_all(nested.join(".claw")).expect("nested claw dir"); + fs::write(root.join("CLAW.md"), "root instructions").expect("write root instructions"); + fs::write(root.join("CLAW.local.md"), "local instructions") .expect("write local instructions"); fs::create_dir_all(root.join("apps")).expect("apps dir"); - fs::create_dir_all(root.join("apps").join(".claude")).expect("apps claude dir"); - fs::write(root.join("apps").join("CLAUDE.md"), "apps instructions") + fs::create_dir_all(root.join("apps").join(".claw")).expect("apps claw dir"); + fs::write(root.join("apps").join("CLAW.md"), "apps instructions") .expect("write apps instructions"); fs::write( - root.join("apps").join(".claude").join("instructions.md"), - "apps dot claude instructions", + root.join("apps").join(".claw").join("instructions.md"), + "apps dot claw instructions", ) - .expect("write apps dot claude instructions"); - fs::write(nested.join(".claude").join("CLAUDE.md"), "nested rules") + .expect("write apps dot claw instructions"); + fs::write(nested.join(".claw").join("CLAW.md"), "nested rules") .expect("write nested rules"); fs::write( - nested.join(".claude").join("instructions.md"), + nested.join(".claw").join("instructions.md"), "nested instructions", ) .expect("write nested instructions"); @@ -551,7 +561,7 @@ mod tests { "root instructions", "local instructions", "apps instructions", - "apps dot claude instructions", + "apps dot claw instructions", "nested rules", "nested instructions" ] @@ -564,8 +574,8 @@ mod tests { let root = temp_dir(); let nested = root.join("apps").join("api"); fs::create_dir_all(&nested).expect("nested dir"); - fs::write(root.join("CLAUDE.md"), "same rules\n\n").expect("write root"); - fs::write(nested.join("CLAUDE.md"), "same rules\n").expect("write nested"); + fs::write(root.join("CLAW.md"), "same rules\n\n").expect("write root"); + fs::write(nested.join("CLAW.md"), "same rules\n").expect("write nested"); let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load"); assert_eq!(context.instruction_files.len(), 1); @@ -593,13 +603,14 @@ mod tests { #[test] fn displays_context_paths_compactly() { assert_eq!( - display_context_path(Path::new("/tmp/project/.claude/CLAUDE.md")), - "CLAUDE.md" + display_context_path(Path::new("/tmp/project/.claw/CLAW.md")), + "CLAW.md" ); } #[test] fn discover_with_git_includes_status_snapshot() { + let _guard = env_lock(); let root = temp_dir(); fs::create_dir_all(&root).expect("root dir"); std::process::Command::new("git") @@ -607,7 +618,7 @@ mod tests { .current_dir(&root) .status() .expect("git init should run"); - fs::write(root.join("CLAUDE.md"), "rules").expect("write instructions"); + fs::write(root.join("CLAW.md"), "rules").expect("write instructions"); fs::write(root.join("tracked.txt"), "hello").expect("write tracked file"); let context = @@ -615,7 +626,7 @@ mod tests { let status = context.git_status.expect("git status should be present"); assert!(status.contains("## No commits yet on") || status.contains("## ")); - assert!(status.contains("?? CLAUDE.md")); + assert!(status.contains("?? CLAW.md")); assert!(status.contains("?? tracked.txt")); assert!(context.git_diff.is_none()); @@ -624,6 +635,7 @@ mod tests { #[test] fn discover_with_git_includes_diff_snapshot_for_tracked_changes() { + let _guard = env_lock(); let root = temp_dir(); fs::create_dir_all(&root).expect("root dir"); std::process::Command::new("git") @@ -665,12 +677,12 @@ mod tests { } #[test] - fn load_system_prompt_reads_claude_files_and_config() { + fn load_system_prompt_reads_claw_files_and_config() { let root = temp_dir(); - fs::create_dir_all(root.join(".claude")).expect("claude dir"); - fs::write(root.join("CLAUDE.md"), "Project rules").expect("write instructions"); + fs::create_dir_all(root.join(".claw")).expect("claw dir"); + fs::write(root.join("CLAW.md"), "Project rules").expect("write instructions"); fs::write( - root.join(".claude").join("settings.json"), + root.join(".claw").join("settings.json"), r#"{"permissionMode":"acceptEdits"}"#, ) .expect("write settings"); @@ -678,9 +690,9 @@ mod tests { let _guard = env_lock(); let previous = std::env::current_dir().expect("cwd"); let original_home = std::env::var("HOME").ok(); - let original_claude_home = std::env::var("CLAUDE_CONFIG_HOME").ok(); + let original_claw_home = std::env::var("CLAW_CONFIG_HOME").ok(); std::env::set_var("HOME", &root); - std::env::set_var("CLAUDE_CONFIG_HOME", root.join("missing-home")); + std::env::set_var("CLAW_CONFIG_HOME", root.join("missing-home")); std::env::set_current_dir(&root).expect("change cwd"); let prompt = super::load_system_prompt(&root, "2026-03-31", "linux", "6.8") .expect("system prompt should load") @@ -695,10 +707,10 @@ mod tests { } else { std::env::remove_var("HOME"); } - if let Some(value) = original_claude_home { - std::env::set_var("CLAUDE_CONFIG_HOME", value); + if let Some(value) = original_claw_home { + std::env::set_var("CLAW_CONFIG_HOME", value); } else { - std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::env::remove_var("CLAW_CONFIG_HOME"); } assert!(prompt.contains("Project rules")); @@ -707,12 +719,12 @@ mod tests { } #[test] - fn renders_claude_code_style_sections_with_project_context() { + fn renders_claw_code_style_sections_with_project_context() { let root = temp_dir(); - fs::create_dir_all(root.join(".claude")).expect("claude dir"); - fs::write(root.join("CLAUDE.md"), "Project rules").expect("write CLAUDE.md"); + fs::create_dir_all(root.join(".claw")).expect("claw dir"); + fs::write(root.join("CLAW.md"), "Project rules").expect("write CLAW.md"); fs::write( - root.join(".claude").join("settings.json"), + root.join(".claw").join("settings.json"), r#"{"permissionMode":"acceptEdits"}"#, ) .expect("write settings"); @@ -731,7 +743,7 @@ mod tests { assert!(prompt.contains("# System")); assert!(prompt.contains("# Project context")); - assert!(prompt.contains("# Claude instructions")); + assert!(prompt.contains("# Claw instructions")); assert!(prompt.contains("Project rules")); assert!(prompt.contains("permissionMode")); assert!(prompt.contains(SYSTEM_PROMPT_DYNAMIC_BOUNDARY)); @@ -748,12 +760,12 @@ mod tests { } #[test] - fn discovers_dot_claude_instructions_markdown() { + fn discovers_dot_claw_instructions_markdown() { let root = temp_dir(); let nested = root.join("apps").join("api"); - fs::create_dir_all(nested.join(".claude")).expect("nested claude dir"); + fs::create_dir_all(nested.join(".claw")).expect("nested claw dir"); fs::write( - nested.join(".claude").join("instructions.md"), + nested.join(".claw").join("instructions.md"), "instruction markdown", ) .expect("write instructions.md"); @@ -762,7 +774,7 @@ mod tests { assert!(context .instruction_files .iter() - .any(|file| file.path.ends_with(".claude/instructions.md"))); + .any(|file| file.path.ends_with(".claw/instructions.md"))); assert!( render_instruction_files(&context.instruction_files).contains("instruction markdown") ); @@ -773,10 +785,10 @@ mod tests { #[test] fn renders_instruction_file_metadata() { let rendered = render_instruction_files(&[ContextFile { - path: PathBuf::from("/tmp/project/CLAUDE.md"), + path: PathBuf::from("/tmp/project/CLAW.md"), content: "Project rules".to_string(), }]); - assert!(rendered.contains("# Claude instructions")); + assert!(rendered.contains("# Claw instructions")); assert!(rendered.contains("scope: /tmp/project")); assert!(rendered.contains("Project rules")); } diff --git a/crates/runtime/src/remote.rs b/crates/runtime/src/remote.rs index 24ee780..5fe59a0 100644 --- a/crates/runtime/src/remote.rs +++ b/crates/runtime/src/remote.rs @@ -72,9 +72,9 @@ impl RemoteSessionContext { #[must_use] pub fn from_env_map(env_map: &BTreeMap<String, String>) -> Self { Self { - enabled: env_truthy(env_map.get("CLAUDE_CODE_REMOTE")), + enabled: env_truthy(env_map.get("CLAW_CODE_REMOTE")), session_id: env_map - .get("CLAUDE_CODE_REMOTE_SESSION_ID") + .get("CLAW_CODE_REMOTE_SESSION_ID") .filter(|value| !value.is_empty()) .cloned(), base_url: env_map @@ -272,9 +272,9 @@ mod tests { #[test] fn remote_context_reads_env_state() { let env = BTreeMap::from([ - ("CLAUDE_CODE_REMOTE".to_string(), "true".to_string()), + ("CLAW_CODE_REMOTE".to_string(), "true".to_string()), ( - "CLAUDE_CODE_REMOTE_SESSION_ID".to_string(), + "CLAW_CODE_REMOTE_SESSION_ID".to_string(), "session-123".to_string(), ), ( @@ -291,7 +291,7 @@ mod tests { #[test] fn bootstrap_fails_open_when_token_or_session_is_missing() { let env = BTreeMap::from([ - ("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()), + ("CLAW_CODE_REMOTE".to_string(), "1".to_string()), ("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()), ]); let bootstrap = UpstreamProxyBootstrap::from_env_map(&env); @@ -307,10 +307,10 @@ mod tests { fs::write(&token_path, "secret-token\n").expect("write token"); let env = BTreeMap::from([ - ("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()), + ("CLAW_CODE_REMOTE".to_string(), "1".to_string()), ("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()), ( - "CLAUDE_CODE_REMOTE_SESSION_ID".to_string(), + "CLAW_CODE_REMOTE_SESSION_ID".to_string(), "session-123".to_string(), ), ( diff --git a/crates/runtime/src/session.rs b/crates/runtime/src/session.rs index beaa435..ec37070 100644 --- a/crates/runtime/src/session.rs +++ b/crates/runtime/src/session.rs @@ -3,10 +3,13 @@ use std::fmt::{Display, Formatter}; use std::fs; use std::path::Path; +use serde::{Deserialize, Serialize}; + use crate::json::{JsonError, JsonValue}; use crate::usage::TokenUsage; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] pub enum MessageRole { System, User, @@ -14,7 +17,8 @@ pub enum MessageRole { Tool, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] pub enum ContentBlock { Text { text: String, @@ -32,14 +36,14 @@ pub enum ContentBlock { }, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct ConversationMessage { pub role: MessageRole, pub blocks: Vec<ContentBlock>, pub usage: Option<TokenUsage>, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct Session { pub version: u32, pub messages: Vec<ConversationMessage>, diff --git a/crates/runtime/src/usage.rs b/crates/runtime/src/usage.rs index 04e28df..0570bc1 100644 --- a/crates/runtime/src/usage.rs +++ b/crates/runtime/src/usage.rs @@ -1,4 +1,5 @@ use crate::session::Session; +use serde::{Deserialize, Serialize}; const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0; const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0; @@ -25,7 +26,7 @@ impl ModelPricing { } } -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)] pub struct TokenUsage { pub input_tokens: u32, pub output_tokens: u32, @@ -249,9 +250,9 @@ mod tests { let cost = usage.estimate_cost_usd(); assert_eq!(format_usd(cost.input_cost_usd), "$15.0000"); assert_eq!(format_usd(cost.output_cost_usd), "$37.5000"); - let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-20250514")); + let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-6")); assert!(lines[0].contains("estimated_cost=$54.6750")); - assert!(lines[0].contains("model=claude-sonnet-4-20250514")); + assert!(lines[0].contains("model=claude-sonnet-4-6")); assert!(lines[1].contains("cache_read=$0.3000")); } @@ -264,7 +265,7 @@ mod tests { cache_read_input_tokens: 0, }; - let haiku = pricing_for_model("claude-haiku-4-5-20251001").expect("haiku pricing"); + let haiku = pricing_for_model("claude-haiku-4-5-20251213").expect("haiku pricing"); let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing"); let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku); let opus_cost = usage.estimate_cost_usd_with_pricing(opus); diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml new file mode 100644 index 0000000..9151aef --- /dev/null +++ b/crates/server/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "server" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true + +[dependencies] +async-stream = "0.3" +axum = "0.8" +runtime = { path = "../runtime" } +serde = { version = "1", features = ["derive"] } +serde_json.workspace = true +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync", "net", "time"] } + +[dev-dependencies] +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "stream"] } + +[lints] +workspace = true diff --git a/crates/server/src/lib.rs b/crates/server/src/lib.rs new file mode 100644 index 0000000..b3386ea --- /dev/null +++ b/crates/server/src/lib.rs @@ -0,0 +1,442 @@ +use std::collections::HashMap; +use std::convert::Infallible; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use async_stream::stream; +use axum::extract::{Path, State}; +use axum::http::StatusCode; +use axum::response::sse::{Event, KeepAlive, Sse}; +use axum::response::IntoResponse; +use axum::routing::{get, post}; +use axum::{Json, Router}; +use runtime::{ConversationMessage, Session as RuntimeSession}; +use serde::{Deserialize, Serialize}; +use tokio::sync::{broadcast, RwLock}; + +pub type SessionId = String; +pub type SessionStore = Arc<RwLock<HashMap<SessionId, Session>>>; + +const BROADCAST_CAPACITY: usize = 64; + +#[derive(Clone)] +pub struct AppState { + sessions: SessionStore, + next_session_id: Arc<AtomicU64>, +} + +impl AppState { + #[must_use] + pub fn new() -> Self { + Self { + sessions: Arc::new(RwLock::new(HashMap::new())), + next_session_id: Arc::new(AtomicU64::new(1)), + } + } + + fn allocate_session_id(&self) -> SessionId { + let id = self.next_session_id.fetch_add(1, Ordering::Relaxed); + format!("session-{id}") + } +} + +impl Default for AppState { + fn default() -> Self { + Self::new() + } +} + +#[derive(Clone)] +pub struct Session { + pub id: SessionId, + pub created_at: u64, + pub conversation: RuntimeSession, + events: broadcast::Sender<SessionEvent>, +} + +impl Session { + fn new(id: SessionId) -> Self { + let (events, _) = broadcast::channel(BROADCAST_CAPACITY); + Self { + id, + created_at: unix_timestamp_millis(), + conversation: RuntimeSession::new(), + events, + } + } + + fn subscribe(&self) -> broadcast::Receiver<SessionEvent> { + self.events.subscribe() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +enum SessionEvent { + Snapshot { + session_id: SessionId, + session: RuntimeSession, + }, + Message { + session_id: SessionId, + message: ConversationMessage, + }, +} + +impl SessionEvent { + fn event_name(&self) -> &'static str { + match self { + Self::Snapshot { .. } => "snapshot", + Self::Message { .. } => "message", + } + } + + fn to_sse_event(&self) -> Result<Event, serde_json::Error> { + Ok(Event::default() + .event(self.event_name()) + .data(serde_json::to_string(self)?)) + } +} + +#[derive(Debug, Serialize)] +struct ErrorResponse { + error: String, +} + +type ApiError = (StatusCode, Json<ErrorResponse>); +type ApiResult<T> = Result<T, ApiError>; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct CreateSessionResponse { + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionSummary { + pub id: SessionId, + pub created_at: u64, + pub message_count: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ListSessionsResponse { + pub sessions: Vec<SessionSummary>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionDetailsResponse { + pub id: SessionId, + pub created_at: u64, + pub session: RuntimeSession, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SendMessageRequest { + pub message: String, +} + +#[must_use] +pub fn app(state: AppState) -> Router { + Router::new() + .route("/sessions", post(create_session).get(list_sessions)) + .route("/sessions/{id}", get(get_session)) + .route("/sessions/{id}/events", get(stream_session_events)) + .route("/sessions/{id}/message", post(send_message)) + .with_state(state) +} + +async fn create_session( + State(state): State<AppState>, +) -> (StatusCode, Json<CreateSessionResponse>) { + let session_id = state.allocate_session_id(); + let session = Session::new(session_id.clone()); + + state + .sessions + .write() + .await + .insert(session_id.clone(), session); + + ( + StatusCode::CREATED, + Json(CreateSessionResponse { session_id }), + ) +} + +async fn list_sessions(State(state): State<AppState>) -> Json<ListSessionsResponse> { + let sessions = state.sessions.read().await; + let mut summaries = sessions + .values() + .map(|session| SessionSummary { + id: session.id.clone(), + created_at: session.created_at, + message_count: session.conversation.messages.len(), + }) + .collect::<Vec<_>>(); + summaries.sort_by(|left, right| left.id.cmp(&right.id)); + + Json(ListSessionsResponse { + sessions: summaries, + }) +} + +async fn get_session( + State(state): State<AppState>, + Path(id): Path<SessionId>, +) -> ApiResult<Json<SessionDetailsResponse>> { + let sessions = state.sessions.read().await; + let session = sessions + .get(&id) + .ok_or_else(|| not_found(format!("session `{id}` not found")))?; + + Ok(Json(SessionDetailsResponse { + id: session.id.clone(), + created_at: session.created_at, + session: session.conversation.clone(), + })) +} + +async fn send_message( + State(state): State<AppState>, + Path(id): Path<SessionId>, + Json(payload): Json<SendMessageRequest>, +) -> ApiResult<StatusCode> { + let message = ConversationMessage::user_text(payload.message); + let broadcaster = { + let mut sessions = state.sessions.write().await; + let session = sessions + .get_mut(&id) + .ok_or_else(|| not_found(format!("session `{id}` not found")))?; + session.conversation.messages.push(message.clone()); + session.events.clone() + }; + + let _ = broadcaster.send(SessionEvent::Message { + session_id: id, + message, + }); + + Ok(StatusCode::NO_CONTENT) +} + +async fn stream_session_events( + State(state): State<AppState>, + Path(id): Path<SessionId>, +) -> ApiResult<impl IntoResponse> { + let (snapshot, mut receiver) = { + let sessions = state.sessions.read().await; + let session = sessions + .get(&id) + .ok_or_else(|| not_found(format!("session `{id}` not found")))?; + ( + SessionEvent::Snapshot { + session_id: session.id.clone(), + session: session.conversation.clone(), + }, + session.subscribe(), + ) + }; + + let stream = stream! { + if let Ok(event) = snapshot.to_sse_event() { + yield Ok::<Event, Infallible>(event); + } + + loop { + match receiver.recv().await { + Ok(event) => { + if let Ok(sse_event) = event.to_sse_event() { + yield Ok::<Event, Infallible>(sse_event); + } + } + Err(broadcast::error::RecvError::Lagged(_)) => continue, + Err(broadcast::error::RecvError::Closed) => break, + } + } + }; + + Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))) +} + +fn unix_timestamp_millis() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be after epoch") + .as_millis() as u64 +} + +fn not_found(message: String) -> ApiError { + ( + StatusCode::NOT_FOUND, + Json(ErrorResponse { error: message }), + ) +} + +#[cfg(test)] +mod tests { + use super::{ + app, AppState, CreateSessionResponse, ListSessionsResponse, SessionDetailsResponse, + }; + use reqwest::Client; + use std::net::SocketAddr; + use std::time::Duration; + use tokio::net::TcpListener; + use tokio::task::JoinHandle; + use tokio::time::timeout; + + struct TestServer { + address: SocketAddr, + handle: JoinHandle<()>, + } + + impl TestServer { + async fn spawn() -> Self { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("test listener should bind"); + let address = listener + .local_addr() + .expect("listener should report local address"); + let handle = tokio::spawn(async move { + axum::serve(listener, app(AppState::default())) + .await + .expect("server should run"); + }); + + Self { address, handle } + } + + fn url(&self, path: &str) -> String { + format!("http://{}{}", self.address, path) + } + } + + impl Drop for TestServer { + fn drop(&mut self) { + self.handle.abort(); + } + } + + async fn create_session(client: &Client, server: &TestServer) -> CreateSessionResponse { + client + .post(server.url("/sessions")) + .send() + .await + .expect("create request should succeed") + .error_for_status() + .expect("create request should return success") + .json::<CreateSessionResponse>() + .await + .expect("create response should parse") + } + + async fn next_sse_frame(response: &mut reqwest::Response, buffer: &mut String) -> String { + loop { + if let Some(index) = buffer.find("\n\n") { + let frame = buffer[..index].to_string(); + let remainder = buffer[index + 2..].to_string(); + *buffer = remainder; + return frame; + } + + let next_chunk = timeout(Duration::from_secs(5), response.chunk()) + .await + .expect("SSE stream should yield within timeout") + .expect("SSE stream should remain readable") + .expect("SSE stream should stay open"); + buffer.push_str(&String::from_utf8_lossy(&next_chunk)); + } + } + + #[tokio::test] + async fn creates_and_lists_sessions() { + let server = TestServer::spawn().await; + let client = Client::new(); + + // given + let created = create_session(&client, &server).await; + + // when + let sessions = client + .get(server.url("/sessions")) + .send() + .await + .expect("list request should succeed") + .error_for_status() + .expect("list request should return success") + .json::<ListSessionsResponse>() + .await + .expect("list response should parse"); + let details = client + .get(server.url(&format!("/sessions/{}", created.session_id))) + .send() + .await + .expect("details request should succeed") + .error_for_status() + .expect("details request should return success") + .json::<SessionDetailsResponse>() + .await + .expect("details response should parse"); + + // then + assert_eq!(created.session_id, "session-1"); + assert_eq!(sessions.sessions.len(), 1); + assert_eq!(sessions.sessions[0].id, created.session_id); + assert_eq!(sessions.sessions[0].message_count, 0); + assert_eq!(details.id, "session-1"); + assert!(details.session.messages.is_empty()); + } + + #[tokio::test] + async fn streams_message_events_and_persists_message_flow() { + let server = TestServer::spawn().await; + let client = Client::new(); + + // given + let created = create_session(&client, &server).await; + let mut response = client + .get(server.url(&format!("/sessions/{}/events", created.session_id))) + .send() + .await + .expect("events request should succeed") + .error_for_status() + .expect("events request should return success"); + let mut buffer = String::new(); + let snapshot_frame = next_sse_frame(&mut response, &mut buffer).await; + + // when + let send_status = client + .post(server.url(&format!("/sessions/{}/message", created.session_id))) + .json(&super::SendMessageRequest { + message: "hello from test".to_string(), + }) + .send() + .await + .expect("message request should succeed") + .status(); + let message_frame = next_sse_frame(&mut response, &mut buffer).await; + let details = client + .get(server.url(&format!("/sessions/{}", created.session_id))) + .send() + .await + .expect("details request should succeed") + .error_for_status() + .expect("details request should return success") + .json::<SessionDetailsResponse>() + .await + .expect("details response should parse"); + + // then + assert_eq!(send_status, reqwest::StatusCode::NO_CONTENT); + assert!(snapshot_frame.contains("event: snapshot")); + assert!(snapshot_frame.contains("\"session_id\":\"session-1\"")); + assert!(message_frame.contains("event: message")); + assert!(message_frame.contains("hello from test")); + assert_eq!(details.session.messages.len(), 1); + assert_eq!( + details.session.messages[0], + runtime::ConversationMessage::user_text("hello from test") + ); + } +} diff --git a/crates/tools/Cargo.toml b/crates/tools/Cargo.toml index 64768f4..04d738b 100644 --- a/crates/tools/Cargo.toml +++ b/crates/tools/Cargo.toml @@ -6,10 +6,13 @@ license.workspace = true publish.workspace = true [dependencies] +api = { path = "../api" } +plugins = { path = "../plugins" } runtime = { path = "../runtime" } reqwest = { version = "0.12", default-features = false, features = ["blocking", "rustls-tls"] } serde = { version = "1", features = ["derive"] } -serde_json = "1" +serde_json.workspace = true +tokio = { version = "1", features = ["rt-multi-thread"] } [lints] workspace = true diff --git a/crates/tools/src/lib.rs b/crates/tools/src/lib.rs index 091b256..4b42572 100644 --- a/crates/tools/src/lib.rs +++ b/crates/tools/src/lib.rs @@ -3,10 +3,18 @@ use std::path::{Path, PathBuf}; use std::process::Command; use std::time::{Duration, Instant}; +use api::{ + max_tokens_for_model, resolve_model_alias, ContentBlockDelta, InputContentBlock, InputMessage, + MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, + StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, +}; +use plugins::PluginTool; use reqwest::blocking::Client; use runtime::{ - edit_file, execute_bash, glob_search, grep_search, read_file, write_file, BashCommandInput, - GrepSearchInput, PermissionMode, + edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file, + ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage, + ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy, + RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -48,6 +56,161 @@ pub struct ToolSpec { pub required_permission: PermissionMode, } +#[derive(Debug, Clone, PartialEq)] +pub struct GlobalToolRegistry { + plugin_tools: Vec<PluginTool>, +} + +impl GlobalToolRegistry { + #[must_use] + pub fn builtin() -> Self { + Self { + plugin_tools: Vec::new(), + } + } + + pub fn with_plugin_tools(plugin_tools: Vec<PluginTool>) -> Result<Self, String> { + let builtin_names = mvp_tool_specs() + .into_iter() + .map(|spec| spec.name.to_string()) + .collect::<BTreeSet<_>>(); + let mut seen_plugin_names = BTreeSet::new(); + + for tool in &plugin_tools { + let name = tool.definition().name.clone(); + if builtin_names.contains(&name) { + return Err(format!( + "plugin tool `{name}` conflicts with a built-in tool name" + )); + } + if !seen_plugin_names.insert(name.clone()) { + return Err(format!("duplicate plugin tool name `{name}`")); + } + } + + Ok(Self { plugin_tools }) + } + + pub fn normalize_allowed_tools(&self, values: &[String]) -> Result<Option<BTreeSet<String>>, String> { + if values.is_empty() { + return Ok(None); + } + + let builtin_specs = mvp_tool_specs(); + let canonical_names = builtin_specs + .iter() + .map(|spec| spec.name.to_string()) + .chain(self.plugin_tools.iter().map(|tool| tool.definition().name.clone())) + .collect::<Vec<_>>(); + let mut name_map = canonical_names + .iter() + .map(|name| (normalize_tool_name(name), name.clone())) + .collect::<BTreeMap<_, _>>(); + + for (alias, canonical) in [ + ("read", "read_file"), + ("write", "write_file"), + ("edit", "edit_file"), + ("glob", "glob_search"), + ("grep", "grep_search"), + ] { + name_map.insert(alias.to_string(), canonical.to_string()); + } + + let mut allowed = BTreeSet::new(); + for value in values { + for token in value + .split(|ch: char| ch == ',' || ch.is_whitespace()) + .filter(|token| !token.is_empty()) + { + let normalized = normalize_tool_name(token); + let canonical = name_map.get(&normalized).ok_or_else(|| { + format!( + "unsupported tool in --allowedTools: {token} (expected one of: {})", + canonical_names.join(", ") + ) + })?; + allowed.insert(canonical.clone()); + } + } + + Ok(Some(allowed)) + } + + #[must_use] + pub fn definitions(&self, allowed_tools: Option<&BTreeSet<String>>) -> Vec<ToolDefinition> { + let builtin = mvp_tool_specs() + .into_iter() + .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name))) + .map(|spec| ToolDefinition { + name: spec.name.to_string(), + description: Some(spec.description.to_string()), + input_schema: spec.input_schema, + }); + let plugin = self + .plugin_tools + .iter() + .filter(|tool| { + allowed_tools.is_none_or(|allowed| allowed.contains(tool.definition().name.as_str())) + }) + .map(|tool| ToolDefinition { + name: tool.definition().name.clone(), + description: tool.definition().description.clone(), + input_schema: tool.definition().input_schema.clone(), + }); + builtin.chain(plugin).collect() + } + + #[must_use] + pub fn permission_specs( + &self, + allowed_tools: Option<&BTreeSet<String>>, + ) -> Vec<(String, PermissionMode)> { + let builtin = mvp_tool_specs() + .into_iter() + .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name))) + .map(|spec| (spec.name.to_string(), spec.required_permission)); + let plugin = self + .plugin_tools + .iter() + .filter(|tool| { + allowed_tools.is_none_or(|allowed| allowed.contains(tool.definition().name.as_str())) + }) + .map(|tool| { + ( + tool.definition().name.clone(), + permission_mode_from_plugin(tool.required_permission()), + ) + }); + builtin.chain(plugin).collect() + } + + pub fn execute(&self, name: &str, input: &Value) -> Result<String, String> { + if mvp_tool_specs().iter().any(|spec| spec.name == name) { + return execute_tool(name, input); + } + self.plugin_tools + .iter() + .find(|tool| tool.definition().name == name) + .ok_or_else(|| format!("unsupported tool: {name}"))? + .execute(input) + .map_err(|error| error.to_string()) + } +} + +fn normalize_tool_name(value: &str) -> String { + value.trim().replace('-', "_").to_ascii_lowercase() +} + +fn permission_mode_from_plugin(value: &str) -> PermissionMode { + match value { + "read-only" => PermissionMode::ReadOnly, + "workspace-write" => PermissionMode::WorkspaceWrite, + "danger-full-access" => PermissionMode::DangerFullAccess, + other => panic!("unsupported plugin permission: {other}"), + } +} + #[must_use] #[allow(clippy::too_many_lines)] pub fn mvp_tool_specs() -> Vec<ToolSpec> { @@ -316,7 +479,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> { }, ToolSpec { name: "Config", - description: "Get or set Claude Code settings.", + description: "Get or set Claw Code settings.", input_schema: json!({ "type": "object", "properties": { @@ -702,7 +865,7 @@ struct SkillOutput { prompt: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct AgentOutput { #[serde(rename = "agentId")] agent_id: String, @@ -718,6 +881,20 @@ struct AgentOutput { manifest_file: String, #[serde(rename = "createdAt")] created_at: String, + #[serde(rename = "startedAt", skip_serializing_if = "Option::is_none")] + started_at: Option<String>, + #[serde(rename = "completedAt", skip_serializing_if = "Option::is_none")] + completed_at: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option<String>, +} + +#[derive(Debug, Clone)] +struct AgentJob { + manifest: AgentOutput, + prompt: String, + system_prompt: Vec<String>, + allowed_tools: BTreeSet<String>, } #[derive(Debug, Serialize)] @@ -904,7 +1081,7 @@ fn build_http_client() -> Result<Client, String> { Client::builder() .timeout(Duration::from_secs(20)) .redirect(reqwest::redirect::Policy::limited(10)) - .user_agent("clawd-rust-tools/0.1") + .user_agent("claw-rust-tools/0.1") .build() .map_err(|error| error.to_string()) } @@ -925,7 +1102,7 @@ fn normalize_fetch_url(url: &str) -> Result<String, String> { } fn build_search_url(query: &str) -> Result<reqwest::Url, String> { - if let Ok(base) = std::env::var("CLAWD_WEB_SEARCH_BASE_URL") { + if let Ok(base) = std::env::var("CLAW_WEB_SEARCH_BASE_URL") { let mut url = reqwest::Url::parse(&base).map_err(|error| error.to_string())?; url.query_pairs_mut().append_pair("q", query); return Ok(url); @@ -1259,15 +1436,7 @@ fn validate_todos(todos: &[TodoItem]) -> Result<(), String> { if todos.is_empty() { return Err(String::from("todos must not be empty")); } - let in_progress = todos - .iter() - .filter(|todo| matches!(todo.status, TodoStatus::InProgress)) - .count(); - if in_progress > 1 { - return Err(String::from( - "exactly zero or one todo items may be in_progress", - )); - } + // Allow multiple in_progress items for parallel workflows if todos.iter().any(|todo| todo.content.trim().is_empty()) { return Err(String::from("todo content must not be empty")); } @@ -1278,11 +1447,11 @@ fn validate_todos(todos: &[TodoItem]) -> Result<(), String> { } fn todo_store_path() -> Result<std::path::PathBuf, String> { - if let Ok(path) = std::env::var("CLAWD_TODO_STORE") { + if let Ok(path) = std::env::var("CLAW_TODO_STORE") { return Ok(std::path::PathBuf::from(path)); } let cwd = std::env::current_dir().map_err(|error| error.to_string())?; - Ok(cwd.join(".clawd-todos.json")) + Ok(cwd.join(".claw-todos.json")) } fn resolve_skill_path(skill: &str) -> Result<std::path::PathBuf, String> { @@ -1295,6 +1464,12 @@ fn resolve_skill_path(skill: &str) -> Result<std::path::PathBuf, String> { if let Ok(codex_home) = std::env::var("CODEX_HOME") { candidates.push(std::path::PathBuf::from(codex_home).join("skills")); } + if let Ok(home) = std::env::var("HOME") { + let home = std::path::PathBuf::from(home); + candidates.push(home.join(".agents").join("skills")); + candidates.push(home.join(".config").join("opencode").join("skills")); + candidates.push(home.join(".codex").join("skills")); + } candidates.push(std::path::PathBuf::from("/home/bellman/.codex/skills")); for root in candidates { @@ -1323,7 +1498,18 @@ fn resolve_skill_path(skill: &str) -> Result<std::path::PathBuf, String> { Err(format!("unknown skill: {requested}")) } +const DEFAULT_AGENT_MODEL: &str = "claude-opus-4-6"; +const DEFAULT_AGENT_SYSTEM_DATE: &str = "2026-03-31"; +const DEFAULT_AGENT_MAX_ITERATIONS: usize = 32; + fn execute_agent(input: AgentInput) -> Result<AgentOutput, String> { + execute_agent_with_spawn(input, spawn_agent_job) +} + +fn execute_agent_with_spawn<F>(input: AgentInput, spawn_fn: F) -> Result<AgentOutput, String> +where + F: FnOnce(AgentJob) -> Result<(), String>, +{ if input.description.trim().is_empty() { return Err(String::from("description must not be empty")); } @@ -1337,6 +1523,7 @@ fn execute_agent(input: AgentInput) -> Result<AgentOutput, String> { let output_file = output_dir.join(format!("{agent_id}.md")); let manifest_file = output_dir.join(format!("{agent_id}.json")); let normalized_subagent_type = normalize_subagent_type(input.subagent_type.as_deref()); + let model = resolve_agent_model(input.model.as_deref()); let agent_name = input .name .as_deref() @@ -1344,6 +1531,8 @@ fn execute_agent(input: AgentInput) -> Result<AgentOutput, String> { .filter(|name| !name.is_empty()) .unwrap_or_else(|| slugify_agent_name(&input.description)); let created_at = iso8601_now(); + let system_prompt = build_agent_system_prompt(&normalized_subagent_type)?; + let allowed_tools = allowed_tools_for_subagent(&normalized_subagent_type); let output_contents = format!( "# Agent Task @@ -1367,21 +1556,519 @@ fn execute_agent(input: AgentInput) -> Result<AgentOutput, String> { name: agent_name, description: input.description, subagent_type: Some(normalized_subagent_type), - model: input.model, - status: String::from("queued"), + model: Some(model), + status: String::from("running"), output_file: output_file.display().to_string(), manifest_file: manifest_file.display().to_string(), - created_at, + created_at: created_at.clone(), + started_at: Some(created_at), + completed_at: None, + error: None, }; - std::fs::write( - &manifest_file, - serde_json::to_string_pretty(&manifest).map_err(|error| error.to_string())?, - ) - .map_err(|error| error.to_string())?; + write_agent_manifest(&manifest)?; + + let manifest_for_spawn = manifest.clone(); + let job = AgentJob { + manifest: manifest_for_spawn, + prompt: input.prompt, + system_prompt, + allowed_tools, + }; + if let Err(error) = spawn_fn(job) { + let error = format!("failed to spawn sub-agent: {error}"); + persist_agent_terminal_state(&manifest, "failed", None, Some(error.clone()))?; + return Err(error); + } Ok(manifest) } +fn spawn_agent_job(job: AgentJob) -> Result<(), String> { + let thread_name = format!("claw-agent-{}", job.manifest.agent_id); + std::thread::Builder::new() + .name(thread_name) + .spawn(move || { + let result = + std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| run_agent_job(&job))); + match result { + Ok(Ok(())) => {} + Ok(Err(error)) => { + let _ = + persist_agent_terminal_state(&job.manifest, "failed", None, Some(error)); + } + Err(_) => { + let _ = persist_agent_terminal_state( + &job.manifest, + "failed", + None, + Some(String::from("sub-agent thread panicked")), + ); + } + } + }) + .map(|_| ()) + .map_err(|error| error.to_string()) +} + +fn run_agent_job(job: &AgentJob) -> Result<(), String> { + let mut runtime = build_agent_runtime(job)?.with_max_iterations(DEFAULT_AGENT_MAX_ITERATIONS); + let summary = runtime + .run_turn(job.prompt.clone(), None) + .map_err(|error| error.to_string())?; + let final_text = final_assistant_text(&summary); + persist_agent_terminal_state(&job.manifest, "completed", Some(final_text.as_str()), None) +} + +fn build_agent_runtime( + job: &AgentJob, +) -> Result<ConversationRuntime<ProviderRuntimeClient, SubagentToolExecutor>, String> { + let model = job + .manifest + .model + .clone() + .unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string()); + let allowed_tools = job.allowed_tools.clone(); + let api_client = ProviderRuntimeClient::new(model, allowed_tools.clone())?; + let tool_executor = SubagentToolExecutor::new(allowed_tools); + Ok(ConversationRuntime::new( + Session::new(), + api_client, + tool_executor, + agent_permission_policy(), + job.system_prompt.clone(), + )) +} + +fn build_agent_system_prompt(subagent_type: &str) -> Result<Vec<String>, String> { + let cwd = std::env::current_dir().map_err(|error| error.to_string())?; + let mut prompt = load_system_prompt( + cwd, + DEFAULT_AGENT_SYSTEM_DATE.to_string(), + std::env::consts::OS, + "unknown", + ) + .map_err(|error| error.to_string())?; + prompt.push(format!( + "You are a background sub-agent of type `{subagent_type}`. Work only on the delegated task, use only the tools available to you, do not ask the user questions, and finish with a concise result." + )); + Ok(prompt) +} + +fn resolve_agent_model(model: Option<&str>) -> String { + model + .map(str::trim) + .filter(|model| !model.is_empty()) + .unwrap_or(DEFAULT_AGENT_MODEL) + .to_string() +} + +fn allowed_tools_for_subagent(subagent_type: &str) -> BTreeSet<String> { + let tools = match subagent_type { + "Explore" => vec![ + "read_file", + "glob_search", + "grep_search", + "WebFetch", + "WebSearch", + "ToolSearch", + "Skill", + "StructuredOutput", + ], + "Plan" => vec![ + "read_file", + "glob_search", + "grep_search", + "WebFetch", + "WebSearch", + "ToolSearch", + "Skill", + "TodoWrite", + "StructuredOutput", + "SendUserMessage", + ], + "Verification" => vec![ + "bash", + "read_file", + "glob_search", + "grep_search", + "WebFetch", + "WebSearch", + "ToolSearch", + "TodoWrite", + "StructuredOutput", + "SendUserMessage", + "PowerShell", + ], + "claw-guide" => vec![ + "read_file", + "glob_search", + "grep_search", + "WebFetch", + "WebSearch", + "ToolSearch", + "Skill", + "StructuredOutput", + "SendUserMessage", + ], + "statusline-setup" => vec![ + "bash", + "read_file", + "write_file", + "edit_file", + "glob_search", + "grep_search", + "ToolSearch", + ], + _ => vec![ + "bash", + "read_file", + "write_file", + "edit_file", + "glob_search", + "grep_search", + "WebFetch", + "WebSearch", + "TodoWrite", + "Skill", + "ToolSearch", + "NotebookEdit", + "Sleep", + "SendUserMessage", + "Config", + "StructuredOutput", + "REPL", + "PowerShell", + ], + }; + tools.into_iter().map(str::to_string).collect() +} + +fn agent_permission_policy() -> PermissionPolicy { + mvp_tool_specs().into_iter().fold( + PermissionPolicy::new(PermissionMode::DangerFullAccess), + |policy, spec| policy.with_tool_requirement(spec.name, spec.required_permission), + ) +} + +fn write_agent_manifest(manifest: &AgentOutput) -> Result<(), String> { + std::fs::write( + &manifest.manifest_file, + serde_json::to_string_pretty(manifest).map_err(|error| error.to_string())?, + ) + .map_err(|error| error.to_string()) +} + +fn persist_agent_terminal_state( + manifest: &AgentOutput, + status: &str, + result: Option<&str>, + error: Option<String>, +) -> Result<(), String> { + append_agent_output( + &manifest.output_file, + &format_agent_terminal_output(status, result, error.as_deref()), + )?; + let mut next_manifest = manifest.clone(); + next_manifest.status = status.to_string(); + next_manifest.completed_at = Some(iso8601_now()); + next_manifest.error = error; + write_agent_manifest(&next_manifest) +} + +fn append_agent_output(path: &str, suffix: &str) -> Result<(), String> { + use std::io::Write as _; + + let mut file = std::fs::OpenOptions::new() + .append(true) + .open(path) + .map_err(|error| error.to_string())?; + file.write_all(suffix.as_bytes()) + .map_err(|error| error.to_string()) +} + +fn format_agent_terminal_output(status: &str, result: Option<&str>, error: Option<&str>) -> String { + let mut sections = vec![format!("\n## Result\n\n- status: {status}\n")]; + if let Some(result) = result.filter(|value| !value.trim().is_empty()) { + sections.push(format!("\n### Final response\n\n{}\n", result.trim())); + } + if let Some(error) = error.filter(|value| !value.trim().is_empty()) { + sections.push(format!("\n### Error\n\n{}\n", error.trim())); + } + sections.join("") +} + +struct ProviderRuntimeClient { + runtime: tokio::runtime::Runtime, + client: ProviderClient, + model: String, + allowed_tools: BTreeSet<String>, +} + +impl ProviderRuntimeClient { + fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> { + let model = resolve_model_alias(&model).to_string(); + let client = ProviderClient::from_model(&model).map_err(|error| error.to_string())?; + Ok(Self { + runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, + client, + model, + allowed_tools, + }) + } +} + +impl ApiClient for ProviderRuntimeClient { + fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { + let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools)) + .into_iter() + .map(|spec| ToolDefinition { + name: spec.name.to_string(), + description: Some(spec.description.to_string()), + input_schema: spec.input_schema, + }) + .collect::<Vec<_>>(); + let message_request = MessageRequest { + model: self.model.clone(), + max_tokens: max_tokens_for_model(&self.model), + messages: convert_messages(&request.messages), + system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), + tools: (!tools.is_empty()).then_some(tools), + tool_choice: (!self.allowed_tools.is_empty()).then_some(ToolChoice::Auto), + stream: true, + }; + + self.runtime.block_on(async { + let mut stream = self + .client + .stream_message(&message_request) + .await + .map_err(|error| RuntimeError::new(error.to_string()))?; + let mut events = Vec::new(); + let mut pending_tools: BTreeMap<u32, (String, String, String)> = BTreeMap::new(); + let mut saw_stop = false; + + while let Some(event) = stream + .next_event() + .await + .map_err(|error| RuntimeError::new(error.to_string()))? + { + match event { + ApiStreamEvent::MessageStart(start) => { + for block in start.message.content { + push_output_block(block, 0, &mut events, &mut pending_tools, true); + } + } + ApiStreamEvent::ContentBlockStart(start) => { + push_output_block( + start.content_block, + start.index, + &mut events, + &mut pending_tools, + true, + ); + } + ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta { + ContentBlockDelta::TextDelta { text } => { + if !text.is_empty() { + events.push(AssistantEvent::TextDelta(text)); + } + } + ContentBlockDelta::InputJsonDelta { partial_json } => { + if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) { + input.push_str(&partial_json); + } + } + ContentBlockDelta::ThinkingDelta { .. } + | ContentBlockDelta::SignatureDelta { .. } => {} + }, + ApiStreamEvent::ContentBlockStop(stop) => { + if let Some((id, name, input)) = pending_tools.remove(&stop.index) { + events.push(AssistantEvent::ToolUse { id, name, input }); + } + } + ApiStreamEvent::MessageDelta(delta) => { + events.push(AssistantEvent::Usage(TokenUsage { + input_tokens: delta.usage.input_tokens, + output_tokens: delta.usage.output_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + })); + } + ApiStreamEvent::MessageStop(_) => { + saw_stop = true; + events.push(AssistantEvent::MessageStop); + } + } + } + + if !saw_stop + && events.iter().any(|event| { + matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) + || matches!(event, AssistantEvent::ToolUse { .. }) + }) + { + events.push(AssistantEvent::MessageStop); + } + + if events + .iter() + .any(|event| matches!(event, AssistantEvent::MessageStop)) + { + return Ok(events); + } + + let response = self + .client + .send_message(&MessageRequest { + stream: false, + ..message_request.clone() + }) + .await + .map_err(|error| RuntimeError::new(error.to_string()))?; + Ok(response_to_events(response)) + }) + } +} + +struct SubagentToolExecutor { + allowed_tools: BTreeSet<String>, +} + +impl SubagentToolExecutor { + fn new(allowed_tools: BTreeSet<String>) -> Self { + Self { allowed_tools } + } +} + +impl ToolExecutor for SubagentToolExecutor { + fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> { + if !self.allowed_tools.contains(tool_name) { + return Err(ToolError::new(format!( + "tool `{tool_name}` is not enabled for this sub-agent" + ))); + } + let value = serde_json::from_str(input) + .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + execute_tool(tool_name, &value).map_err(ToolError::new) + } +} + +fn tool_specs_for_allowed_tools(allowed_tools: Option<&BTreeSet<String>>) -> Vec<ToolSpec> { + mvp_tool_specs() + .into_iter() + .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name))) + .collect() +} + +fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> { + messages + .iter() + .filter_map(|message| { + let role = match message.role { + MessageRole::System | MessageRole::User | MessageRole::Tool => "user", + MessageRole::Assistant => "assistant", + }; + let content = message + .blocks + .iter() + .map(|block| match block { + ContentBlock::Text { text } => InputContentBlock::Text { text: text.clone() }, + ContentBlock::ToolUse { id, name, input } => InputContentBlock::ToolUse { + id: id.clone(), + name: name.clone(), + input: serde_json::from_str(input) + .unwrap_or_else(|_| serde_json::json!({ "raw": input })), + }, + ContentBlock::ToolResult { + tool_use_id, + output, + is_error, + .. + } => InputContentBlock::ToolResult { + tool_use_id: tool_use_id.clone(), + content: vec![ToolResultContentBlock::Text { + text: output.clone(), + }], + is_error: *is_error, + }, + }) + .collect::<Vec<_>>(); + (!content.is_empty()).then(|| InputMessage { + role: role.to_string(), + content, + }) + }) + .collect() +} + +fn push_output_block( + block: OutputContentBlock, + block_index: u32, + events: &mut Vec<AssistantEvent>, + pending_tools: &mut BTreeMap<u32, (String, String, String)>, + streaming_tool_input: bool, +) { + match block { + OutputContentBlock::Text { text } => { + if !text.is_empty() { + events.push(AssistantEvent::TextDelta(text)); + } + } + OutputContentBlock::ToolUse { id, name, input } => { + let initial_input = if streaming_tool_input + && input.is_object() + && input.as_object().is_some_and(serde_json::Map::is_empty) + { + String::new() + } else { + input.to_string() + }; + pending_tools.insert(block_index, (id, name, initial_input)); + } + OutputContentBlock::Thinking { .. } | OutputContentBlock::RedactedThinking { .. } => {} + } +} + +fn response_to_events(response: MessageResponse) -> Vec<AssistantEvent> { + let mut events = Vec::new(); + let mut pending_tools = BTreeMap::new(); + + for (index, block) in response.content.into_iter().enumerate() { + let index = u32::try_from(index).expect("response block index overflow"); + push_output_block(block, index, &mut events, &mut pending_tools, false); + if let Some((id, name, input)) = pending_tools.remove(&index) { + events.push(AssistantEvent::ToolUse { id, name, input }); + } + } + + events.push(AssistantEvent::Usage(TokenUsage { + input_tokens: response.usage.input_tokens, + output_tokens: response.usage.output_tokens, + cache_creation_input_tokens: response.usage.cache_creation_input_tokens, + cache_read_input_tokens: response.usage.cache_read_input_tokens, + })); + events.push(AssistantEvent::MessageStop); + events +} + +fn final_assistant_text(summary: &runtime::TurnSummary) -> String { + summary + .assistant_messages + .last() + .map(|message| { + message + .blocks + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::<Vec<_>>() + .join("") + }) + .unwrap_or_default() +} + #[allow(clippy::needless_pass_by_value)] fn execute_tool_search(input: ToolSearchInput) -> ToolSearchOutput { let deferred = deferred_tool_specs(); @@ -1519,14 +2206,14 @@ fn canonical_tool_token(value: &str) -> String { } fn agent_store_dir() -> Result<std::path::PathBuf, String> { - if let Ok(path) = std::env::var("CLAWD_AGENT_STORE") { + if let Ok(path) = std::env::var("CLAW_AGENT_STORE") { return Ok(std::path::PathBuf::from(path)); } let cwd = std::env::current_dir().map_err(|error| error.to_string())?; if let Some(workspace_root) = cwd.ancestors().nth(2) { - return Ok(workspace_root.join(".clawd-agents")); + return Ok(workspace_root.join(".claw-agents")); } - Ok(cwd.join(".clawd-agents")) + Ok(cwd.join(".claw-agents")) } fn make_agent_id() -> String { @@ -1567,7 +2254,7 @@ fn normalize_subagent_type(subagent_type: Option<&str>) -> String { "verification" | "verificationagent" | "verify" | "verifier" => { String::from("Verification") } - "claudecodeguide" | "claudecodeguideagent" | "guide" => String::from("claude-code-guide"), + "clawguide" | "clawguideagent" | "guide" => String::from("claw-guide"), "statusline" | "statuslinesetup" => String::from("statusline-setup"), _ => trimmed.to_string(), } @@ -2067,16 +2754,16 @@ fn config_file_for_scope(scope: ConfigScope) -> Result<PathBuf, String> { let cwd = std::env::current_dir().map_err(|error| error.to_string())?; Ok(match scope { ConfigScope::Global => config_home_dir()?.join("settings.json"), - ConfigScope::Settings => cwd.join(".claude").join("settings.local.json"), + ConfigScope::Settings => cwd.join(".claw").join("settings.local.json"), }) } fn config_home_dir() -> Result<PathBuf, String> { - if let Ok(path) = std::env::var("CLAUDE_CONFIG_HOME") { + if let Ok(path) = std::env::var("CLAW_CONFIG_HOME") { return Ok(PathBuf::from(path)); } let home = std::env::var("HOME").map_err(|_| String::from("HOME is not set"))?; - Ok(PathBuf::from(home).join(".claude")) + Ok(PathBuf::from(home).join(".claw")) } fn read_json_object(path: &Path) -> Result<serde_json::Map<String, Value>, String> { @@ -2215,7 +2902,7 @@ fn execute_shell_command( persisted_output_path: None, persisted_output_size: None, sandbox_status: None, -}); + }); } let mut process = std::process::Command::new(shell); @@ -2284,7 +2971,7 @@ Command exceeded timeout of {timeout_ms} ms", persisted_output_path: None, persisted_output_size: None, sandbox_status: None, -}); + }); } std::thread::sleep(Duration::from_millis(10)); } @@ -2373,6 +3060,8 @@ fn parse_skill_description(contents: &str) -> Option<String> { #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use std::collections::BTreeSet; use std::fs; use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener}; @@ -2381,7 +3070,13 @@ mod tests { use std::thread; use std::time::Duration; - use super::{execute_tool, mvp_tool_specs}; + use super::{ + agent_permission_policy, allowed_tools_for_subagent, execute_agent_with_spawn, + execute_tool, final_assistant_text, mvp_tool_specs, persist_agent_terminal_state, + push_output_block, AgentInput, AgentJob, SubagentToolExecutor, + }; + use api::OutputContentBlock; + use runtime::{ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, Session}; use serde_json::json; fn env_lock() -> &'static Mutex<()> { @@ -2394,7 +3089,7 @@ mod tests { .duration_since(std::time::UNIX_EPOCH) .expect("time") .as_nanos(); - std::env::temp_dir().join(format!("clawd-tools-{unique}-{name}")) + std::env::temp_dir().join(format!("claw-tools-{unique}-{name}")) } #[test] @@ -2517,7 +3212,7 @@ mod tests { })); std::env::set_var( - "CLAWD_WEB_SEARCH_BASE_URL", + "CLAW_WEB_SEARCH_BASE_URL", format!("http://{}/search", server.addr()), ); let result = execute_tool( @@ -2529,7 +3224,7 @@ mod tests { }), ) .expect("WebSearch should succeed"); - std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL"); + std::env::remove_var("CLAW_WEB_SEARCH_BASE_URL"); let output: serde_json::Value = serde_json::from_str(&result).expect("valid json"); assert_eq!(output["query"], "rust web search"); @@ -2565,7 +3260,7 @@ mod tests { })); std::env::set_var( - "CLAWD_WEB_SEARCH_BASE_URL", + "CLAW_WEB_SEARCH_BASE_URL", format!("http://{}/fallback", server.addr()), ); let result = execute_tool( @@ -2575,7 +3270,7 @@ mod tests { }), ) .expect("WebSearch fallback parsing should succeed"); - std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL"); + std::env::remove_var("CLAW_WEB_SEARCH_BASE_URL"); let output: serde_json::Value = serde_json::from_str(&result).expect("valid json"); let results = output["results"].as_array().expect("results array"); @@ -2588,20 +3283,77 @@ mod tests { assert_eq!(content[0]["url"], "https://example.com/one"); assert_eq!(content[1]["url"], "https://docs.rs/tokio"); - std::env::set_var("CLAWD_WEB_SEARCH_BASE_URL", "://bad-base-url"); + std::env::set_var("CLAW_WEB_SEARCH_BASE_URL", "://bad-base-url"); let error = execute_tool("WebSearch", &json!({ "query": "generic links" })) .expect_err("invalid base URL should fail"); - std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL"); + std::env::remove_var("CLAW_WEB_SEARCH_BASE_URL"); assert!(error.contains("relative URL without a base") || error.contains("empty host")); } + #[test] + fn pending_tools_preserve_multiple_streaming_tool_calls_by_index() { + let mut events = Vec::new(); + let mut pending_tools = BTreeMap::new(); + + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }, + 1, + &mut events, + &mut pending_tools, + true, + ); + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-2".to_string(), + name: "grep_search".to_string(), + input: json!({}), + }, + 2, + &mut events, + &mut pending_tools, + true, + ); + + pending_tools + .get_mut(&1) + .expect("first tool pending") + .2 + .push_str("{\"path\":\"src/main.rs\"}"); + pending_tools + .get_mut(&2) + .expect("second tool pending") + .2 + .push_str("{\"pattern\":\"TODO\"}"); + + assert_eq!( + pending_tools.remove(&1), + Some(( + "tool-1".to_string(), + "read_file".to_string(), + "{\"path\":\"src/main.rs\"}".to_string(), + )) + ); + assert_eq!( + pending_tools.remove(&2), + Some(( + "tool-2".to_string(), + "grep_search".to_string(), + "{\"pattern\":\"TODO\"}".to_string(), + )) + ); + } + #[test] fn todo_write_persists_and_returns_previous_state() { let _guard = env_lock() .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let path = temp_path("todos.json"); - std::env::set_var("CLAWD_TODO_STORE", &path); + std::env::set_var("CLAW_TODO_STORE", &path); let first = execute_tool( "TodoWrite", @@ -2627,7 +3379,7 @@ mod tests { }), ) .expect("TodoWrite should succeed"); - std::env::remove_var("CLAWD_TODO_STORE"); + std::env::remove_var("CLAW_TODO_STORE"); let _ = std::fs::remove_file(path); let second_output: serde_json::Value = serde_json::from_str(&second).expect("valid json"); @@ -2648,13 +3400,14 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let path = temp_path("todos-errors.json"); - std::env::set_var("CLAWD_TODO_STORE", &path); + std::env::set_var("CLAW_TODO_STORE", &path); let empty = execute_tool("TodoWrite", &json!({ "todos": [] })) .expect_err("empty todos should fail"); assert!(empty.contains("todos must not be empty")); - let too_many_active = execute_tool( + // Multiple in_progress items are now allowed for parallel workflows + let _multi_active = execute_tool( "TodoWrite", &json!({ "todos": [ @@ -2663,8 +3416,7 @@ mod tests { ] }), ) - .expect_err("multiple in-progress todos should fail"); - assert!(too_many_active.contains("zero or one todo items may be in_progress")); + .expect("multiple in-progress todos should succeed"); let blank_content = execute_tool( "TodoWrite", @@ -2688,7 +3440,7 @@ mod tests { }), ) .expect("completed todos should succeed"); - std::env::remove_var("CLAWD_TODO_STORE"); + std::env::remove_var("CLAW_TODO_STORE"); let _ = fs::remove_file(path); let output: serde_json::Value = serde_json::from_str(&nudge).expect("valid json"); @@ -2697,6 +3449,9 @@ mod tests { #[test] fn skill_loads_local_skill_prompt() { + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); let result = execute_tool( "Skill", &json!({ @@ -2772,33 +3527,49 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let dir = temp_path("agent-store"); - std::env::set_var("CLAWD_AGENT_STORE", &dir); + std::env::set_var("CLAW_AGENT_STORE", &dir); + let captured = Arc::new(Mutex::new(None::<AgentJob>)); + let captured_for_spawn = Arc::clone(&captured); - let result = execute_tool( - "Agent", - &json!({ - "description": "Audit the branch", - "prompt": "Check tests and outstanding work.", - "subagent_type": "Explore", - "name": "ship-audit" - }), + let manifest = execute_agent_with_spawn( + AgentInput { + description: "Audit the branch".to_string(), + prompt: "Check tests and outstanding work.".to_string(), + subagent_type: Some("Explore".to_string()), + name: Some("ship-audit".to_string()), + model: None, + }, + move |job| { + *captured_for_spawn + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(job); + Ok(()) + }, ) .expect("Agent should succeed"); - std::env::remove_var("CLAWD_AGENT_STORE"); + std::env::remove_var("CLAW_AGENT_STORE"); - let output: serde_json::Value = serde_json::from_str(&result).expect("valid json"); - assert_eq!(output["name"], "ship-audit"); - assert_eq!(output["subagentType"], "Explore"); - assert_eq!(output["status"], "queued"); - assert!(output["createdAt"].as_str().is_some()); - let manifest_file = output["manifestFile"].as_str().expect("manifest file"); - let output_file = output["outputFile"].as_str().expect("output file"); - let contents = std::fs::read_to_string(output_file).expect("agent file exists"); + assert_eq!(manifest.name, "ship-audit"); + assert_eq!(manifest.subagent_type.as_deref(), Some("Explore")); + assert_eq!(manifest.status, "running"); + assert!(!manifest.created_at.is_empty()); + assert!(manifest.started_at.is_some()); + assert!(manifest.completed_at.is_none()); + let contents = std::fs::read_to_string(&manifest.output_file).expect("agent file exists"); let manifest_contents = - std::fs::read_to_string(manifest_file).expect("manifest file exists"); + std::fs::read_to_string(&manifest.manifest_file).expect("manifest file exists"); assert!(contents.contains("Audit the branch")); assert!(contents.contains("Check tests and outstanding work.")); assert!(manifest_contents.contains("\"subagentType\": \"Explore\"")); + assert!(manifest_contents.contains("\"status\": \"running\"")); + let captured_job = captured + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + .expect("spawn job should be captured"); + assert_eq!(captured_job.prompt, "Check tests and outstanding work."); + assert!(captured_job.allowed_tools.contains("read_file")); + assert!(!captured_job.allowed_tools.contains("Agent")); let normalized = execute_tool( "Agent", @@ -2827,6 +3598,195 @@ mod tests { let _ = std::fs::remove_dir_all(dir); } + #[test] + fn agent_fake_runner_can_persist_completion_and_failure() { + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let dir = temp_path("agent-runner"); + std::env::set_var("CLAW_AGENT_STORE", &dir); + + let completed = execute_agent_with_spawn( + AgentInput { + description: "Complete the task".to_string(), + prompt: "Do the work".to_string(), + subagent_type: Some("Explore".to_string()), + name: Some("complete-task".to_string()), + model: Some("claude-sonnet-4-6".to_string()), + }, + |job| { + persist_agent_terminal_state( + &job.manifest, + "completed", + Some("Finished successfully"), + None, + ) + }, + ) + .expect("completed agent should succeed"); + + let completed_manifest = std::fs::read_to_string(&completed.manifest_file) + .expect("completed manifest should exist"); + let completed_output = + std::fs::read_to_string(&completed.output_file).expect("completed output should exist"); + assert!(completed_manifest.contains("\"status\": \"completed\"")); + assert!(completed_output.contains("Finished successfully")); + + let failed = execute_agent_with_spawn( + AgentInput { + description: "Fail the task".to_string(), + prompt: "Do the failing work".to_string(), + subagent_type: Some("Verification".to_string()), + name: Some("fail-task".to_string()), + model: None, + }, + |job| { + persist_agent_terminal_state( + &job.manifest, + "failed", + None, + Some(String::from("simulated failure")), + ) + }, + ) + .expect("failed agent should still spawn"); + + let failed_manifest = + std::fs::read_to_string(&failed.manifest_file).expect("failed manifest should exist"); + let failed_output = + std::fs::read_to_string(&failed.output_file).expect("failed output should exist"); + assert!(failed_manifest.contains("\"status\": \"failed\"")); + assert!(failed_manifest.contains("simulated failure")); + assert!(failed_output.contains("simulated failure")); + + let spawn_error = execute_agent_with_spawn( + AgentInput { + description: "Spawn error task".to_string(), + prompt: "Never starts".to_string(), + subagent_type: None, + name: Some("spawn-error".to_string()), + model: None, + }, + |_| Err(String::from("thread creation failed")), + ) + .expect_err("spawn errors should surface"); + assert!(spawn_error.contains("failed to spawn sub-agent")); + let spawn_error_manifest = std::fs::read_dir(&dir) + .expect("agent dir should exist") + .filter_map(Result::ok) + .map(|entry| entry.path()) + .filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json")) + .find_map(|path| { + let contents = std::fs::read_to_string(&path).ok()?; + contents + .contains("\"name\": \"spawn-error\"") + .then_some(contents) + }) + .expect("failed manifest should still be written"); + assert!(spawn_error_manifest.contains("\"status\": \"failed\"")); + assert!(spawn_error_manifest.contains("thread creation failed")); + + std::env::remove_var("CLAW_AGENT_STORE"); + let _ = std::fs::remove_dir_all(dir); + } + + #[test] + fn agent_tool_subset_mapping_is_expected() { + let general = allowed_tools_for_subagent("general-purpose"); + assert!(general.contains("bash")); + assert!(general.contains("write_file")); + assert!(!general.contains("Agent")); + + let explore = allowed_tools_for_subagent("Explore"); + assert!(explore.contains("read_file")); + assert!(explore.contains("grep_search")); + assert!(!explore.contains("bash")); + + let plan = allowed_tools_for_subagent("Plan"); + assert!(plan.contains("TodoWrite")); + assert!(plan.contains("StructuredOutput")); + assert!(!plan.contains("Agent")); + + let verification = allowed_tools_for_subagent("Verification"); + assert!(verification.contains("bash")); + assert!(verification.contains("PowerShell")); + assert!(!verification.contains("write_file")); + } + + #[derive(Debug)] + struct MockSubagentApiClient { + calls: usize, + input_path: String, + } + + impl runtime::ApiClient for MockSubagentApiClient { + fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { + self.calls += 1; + match self.calls { + 1 => { + assert_eq!(request.messages.len(), 1); + Ok(vec![ + AssistantEvent::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({ "path": self.input_path }).to_string(), + }, + AssistantEvent::MessageStop, + ]) + } + 2 => { + assert!(request.messages.len() >= 3); + Ok(vec![ + AssistantEvent::TextDelta("Scope: completed mock review".to_string()), + AssistantEvent::MessageStop, + ]) + } + _ => panic!("unexpected mock stream call"), + } + } + } + + #[test] + fn subagent_runtime_executes_tool_loop_with_isolated_session() { + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let path = temp_path("subagent-input.txt"); + std::fs::write(&path, "hello from child").expect("write input file"); + + let mut runtime = ConversationRuntime::new( + Session::new(), + MockSubagentApiClient { + calls: 0, + input_path: path.display().to_string(), + }, + SubagentToolExecutor::new(BTreeSet::from([String::from("read_file")])), + agent_permission_policy(), + vec![String::from("system prompt")], + ); + + let summary = runtime + .run_turn("Inspect the delegated file", None) + .expect("subagent loop should succeed"); + + assert_eq!( + final_assistant_text(&summary), + "Scope: completed mock review" + ); + assert!(runtime + .session() + .messages + .iter() + .flat_map(|message| message.blocks.iter()) + .any(|block| matches!( + block, + runtime::ContentBlock::ToolResult { output, .. } + if output.contains("hello from child") + ))); + + let _ = std::fs::remove_file(path); + } + #[test] fn agent_rejects_blank_required_fields() { let missing_description = execute_tool( @@ -3212,7 +4172,7 @@ mod tests { #[test] fn brief_returns_sent_message_and_attachment_metadata() { let attachment = std::env::temp_dir().join(format!( - "clawd-brief-{}.png", + "claw-brief-{}.png", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .expect("time") @@ -3243,7 +4203,7 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let root = std::env::temp_dir().join(format!( - "clawd-config-{}", + "claw-config-{}", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .expect("time") @@ -3251,19 +4211,19 @@ mod tests { )); let home = root.join("home"); let cwd = root.join("cwd"); - std::fs::create_dir_all(home.join(".claude")).expect("home dir"); - std::fs::create_dir_all(cwd.join(".claude")).expect("cwd dir"); + std::fs::create_dir_all(home.join(".claw")).expect("home dir"); + std::fs::create_dir_all(cwd.join(".claw")).expect("cwd dir"); std::fs::write( - home.join(".claude").join("settings.json"), + home.join(".claw").join("settings.json"), r#"{"verbose":false}"#, ) .expect("write global settings"); let original_home = std::env::var("HOME").ok(); - let original_claude_home = std::env::var("CLAUDE_CONFIG_HOME").ok(); + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); let original_dir = std::env::current_dir().expect("cwd"); std::env::set_var("HOME", &home); - std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::env::remove_var("CLAW_CONFIG_HOME"); std::env::set_current_dir(&cwd).expect("set cwd"); let get = execute_tool("Config", &json!({"setting": "verbose"})).expect("get config"); @@ -3296,9 +4256,9 @@ mod tests { Some(value) => std::env::set_var("HOME", value), None => std::env::remove_var("HOME"), } - match original_claude_home { - Some(value) => std::env::set_var("CLAUDE_CONFIG_HOME", value), - None => std::env::remove_var("CLAUDE_CONFIG_HOME"), + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), } let _ = std::fs::remove_dir_all(root); } @@ -3332,7 +4292,7 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let dir = std::env::temp_dir().join(format!( - "clawd-pwsh-bin-{}", + "claw-pwsh-bin-{}", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .expect("time") @@ -3389,7 +4349,7 @@ printf 'pwsh:%s' "$1" .unwrap_or_else(std::sync::PoisonError::into_inner); let original_path = std::env::var("PATH").unwrap_or_default(); let empty_dir = std::env::temp_dir().join(format!( - "clawd-empty-bin-{}", + "claw-empty-bin-{}", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .expect("time") diff --git a/docs/releases/0.1.0.md b/docs/releases/0.1.0.md new file mode 100644 index 0000000..5254475 --- /dev/null +++ b/docs/releases/0.1.0.md @@ -0,0 +1,51 @@ +# Claw Code 0.1.0 发行说明(草案) + +## 摘要 + +Claw Code `0.1.0` 是当前 Rust 实现的第一个公开发布准备里程碑。Claw Code 的灵感来自 Claude Code,并作为一个净室(clean-room)Rust 实现构建;它不是直接的移植或复制。此版本专注于可用的本地 CLI 体验:交互式会话、非交互式提示词、工作区工具、配置加载、会话、插件以及本地代理/技能发现。 + +## 亮点 + +- Claw Code 的首个公开 `0.1.0` 发行候选版本 +- 作为当前主要产品界面的安全 Rust 实现 +- 用于交互式和单次编码代理工作流的 `claw` CLI +- 内置工作区工具:用于 shell、文件操作、搜索、网页获取/搜索、待办事项跟踪和笔记本更新 +- 斜杠命令界面:用于状态、压缩、配置检查、会话、差异/导出以及版本信息 +- 本地插件、代理和技能的发现/管理界面 +- OAuth 登录/注销以及模型/提供商选择 + +## 安装与运行 + +此版本目前旨在通过源码构建: + +```bash +cargo install --path crates/claw-cli --locked +# 或者 +cargo build --release -p claw-cli +``` + +运行: + +```bash +claw +claw prompt "总结此仓库" +``` + +## 已知限制 + +- 仅限源码构建分发;尚未发布打包好的发行构件 +- CI 目前覆盖 Ubuntu 和 macOS 的发布构建、检查和测试 +- Windows 的发布就绪性尚未建立 +- 部分集成覆盖是可选的,因为需要实时提供商凭据和网络访问 +- 公开接口可能会在 `0.x` 版本系列期间继续演进 + +## 推荐的发行定位 + +将 `0.1.0` 定位为 Claw Code 当前 Rust 实现的首个公开发布版本,面向习惯于从源码构建的早期采用者。功能表面已足够广泛以支持实际使用,而打包和发布自动化可以在后续版本中继续改进。 + +## 用于此草案的验证 + +- 通过 `Cargo.toml` 验证了工作区版本 +- 通过 `cargo metadata` 验证了 `claw` 二进制文件/包路径 +- 通过 `cargo run --quiet --bin claw -- --help` 验证了 CLI 命令表面 +- 通过 `.github/workflows/ci.yml` 验证了 CI 覆盖范围