diff --git a/.gitignore b/.gitignore index 324ae1d..cfa4cf3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,6 @@ __pycache__/ archive/ .omx/ .clawd-agents/ +# Claude Code local artifacts +.claude/settings.local.json +.claude/sessions/ diff --git a/CLAW.md b/CLAW.md new file mode 100644 index 0000000..676f2e3 --- /dev/null +++ b/CLAW.md @@ -0,0 +1,21 @@ +# CLAW.md + +此文件为 Claw Code 在此仓库中处理代码时提供指导。 + +## 检测到的技术栈 +- 语言:Rust。 +- 框架:未从支持的启动标记中检测到任何框架。 + +## 验证 +- 在 `rust/` 目录下运行 Rust 验证:`cargo fmt`、`cargo clippy --workspace --all-targets -- -D warnings`、`cargo test --workspace` +- `src/` 和 `tests/` 均已存在;当行为发生变化时,请同时更新这两个部分。 + +## 仓库结构 +- `rust/` 包含 Rust 工作区以及活跃的 CLI/运行时实现。 +- `src/` 包含应与生成的指导和测试保持一致的源文件。 +- `tests/` 包含应随代码更改而一同审查的验证部分。 + +## 工作协议 +- 优先采用小而易于审查的更改,并保持生成的引导文件与实际仓库工作流对齐。 +- 将共享的默认值保留在 `.claw.json` 中;`.claw/settings.local.json` 预留给机器本地的覆盖设置。 +- 不要自动覆盖现有的 `CLAW.md` 内容;当仓库工作流更改时,请有目的地更新它。 diff --git a/PARITY.md b/PARITY.md new file mode 100644 index 0000000..15f6bbd --- /dev/null +++ b/PARITY.md @@ -0,0 +1,214 @@ +# 差异化缺口分析 (PARITY GAP ANALYSIS) + +范围:对位于 `/home/bellman/Workspace/claw-code/src/` 的原始 TypeScript 源码与 `rust/crates/` 下的 Rust 移植版进行只读对比。 + +方法:仅对比功能表面、注册表、入口点和运行时流水线。未复制任何 TypeScript 源码。 + +## 主要摘要 + +Rust 移植版在以下方面具有良好的基础: +- Anthropic API/OAuth 基础 +- 本地对话/会话状态 +- 核心工具循环 +- MCP stdio/引导支持 +- CLAW.md 发现 +- 一套小巧但可用的内置工具 + +它与 TypeScript CLI **尚未实现功能对等**。 + +最大的缺口: +- **插件 (plugins)**:在 Rust 中实际上不存在 +- **挂钩 (hooks)**:Rust 中已解析但未执行 +- **CLI 广度**:Rust 中的功能范围要窄得多 +- **技能 (skills)**:Rust 中仅支持本地文件,没有 TS 的注册表/捆绑流水线 +- **助手编排 (assistant orchestration)**:缺乏 TS 中感知挂钩的编排以及远程/结构化传输 +- **服务 (services)**:除了核心 API/OAuth/MCP 之外,Rust 中大部分服务都缺失 + +--- + +## tools/ (工具) + +### TS 已存在 +证据: +- `src/tools/` 包含广泛的工具家族,包括 `AgentTool`、`AskUserQuestionTool`、`BashTool`、`ConfigTool`、`FileReadTool`、`FileWriteTool`、`GlobTool`、`GrepTool`、`LSPTool`、`ListMcpResourcesTool`、`MCPTool`、`McpAuthTool`、`ReadMcpResourceTool`、`RemoteTriggerTool`、`ScheduleCronTool`、`SkillTool`、`Task*`、`Team*`、`TodoWriteTool`、`ToolSearchTool`、`WebFetchTool`、`WebSearchTool`。 +- 工具执行/编排分散在 `src/services/tools/StreamingToolExecutor.ts`、`src/services/tools/toolExecution.ts`、`src/services/tools/toolHooks.ts` 和 `src/services/tools/toolOrchestration.ts` 中。 + +### Rust 已存在 +证据: +- 工具注册表通过 `mvp_tool_specs()` 集中在 `rust/crates/tools/src/lib.rs` 中。 +- 当前内置工具包括 shell/文件/搜索/网页/待办事项/技能/代理/配置/笔记本/REPL/PowerShell 原语。 +- 运行时执行通过 `rust/crates/tools/src/lib.rs` 和 `rust/crates/runtime/src/conversation.rs` 连接。 + +### Rust 中缺失或损坏 +- 重大 TS 工具在 Rust 中没有等效项,例如 `AskUserQuestionTool`、`LSPTool`、`ListMcpResourcesTool`、`MCPTool`、`McpAuthTool`、`ReadMcpResourceTool`、`RemoteTriggerTool`、`ScheduleCronTool`、`Task*`、`Team*` 以及多个工作流/系统工具。 +- Rust 工具表面仍明确是一个 MVP(最小可行产品)注册表,而非对等注册表。 +- Rust 缺乏 TS 的分层工具编排拆分。 + +**状态:** 仅部分核心功能。 + +--- + +## hooks/ (挂钩) + +### TS 已存在 +证据: +- `src/commands/hooks/` 下的挂钩命令表面。 +- `src/services/tools/toolHooks.ts` 和 `src/services/tools/toolExecution.ts` 中的运行时挂钩机制。 +- TS 支持 `PreToolUse`、`PostToolUse` 以及通过设置配置并记录在 `src/skills/bundled/updateConfig.ts` 中的更广泛的挂钩驱动行为。 + +### Rust 已存在 +证据: +- 挂钩配置在 `rust/crates/runtime/src/config.rs` 中被解析和合并。 +- 挂钩配置可以通过 `rust/crates/commands/src/lib.rs` 和 `rust/crates/claw-cli/src/main.rs` 中的 Rust 配置报告进行检查。 +- `rust/crates/runtime/src/prompt.rs` 中的提示指南提到了挂钩。 + +### Rust 中缺失或损坏 +- `rust/crates/runtime/src/conversation.rs` 中没有实际的挂钩执行流水线。 +- 没有 `PreToolUse`/`PostToolUse` 的变异/拒绝/重写/结果挂钩行为。 +- 没有 Rust 的 `/hooks` 对等命令。 + +**状态:** 仅配置支持;运行时行为缺失。 + +--- + +## plugins/ (插件) + +### TS 已存在 +证据: +- `src/plugins/builtinPlugins.ts` 和 `src/plugins/bundled/index.ts` 中的内置插件脚手架。 +- `src/services/plugins/PluginInstallationManager.ts` 和 `src/services/plugins/pluginOperations.ts` 中的插件生命周期/服务。 +- `src/commands/plugin/` 和 `src/commands/reload-plugins/` 下的 CLI/插件命令表面。 + +### Rust 已存在 +证据: +- `rust/crates/` 下没有出现专门的插件子系统。 +- 仓库范围内对插件的 Rust 引用除了文本/帮助提到外,实际上不存在。 + +### Rust 中缺失或损坏 +- 没有插件加载器。 +- 没有市场安装/更新/启用/禁用流程。 +- 没有 `/plugin` 或 `/reload-plugins` 对等命令。 +- 没有插件提供的挂钩/工具/命令/MCP 扩展路径。 + +**状态:** 缺失。 + +--- + +## skills/ 与 CLAW.md 发现 + +### TS 已存在 +证据: +- `src/skills/loadSkillsDir.ts`、`src/skills/bundledSkills.ts` 和 `src/skills/mcpSkillBuilders.ts` 中的技能加载/注册流水线。 +- `src/skills/bundled/` 下的捆绑技能。 +- `src/commands/skills/` 下的技能命令表面。 + +### Rust 已存在 +证据: +- `rust/crates/tools/src/lib.rs` 中的 `Skill` 工具可以解析并读取本地 `SKILL.md` 文件。 +- `rust/crates/runtime/src/prompt.rs` 中实现了 `CLAW.md` 发现。 +- Rust 通过 `rust/crates/commands/src/lib.rs` 和 `rust/crates/claw-cli/src/main.rs` 支持 `/memory` 和 `/init`。 + +### Rust 中缺失或损坏 +- 没有等效的捆绑技能注册表。 +- 没有 `/skills` 命令。 +- 没有 MCP 技能构建器流水线。 +- 没有类似 TS 的实时技能发现/重载/更改处理。 +- 没有围绕技能的会话记忆 (session-memory) / 团队记忆 (team-memory) 集成。 + +**状态:** 仅支持基础的本地技能加载。 + +--- + +## cli/ (命令行界面) + +### TS 已存在 +证据: +- `src/commands/` 下庞大的命令表面,包括 `agents`、`hooks`、`mcp`、`memory`、`model`、`permissions`、`plan`、`plugin`、`resume`、`review`、`skills`、`tasks` 等。 +- `src/cli/structuredIO.ts`、`src/cli/remoteIO.ts` 和 `src/cli/transports/*` 中的结构化/远程传输栈。 +- `src/cli/handlers/*` 中的 CLI 处理程序拆分。 + +### Rust 已存在 +证据: +- `rust/crates/commands/src/lib.rs` 中共享的斜杠命令注册表。 +- Rust 斜杠命令目前涵盖 `help`、`status`、`compact`、`model`、`permissions`、`clear`、`cost`、`resume`、`config`、`memory`、`init`、`diff`、`version`、`export`、`session`。 +- 主要的 CLI/REPL/提示词处理位于 `rust/crates/claw-cli/src/main.rs`。 + +### Rust 中缺失或损坏 +- 缺失重大的 TS 命令家族:`/agents`、`/hooks`、`/mcp`、`/plugin`、`/skills`、`/plan`、`/review`、`/tasks` 等等。 +- Rust 没有等效于 TS 结构化 IO / 远程传输层的实现。 +- 没有针对 auth/plugins/MCP/agents 的 TS 风格处理程序分解。 +- JSON 提示模式在此分支上有所改进,但仍未达到纯净的传输对等:实证验证显示,支持工具的 JSON 输出在最终 JSON 对象之前可能会发出人类可读的工具结果行。 + +**状态:** 功能性本地 CLI 核心,范围比 TS 窄得多。 + +--- + +## assistant/ (助手/代理循环、流式、工具调用) + +### TS 已存在 +证据: +- `src/assistant/sessionHistory.ts` 中的助手/会话表面。 +- `src/services/tools/StreamingToolExecutor.ts`、`src/services/tools/toolExecution.ts`、`src/services/tools/toolOrchestration.ts` 中的工具编排。 +- `src/cli/structuredIO.ts` 和 `src/cli/remoteIO.ts` 中的远程/结构化流式层。 + +### Rust 已存在 +证据: +- `rust/crates/runtime/src/conversation.rs` 中的核心循环。 +- `rust/crates/claw-cli/src/main.rs` 中的流式/工具事件转换。 +- `rust/crates/runtime/src/session.rs` 中的会话持久化。 + +### Rust 中缺失或损坏 +- 没有类似 TS 的挂钩感知编排层。 +- 没有 TS 结构化/远程助手传输栈。 +- 没有更丰富的 TS 助手/会话历史/后台任务集成。 +- JSON 输出路径在此分支上不再仅限单次交互,但输出纯净度仍落后于 TS 传输预期。 + +**状态:** 强大的核心循环,缺失编排层。 + +--- + +## services/ (服务:API 客户端、认证、模型、MCP) + +### TS 已存在 +证据: +- `src/services/api/*` 下的 API 服务。 +- `src/services/oauth/*` 下的 OAuth 服务。 +- `src/services/mcp/*` 下的 MCP 服务。 +- `src/services/*` 下的分析、提示建议、会话记忆、插件操作、设置同步、策略限制、团队记忆同步、通知器、语音等附加服务层。 + +### Rust 已存在 +证据: +- `rust/crates/api/src/{client,error,sse,types}.rs` 中的核心 Anthropic API 客户端。 +- `rust/crates/runtime/src/oauth.rs` 中的 OAuth 支持。 +- `rust/crates/runtime/src/{config,mcp,mcp_client,mcp_stdio}.rs` 中的 MCP 配置/引导/客户端支持。 +- `rust/crates/runtime/src/usage.rs` 中的用量统计。 +- `rust/crates/runtime/src/remote.rs` 中的远程上游代理支持。 + +### Rust 中缺失或损坏 +- 除了核心消息传递/认证/MCP 之外,大部分 TS 服务生态系统都缺失。 +- 没有等效于 TS 的插件服务层。 +- 没有等效于 TS 的分析/设置同步/策略限制/团队记忆子系统。 +- 没有 TS 风格的 MCP 连接器/UI 层。 +- 模型/提供商的用户体验(ergonomics)仍比 TS 薄弱。 + +**状态:** 核心基础已存在;更广泛的服务生态系统缺失。 + +--- + +## 正在处理的分支中的关键缺陷状态 + +### 已修复 +- **已启用提示词模式工具 (Prompt mode tools)** + - `rust/crates/claw-cli/src/main.rs` 现在使用 `LiveCli::new(model, true, ...)` 构建提示词模式。 +- **默认权限模式 = DangerFullAccess** + - `rust/crates/claw-cli/src/main.rs` 中的运行时默认值现在解析为 `DangerFullAccess`。 + - `rust/crates/claw-cli/src/args.rs` 中的 Clap 默认值也使用 `DangerFullAccess`。 + - `rust/crates/claw-cli/src/init.rs` 中的初始化模板写入了 `dontAsk`。 +- **流式 `{}` 工具输入前缀 Bug** + - `rust/crates/claw-cli/src/main.rs` 现在仅针对流式工具输入剥离初始空对象,同时保留非流式响应中的合法 `{}`。 +- **无限最大迭代次数 (max_iterations)** + - 在 `rust/crates/runtime/src/conversation.rs` 中通过 `usize::MAX` 进行了验证。 + +### 剩余的显著对等问题 +- **JSON 提示输出纯净度** + - 支持工具的 JSON 模式现在可以循环,但实证验证显示,当工具启动时,在 JSON 之前仍会出现人类可读的工具结果输出。 diff --git a/README.md b/README.md index a74ce3b..934d2a6 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ -# Rewriting Project Claw Code +# 重写项目:Claw Code

- ⭐ The fastest repo in history to surpass 50K stars, reaching the milestone in just 2 hours after publication ⭐ + ⭐ 历史上最快超过 5 万颗星的仓库,发布后仅 2 小时即达成里程碑 ⭐

@@ -19,7 +19,7 @@

- Better Harness Tools, not merely storing the archive of leaked Claude Code + 更好的 Harness 工具,而不仅仅是存储泄露的 Claw Code 存档

@@ -27,63 +27,86 @@

> [!IMPORTANT] -> **Rust port is now in progress** on the [`dev/rust`](https://github.com/instructkr/claw-code/tree/dev/rust) branch and is expected to be merged into main today. The Rust implementation aims to deliver a faster, memory-safe harness runtime. Stay tuned — this will be the definitive version of the project. +> **Rust 移植工作目前正在 [`dev/rust`](https://github.com/instructkr/claw-code/tree/dev/rust) 分支上进行**,预计今天将合并到主分支。Rust 实现旨在提供更快、内存安全的 harness 运行时。敬请期待——这将是该项目的最终版本。 -> If you find this work useful, consider [sponsoring @instructkr on GitHub](https://github.com/sponsors/instructkr) to support continued open-source harness engineering research. +> 如果你觉得这项工作有用,请考虑在 GitHub 上 [赞助 @instructkr](https://github.com/sponsors/instructkr) 以支持持续的开源 harness 工程研究。 --- -## Backstory +## Rust 移植 -At 4 AM on March 31, 2026, I woke up to my phone blowing up with notifications. The Claude Code source had been exposed, and the entire dev community was in a frenzy. My girlfriend in Korea was genuinely worried I might face legal action from Anthropic just for having the code on my machine — so I did what any engineer would do under pressure: I sat down, ported the core features to Python from scratch, and pushed it before the sun came up. +`rust/` 目录下的 Rust 工作区是该项目的当前系统语言移植版本。 -The whole thing was orchestrated end-to-end using [oh-my-codex (OmX)](https://github.com/Yeachan-Heo/oh-my-codex) by [@bellman_ych](https://x.com/bellman_ych) — a workflow layer built on top of OpenAI's Codex ([@OpenAIDevs](https://x.com/OpenAIDevs)). I used `$team` mode for parallel code review and `$ralph` mode for persistent execution loops with architect-level verification. The entire porting session — from reading the original harness structure to producing a working Python tree with tests — was driven through OmX orchestration. +它目前包括: -The result is a clean-room Python rewrite that captures the architectural patterns of Claude Code's agent harness without copying any proprietary source. I'm now actively collaborating with [@bellman_ych](https://x.com/bellman_ych) — the creator of OmX himself — to push this further. The basic Python foundation is already in place and functional, but we're just getting started. **Stay tuned — a much more capable version is on the way.** +- `crates/api-client` — 具有提供商抽象、OAuth 和流式支持的 API 客户端 +- `crates/runtime` — 会话状态、压缩、MCP 编排、提示词构建 +- `crates/tools` — 工具清单定义和执行框架 +- `crates/commands` — 斜杠命令、技能发现和配置检查 +- `crates/plugins` — 插件模型、挂钩管道和内置插件 +- `crates/compat-harness` — 用于上游编辑器集成的兼容层 +- `crates/claw-cli` — 交互式 REPL、Markdown 渲染以及项目引导/初始化流程 + +运行 Rust 构建: + +```bash +cd rust +cargo build --release +``` + +## 背景故事 + +2026年3月31日凌晨4点,我被手机弹出的漫天通知震醒。Claw Code 的源代码被曝光了,整个开发者社区都陷入了疯狂。我在韩国的女朋友真的很担心我仅仅因为机器上有这些代码而面临原作者的法律诉讼——于是我在压力之下做了任何工程师都会做的事:我坐下来,从头开始将核心功能移植到 Python,并在日出前提交了代码。 + +整个过程是使用 [@bellman_ych](https://x.com/bellman_ych) 开发的 [oh-my-codex (OmX)](https://github.com/Yeachan-Heo/oh-my-codex) 端到端编排的——这是一个构建在 OpenAI Codex ([@OpenAIDevs](https://x.com/OpenAIDevs)) 之上的工作流层。我使用了 `$team` 模式进行并行代码审查,并使用了 `$ralph` 模式进行带有架构级验证的持久执行循环。整个移植过程——从阅读原始 harness 结构到生成具有测试的可用 Python 树——都是通过 OmX 编排驱动的。 + +结果是一个净室(clean-room)Python 重写版本,它捕捉了 Claw Code 代理 harness 的架构模式,而没有复制任何专有源代码。我现在正与 OmX 的创始人 [@bellman_ych](https://x.com/bellman_ych) 本人积极合作,以进一步推进这项工作。基本的 Python 基础已经就绪并可以运行,但我们才刚刚开始。 **敬请期待——一个更强大的版本正在路上。** + +Rust 移植版本是同时使用 [oh-my-codex (OmX)](https://github.com/Yeachan-Heo/oh-my-codex) 和 [oh-my-opencode (OmO)](https://github.com/code-yeongyu/oh-my-openagent) 开发的:OmX 驱动了脚手架搭建、编排和架构方向,而 OmO 则用于后期的实现加速和验证支持。 https://github.com/instructkr/claw-code ![Tweet screenshot](assets/tweet-screenshot.png) -## The Creators Featured in Wall Street Journal For Avid Claude Code Fans +## 《华尔街日报》对 Claw Code 狂热粉丝创建者的报道 -I've been deeply interested in **harness engineering** — studying how agent systems wire tools, orchestrate tasks, and manage runtime context. This isn't a sudden thing. The Wall Street Journal featured my work earlier this month, documenting how I've been one of the most active power users exploring these systems: +我对 **harness 工程** 深感兴趣——研究代理系统如何连接工具、编排任务以及管理运行时上下文。这并非突发奇想。《华尔街日报》在本月早些时候报道了我的工作,记录了我如何成为探索这些系统最活跃的超级用户之一: -> AI startup worker Sigrid Jin, who attended the Seoul dinner, single-handedly used 25 billion of Claude Code tokens last year. At the time, usage limits were looser, allowing early enthusiasts to reach tens of billions of tokens at a very low cost. +> AI 创业公司员工 Sigrid Jin,去年凭一己之力使用了 250 亿个 Claw Code token。当时,使用限制较为宽松,使得早期发烧友能够以极低的成本触及数百亿个 token。 > -> Despite his countless hours with Claude Code, Jin isn't faithful to any one AI lab. The tools available have different strengths and weaknesses, he said. Codex is better at reasoning, while Claude Code generates cleaner, more shareable code. +> 尽管 Jin 在 Claw Code 上花费了无数个小时,但他并不忠于任何一家 AI 实验室。他说,现有的工具各有所长。Codex 擅长推理,而 Claw Code 生成的代码更干净、更易于分享。 > -> Jin flew to San Francisco in February for Claude Code's first birthday party, where attendees waited in line to compare notes with Cherny. The crowd included a practicing cardiologist from Belgium who had built an app to help patients navigate care, and a California lawyer who made a tool for automating building permit approvals using Claude Code. +> Jin 于 2 月份飞往旧金山参加 Claw Code 的一周年派对,与会者排队与 Cherny 交流心得。人群中包括一名来自比利时的执业心脏病专家(他开发了一个帮助患者就医的应用程序),以及一名加州律师(他使用 Claw Code 制作了一个自动批准建筑许可的工具)。 > -> "It was basically like a sharing party," Jin said. "There were lawyers, there were doctors, there were dentists. They did not have software engineering backgrounds." +> “这基本上就像是一个分享派对,”Jin 说。“有律师,有医生,有牙医。他们并没有软件工程背景。” > -> — *The Wall Street Journal*, March 21, 2026, [*"The Trillion Dollar Race to Automate Our Entire Lives"*](https://lnkd.in/gs9td3qd) +> — *《华尔街日报》*,2026年3月21日,[*“这场价值万亿美元的竞赛,旨在自动化我们的整个生活”*](https://lnkd.in/gs9td3qd) ![WSJ Feature](assets/wsj-feature.png) --- -## Porting Status +## 移植状态 -The main source tree is now Python-first. +主源码树现在以 Python 为主。 -- `src/` contains the active Python porting workspace -- `tests/` verifies the current Python workspace -- the exposed snapshot is no longer part of the tracked repository state +- `src/` 包含活跃的 Python 移植工作区 +- `tests/` 验证当前的 Python 工作区 +- 曝光的快照不再是受追踪的仓库状态的一部分 -The current Python workspace is not yet a complete one-to-one replacement for the original system, but the primary implementation surface is now Python. +目前的 Python 工作区尚未完全替代原始系统,但主要的实现界面现在是 Python。 -## Why this rewrite exists +## 为什么进行这次重写 -I originally studied the exposed codebase to understand its harness, tool wiring, and agent workflow. After spending more time with the legal and ethical questions—and after reading the essay linked below—I did not want the exposed snapshot itself to remain the main tracked source tree. +我最初研究曝光的代码库是为了了解其 harness、工具连接和代理工作流。在深入思考法律和伦理问题——并在阅读了下面链接的论文后——我不希望曝光的快照本身继续作为主要追踪的源码树。 -This repository now focuses on Python porting work instead. +因此,本仓库现在专注于 Python 移植工作。 -## Repository Layout +## 仓库布局 ```text . -├── src/ # Python porting workspace +├── src/ # Python 移植工作区 │ ├── __init__.py │ ├── commands.py │ ├── main.py @@ -92,100 +115,114 @@ This repository now focuses on Python porting work instead. │ ├── query_engine.py │ ├── task.py │ └── tools.py -├── tests/ # Python verification -├── assets/omx/ # OmX workflow screenshots +├── rust/ # Rust 移植 (claw CLI) +│ ├── crates/api/ # API 客户端 + 流式处理 +│ ├── crates/runtime/ # 会话、工具、MCP、配置 +│ ├── crates/claw-cli/ # 交互式 CLI 二进制文件 +│ ├── crates/plugins/ # 插件系统 +│ ├── crates/commands/ # 斜杠命令 +│ ├── crates/server/ # HTTP/SSE 服务器 (axum) +│ ├── crates/lsp/ # LSP 客户端集成 +│ └── crates/tools/ # 工具规范 +├── tests/ # Python 验证 +├── assets/omx/ # OmX 工作流截图 ├── 2026-03-09-is-legal-the-same-as-legitimate-ai-reimplementation-and-the-erosion-of-copyleft.md └── README.md ``` -## Python Workspace Overview +## Python 工作区概览 -The new Python `src/` tree currently provides: +新的 Python `src/` 树目前提供: -- **`port_manifest.py`** — summarizes the current Python workspace structure -- **`models.py`** — dataclasses for subsystems, modules, and backlog state -- **`commands.py`** — Python-side command port metadata -- **`tools.py`** — Python-side tool port metadata -- **`query_engine.py`** — renders a Python porting summary from the active workspace -- **`main.py`** — a CLI entrypoint for manifest and summary output +- **`port_manifest.py`** — 总结当前 Python 工作区的结构 +- **`models.py`** — 用于子系统、模块和积压状态的数据类 +- **`commands.py`** — Python 侧的命令移植元数据 +- **`tools.py`** — Python 侧的工具移植元数据 +- **`query_engine.py`** — 从活跃工作区渲染 Python 移植摘要 +- **`main.py`** — 用于清单和摘要输出的 CLI 入口点 -## Quickstart +## 快速开始 -Render the Python porting summary: +渲染 Python 移植摘要: ```bash python3 -m src.main summary ``` -Print the current Python workspace manifest: +打印当前 Python 工作区清单: ```bash python3 -m src.main manifest ``` -List the current Python modules: +列出当前的 Python 模块: ```bash python3 -m src.main subsystems --limit 16 ``` -Run verification: +运行验证: ```bash python3 -m unittest discover -s tests -v ``` -Run the parity audit against the local ignored archive (when present): +对本地忽略的存档运行一致性审计(如果存在): ```bash python3 -m src.main parity-audit ``` -Inspect mirrored command/tool inventories: +检查镜像的命令/工具库: ```bash python3 -m src.main commands --limit 10 python3 -m src.main tools --limit 10 ``` -## Current Parity Checkpoint +## 当前对比检查点 -The port now mirrors the archived root-entry file surface, top-level subsystem names, and command/tool inventories much more closely than before. However, it is **not yet** a full runtime-equivalent replacement for the original TypeScript system; the Python tree still contains fewer executable runtime slices than the archived source. +该移植版本现在比以前更紧密地镜像了存档的根入口文件表面、顶层子系统名称以及命令/工具清单。然而,它**尚未**完全替代原始的 TypeScript 系统;Python 树中可执行的运行时切片仍少于存档源码。 +## 使用 `oh-my-codex` 和 `oh-my-opencode` 构建 -## Built with `oh-my-codex` +本仓库的移植、净室加固和验证工作流是在 Yeachan Heo 工具栈的 AI 辅助下完成的,其中 **oh-my-codex (OmX)** 是主要的脚手架和编排层。 -The restructuring and documentation work on this repository was AI-assisted and orchestrated with Yeachan Heo's [oh-my-codex (OmX)](https://github.com/Yeachan-Heo/oh-my-codex), layered on top of Codex. +- [**oh-my-codex (OmX)**](https://github.com/Yeachan-Heo/oh-my-codex) — 脚手架、编排、架构方向和核心移植工作流 +- [**oh-my-opencode (OmO)**](https://github.com/code-yeongyu/oh-my-openagent) — 实现加速、清理和验证支持 -- **`$team` mode:** used for coordinated parallel review and architectural feedback -- **`$ralph` mode:** used for persistent execution, verification, and completion discipline -- **Codex-driven workflow:** used to turn the main `src/` tree into a Python-first porting workspace +移植过程中使用的关键工作流模式: -### OmX workflow screenshots +- **`$team` 模式:** 协调的并行审查和架构反馈 +- **`$ralph` 模式:** 持久执行、验证和完成纪律 +- **净室通过:** 跨 Rust 工作区的命名/品牌清理、QA 和发布验证 +- **手动和现场验证:** 构建、测试、手动 QA 以及发布前的真实 API 路径验证 + +### OmX 工作流截图 ![OmX workflow screenshot 1](assets/omx/omx-readme-review-1.png) -*Ralph/team orchestration view while the README and essay context were being reviewed in terminal panes.* +*Ralph/team 编排视图,终端面板中正在审查 README 和论文背景。* ![OmX workflow screenshot 2](assets/omx/omx-readme-review-2.png) -*Split-pane review and verification flow during the final README wording pass.* +*最后一次 README 措辞审核期间的分屏审查和验证流程。* -## Community +## 社区

instructkr

-Join the [**instructkr Discord**](https://instruct.kr/) — the best Korean language model community. Come chat about LLMs, harness engineering, agent workflows, and everything in between. +加入 [**instructkr Discord**](https://instruct.kr/) —— 最佳韩国语言模型社区。来这里聊聊 LLM、harness 工程、代理工作流以及其中的一切。 [![Discord](https://img.shields.io/badge/Join%20Discord-instruct.kr-5865F2?logo=discord&style=for-the-badge)](https://instruct.kr/) ## Star History -See the chart at the top of this README. +见 README 顶部的图表。 -## Ownership / Affiliation Disclaimer +## 所有权 / 附属免责声明 -- This repository does **not** claim ownership of the original Claude Code source material. -- This repository is **not affiliated with, endorsed by, or maintained by Anthropic**. +- 本仓库**不**主张对原始 Claw Code 源码材料的所有权。 +- 本仓库**不隶属于、不被背书、也不由原作者维护**。 diff --git a/rust/.claude/sessions/session-1775007453382.json b/rust/.claude/sessions/session-1775007453382.json deleted file mode 100644 index d45e491..0000000 --- a/rust/.claude/sessions/session-1775007453382.json +++ /dev/null @@ -1 +0,0 @@ -{"messages":[],"version":1} \ No newline at end of file diff --git a/rust/.claude/sessions/session-1775007484031.json b/rust/.claude/sessions/session-1775007484031.json deleted file mode 100644 index d45e491..0000000 --- a/rust/.claude/sessions/session-1775007484031.json +++ /dev/null @@ -1 +0,0 @@ -{"messages":[],"version":1} \ No newline at end of file diff --git a/rust/.claude/sessions/session-1775007490104.json b/rust/.claude/sessions/session-1775007490104.json deleted file mode 100644 index d45e491..0000000 --- a/rust/.claude/sessions/session-1775007490104.json +++ /dev/null @@ -1 +0,0 @@ -{"messages":[],"version":1} \ No newline at end of file diff --git a/rust/.claude/sessions/session-1775007981374.json b/rust/.claude/sessions/session-1775007981374.json deleted file mode 100644 index d45e491..0000000 --- a/rust/.claude/sessions/session-1775007981374.json +++ /dev/null @@ -1 +0,0 @@ -{"messages":[],"version":1} \ No newline at end of file diff --git a/rust/.claude/sessions/session-1775008007069.json b/rust/.claude/sessions/session-1775008007069.json deleted file mode 100644 index d45e491..0000000 --- a/rust/.claude/sessions/session-1775008007069.json +++ /dev/null @@ -1 +0,0 @@ -{"messages":[],"version":1} \ No newline at end of file diff --git a/rust/.claude/sessions/session-1775008071886.json b/rust/.claude/sessions/session-1775008071886.json deleted file mode 100644 index d45e491..0000000 --- a/rust/.claude/sessions/session-1775008071886.json +++ /dev/null @@ -1 +0,0 @@ -{"messages":[],"version":1} \ No newline at end of file diff --git a/rust/.github/workflows/ci.yml b/rust/.github/workflows/ci.yml new file mode 100644 index 0000000..73459b8 --- /dev/null +++ b/rust/.github/workflows/ci.yml @@ -0,0 +1,36 @@ +name: CI + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + rust: + name: ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: + - ubuntu-latest + - macos-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Run cargo check + run: cargo check --workspace + + - name: Run cargo test + run: cargo test --workspace + + - name: Run release build + run: cargo build --release diff --git a/rust/CONTRIBUTING.md b/rust/CONTRIBUTING.md new file mode 100644 index 0000000..759fb9e --- /dev/null +++ b/rust/CONTRIBUTING.md @@ -0,0 +1,43 @@ +# 贡献指南 + +感谢你为 Claw Code 做出贡献。 + +## 开发设置 + +- 安装稳定的 Rust 工具链。 +- 在此 Rust 工作区的仓库根目录下进行开发。如果你从父仓库根目录开始,请先执行 `cd rust/`。 + +## 构建 + +```bash +cargo build +cargo build --release +``` + +## 测试与验证 + +在开启 Pull Request 之前,请运行完整的 Rust 验证集: + +```bash +cargo fmt --all --check +cargo clippy --workspace --all-targets -- -D warnings +cargo check --workspace +cargo test --workspace +``` + +如果你更改了行为,请在同一个 Pull Request 中添加或更新相关的测试。 + +## 代码风格 + +- 遵循所修改 crate 中的现有模式,而不是引入新的风格。 +- 使用 `rustfmt` 格式化代码。 +- 确保你修改的工作区目标的 `clippy` 检查通过。 +- 优先采用针对性的 diff,而不是顺便进行的重构。 + +## Pull Request + +- 从 `main` 分支拉取新分支。 +- 确保每个 Pull Request 的范围仅限于一个明确的更改。 +- 说明更改动机、实现摘要以及你运行的验证。 +- 在请求审查之前,确保本地检查已通过。 +- 如果审查反馈导致行为更改,请重新运行相关的验证命令。 diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 9030127..443b79d 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -28,12 +28,86 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atomic-waker" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "base64" version = "0.22.1" @@ -49,6 +123,12 @@ dependencies = [ "serde", ] +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + [[package]] name = "bitflags" version = "2.11.0" @@ -98,11 +178,40 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "claw-cli" +version = "0.1.0" +dependencies = [ + "api", + "commands", + "compat-harness", + "crossterm", + "plugins", + "pulldown-cmark", + "runtime", + "rustyline", + "serde_json", + "syntect", + "tokio", + "tools", +] + +[[package]] +name = "clipboard-win" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bde03770d3df201d4fb868f2c9c59e66a3e4e2bd06692a0fe701e7103c7e84d4" +dependencies = [ + "error-code", +] + [[package]] name = "commands" version = "0.1.0" dependencies = [ + "plugins", "runtime", + "serde_json", ] [[package]] @@ -138,11 +247,11 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" dependencies = [ - "bitflags", + "bitflags 2.11.0", "crossterm_winapi", "mio", "parking_lot", - "rustix", + "rustix 0.38.44", "signal-hook", "signal-hook-mio", "winapi", @@ -197,6 +306,12 @@ dependencies = [ "syn", ] +[[package]] +name = "endian-type" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" + [[package]] name = "equivalent" version = "1.0.2" @@ -213,6 +328,23 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "error-code" +version = "3.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59" + +[[package]] +name = "fd-lock" +version = "4.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" +dependencies = [ + "cfg-if", + "rustix 1.1.4", + "windows-sys 0.59.0", +] + [[package]] name = "find-msvc-tools" version = "0.1.9" @@ -229,6 +361,15 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "fluent-uri" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17c704e9dbe1ddd863da1e6ff3567795087b1eb201ce80d8fa81162e1516500d" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "fnv" version = "1.0.7" @@ -266,6 +407,17 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.32" @@ -286,6 +438,7 @@ checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" dependencies = [ "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -351,6 +504,15 @@ version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +[[package]] +name = "home" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "http" version = "1.4.0" @@ -390,6 +552,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.9.0" @@ -403,6 +571,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -614,6 +783,12 @@ version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + [[package]] name = "litemap" version = "0.8.1" @@ -641,12 +816,48 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "lsp" +version = "0.1.0" +dependencies = [ + "lsp-types", + "serde", + "serde_json", + "tokio", + "url", +] + +[[package]] +name = "lsp-types" +version = "0.97.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53353550a17c04ac46c585feb189c2db82154fc84b79c7a66c96c2c644f66071" +dependencies = [ + "bitflags 1.3.2", + "fluent-uri", + "serde", + "serde_json", + "serde_repr", +] + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "memchr" version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "miniz_oxide" version = "0.8.9" @@ -669,6 +880,27 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nibble_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a5d83df9f36fe23f0c3648c6bbb8b0298bb5f1939c8f2704431371f4b84d43" +dependencies = [ + "smallvec", +] + +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "num-conv" version = "0.2.1" @@ -687,7 +919,7 @@ version = "6.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "336b9c63443aceef14bea841b899035ae3abe89b7c486aaf4c5bd8aafedac3f0" dependencies = [ - "bitflags", + "bitflags 2.11.0", "libc", "once_cell", "onig_sys", @@ -757,6 +989,14 @@ dependencies = [ "time", ] +[[package]] +name = "plugins" +version = "0.1.0" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "potential_utf" version = "0.1.4" @@ -796,7 +1036,7 @@ version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c3a14896dfa883796f1cb410461aef38810ea05f2b2c33c5aded3649095fdad" dependencies = [ - "bitflags", + "bitflags 2.11.0", "getopts", "memchr", "pulldown-cmark-escape", @@ -888,6 +1128,16 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "radix_trie" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c069c179fcdc6a2fe24d8d18305cf085fdbd4f922c041943e203685d6a1c58fd" +dependencies = [ + "endian-type", + "nibble_vec", +] + [[package]] name = "rand" version = "0.9.2" @@ -923,7 +1173,7 @@ version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ - "bitflags", + "bitflags 2.11.0", ] [[package]] @@ -985,12 +1235,14 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-rustls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", ] @@ -1014,6 +1266,8 @@ name = "runtime" version = "0.1.0" dependencies = [ "glob", + "lsp", + "plugins", "regex", "serde", "serde_json", @@ -1034,11 +1288,24 @@ version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags", + "bitflags 2.11.0", "errno", "libc", - "linux-raw-sys", - "windows-sys 0.52.0", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags 2.11.0", + "errno", + "libc", + "linux-raw-sys 0.12.1", + "windows-sys 0.61.2", ] [[package]] @@ -1098,6 +1365,28 @@ dependencies = [ "tools", ] +[[package]] +name = "rustyline" +version = "15.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ee1e066dc922e513bda599c6ccb5f3bb2b0ea5870a579448f2622993f0a9a2f" +dependencies = [ + "bitflags 2.11.0", + "cfg-if", + "clipboard-win", + "fd-lock", + "home", + "libc", + "log", + "memchr", + "nix", + "radix_trie", + "unicode-segmentation", + "unicode-width", + "utf8parse", + "windows-sys 0.59.0", +] + [[package]] name = "ryu" version = "1.0.23" @@ -1162,6 +1451,28 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1174,6 +1485,19 @@ dependencies = [ "serde", ] +[[package]] +name = "server" +version = "0.1.0" +dependencies = [ + "async-stream", + "axum", + "reqwest", + "runtime", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "sha2" version = "0.10.9" @@ -1427,14 +1751,30 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tools" version = "0.1.0" dependencies = [ + "api", + "plugins", "reqwest", "runtime", "serde", "serde_json", + "tokio", ] [[package]] @@ -1450,6 +1790,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -1458,7 +1799,7 @@ version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ - "bitflags", + "bitflags 2.11.0", "bytes", "futures-util", "http", @@ -1488,6 +1829,7 @@ version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ + "log", "pin-project-lite", "tracing-core", ] @@ -1525,6 +1867,12 @@ version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +[[package]] +name = "unicode-segmentation" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" + [[package]] name = "unicode-width" version = "0.2.2" @@ -1555,6 +1903,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "version_check" version = "0.9.5" @@ -1650,6 +2004,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.93" @@ -1725,6 +2092,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.60.2" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 4a2f4d4..aa2f4ea 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -8,6 +8,10 @@ edition = "2021" license = "MIT" publish = false +[workspace.dependencies] +lsp-types = "0.97" +serde_json = "1" + [workspace.lints.rust] unsafe_code = "forbid" diff --git a/rust/README.md b/rust/README.md index 26a0d4a..fb4ef6b 100644 --- a/rust/README.md +++ b/rust/README.md @@ -1,230 +1,122 @@ -# Rusty Claude CLI +# Claw Code -`rust/` contains the Rust workspace for the integrated `rusty-claude-cli` deliverable. -It is intended to be something you can clone, build, and run directly. +Claw Code 是一个使用安全 Rust 实现的本地编程代理(coding-agent)命令行工具。它的设计灵感来自 **Claude Code**,并作为一个**净室实现(clean-room implementation)**开发:旨在提供强大的本地代理体验,但它**不是** Claude Code 的直接移植或复制。 -## Workspace layout +Rust 工作区是当前主要的产品界面。`claw` 二进制文件在单个工作区内提供交互式会话、单次提示、工作区感知工具、本地代理工作流以及支持插件的操作。 -```text -rust/ -├── Cargo.toml -├── Cargo.lock -├── README.md -└── crates/ - ├── api/ # Anthropic API client + SSE streaming support - ├── commands/ # Shared slash-command metadata/help surfaces - ├── compat-harness/ # Upstream TS manifest extraction harness - ├── runtime/ # Session/runtime/config/prompt orchestration - ├── rusty-claude-cli/ # Main CLI binary - └── tools/ # Built-in tool implementations -``` +## 当前状态 -## Prerequisites +- **版本:** `0.1.0` +- **发布阶段:** 初始公开发布,源码编译分发 +- **主要实现:** 本仓库中的 Rust 工作区 +- **平台焦点:** macOS 和 Linux 开发工作站 -- Rust toolchain installed (`rustup`, stable toolchain) -- Network access and Anthropic credentials for live prompt/REPL usage +## 安装、构建与运行 -## Build +### 准备工作 -From the repository root: +- Rust 稳定版工具链 +- Cargo +- 你想使用的模型的提供商凭据 + +### 身份验证 + +兼容 Anthropic 的模型: ```bash -cd rust -cargo build --release -p rusty-claude-cli +export ANTHROPIC_API_KEY="..." +# 使用兼容的端点时可选 +export ANTHROPIC_BASE_URL="https://api.anthropic.com" ``` -The optimized binary will be written to: +Grok 模型: ```bash -./target/release/rusty-claude-cli +export XAI_API_KEY="..." +# 使用兼容的端点时可选 +export XAI_BASE_URL="https://api.x.ai" ``` -## Test - -Run the verified workspace test suite used for release-readiness: +也可以使用 OAuth 登录: ```bash -cd rust -cargo test --workspace --exclude compat-harness +cargo run --bin claw -- login ``` -## Quick start - -### Show help +### 本地安装 ```bash -cd rust -cargo run -p rusty-claude-cli -- --help +cargo install --path crates/claw-cli --locked ``` -### Print version +### 从源码构建 ```bash -cd rust -cargo run -p rusty-claude-cli -- --version +cargo build --release -p claw-cli ``` -### Login with OAuth +### 运行 -Configure `settings.json` with an `oauth` block containing `clientId`, `authorizeUrl`, `tokenUrl`, optional `callbackPort`, and optional `scopes`, then run: +在工作区内运行: ```bash -cd rust -cargo run -p rusty-claude-cli -- login +cargo run --bin claw -- --help +cargo run --bin claw -- +cargo run --bin claw -- prompt "总结此工作区" +cargo run --bin claw -- --model sonnet "审查最新更改" ``` -This opens the browser, listens on the configured localhost callback, exchanges the auth code for tokens, and stores OAuth credentials in `~/.claude/credentials.json` (or `$CLAUDE_CONFIG_HOME/credentials.json`). - -### Logout +运行发布版本: ```bash -cd rust -cargo run -p rusty-claude-cli -- logout +./target/release/claw +./target/release/claw prompt "解释 crates/runtime" ``` -This removes only the stored OAuth credentials and preserves unrelated JSON fields in `credentials.json`. +## 支持的功能 -### Self-update +- 交互式 REPL 和单次提示执行 +- 已保存会话的检查和恢复流程 +- 内置工作区工具:shell、文件读/写/编辑、搜索、网页获取/搜索、待办事项和笔记本更新 +- 斜杠命令:状态、压缩、配置检查、差异(diff)、导出、会话管理和版本报告 +- 本地代理和技能发现:通过 `claw agents` 和 `claw skills` +- 通过命令行和斜杠命令界面发现并管理插件 +- OAuth 登录/注销,以及从命令行选择模型/提供商 +- 工作区感知的指令/配置加载(`CLAW.md`、配置文件、权限、插件设置) -```bash -cd rust -cargo run -p rusty-claude-cli -- self-update -``` +## 当前限制 -The command checks the latest GitHub release for `instructkr/clawd-code`, compares it to the current binary version, downloads the matching binary asset plus checksum manifest, verifies SHA-256, replaces the current executable, and prints the release changelog. If no published release or matching asset exists, it exits safely with an explanatory message. +- 目前公开发布**仅限源码构建**;此工作区尚未设置 crates.io 发布 +- GitHub CI 验证 `cargo check`、`cargo test` 和发布构建,但尚未提供自动化的发布打包 +- 当前 CI 目标为 Ubuntu 和 macOS;Windows 的发布就绪性仍待建立 +- 一些实时提供商集成覆盖是可选的,因为它们需要外部凭据 and 网络访问 +- 命令界面可能会在 `0.x` 系列期间继续演进 -## Usage examples +## 实现现状 -### 1) Prompt mode +Rust 工作区是当前的产品实现。目前包含以下 crate: -Send one prompt, stream the answer, then exit: +- `claw-cli` — 面向用户的二进制文件 +- `api` — 提供商客户端和流式处理 +- `runtime` — 会话、配置、权限、提示词和运行时循环 +- `tools` — 内置工具实现 +- `commands` — 斜杠命令注册和处理程序 +- `plugins` — 插件发现、注册和生命周期支持 +- `lsp` — 语言服务器协议支持类型和进程助手 +- `server` 和 `compat-harness` — 支持服务和兼容性工具 -```bash -cd rust -cargo run -p rusty-claude-cli -- prompt "Summarize the architecture of this repository" -``` +## 路线图 -Use a specific model: +- 发布打包好的构件,用于公共安装 +- 添加可重复的发布工作流和长期维护的变更日志(changelog)规范 +- 将平台验证扩展到当前 CI 矩阵之外 +- 添加更多以任务为中心的示例和操作员文档 +- 继续加强 Rust 实现的功能覆盖并磨炼用户体验(UX) -```bash -cd rust -cargo run -p rusty-claude-cli -- --model claude-sonnet-4-20250514 prompt "List the key crates in this workspace" -``` +## 发行版本说明 -Restrict enabled tools in an interactive session: +- 0.1.0 发行说明草案:[`docs/releases/0.1.0.md`](docs/releases/0.1.0.md) -```bash -cd rust -cargo run -p rusty-claude-cli -- --allowedTools read,glob -``` +## 许可 -Bootstrap Claude project files for the current repo: - -```bash -cd rust -cargo run -p rusty-claude-cli -- init -``` - -### 2) REPL mode - -Start the interactive shell: - -```bash -cd rust -cargo run -p rusty-claude-cli -- -``` - -Inside the REPL, useful commands include: - -```text -/help -/status -/model claude-sonnet-4-20250514 -/permissions workspace-write -/cost -/compact -/memory -/config -/init -/diff -/version -/export notes.txt -/sessions -/session list -/exit -``` - -### 3) Resume an existing session - -Inspect or maintain a saved session file without entering the REPL: - -```bash -cd rust -cargo run -p rusty-claude-cli -- --resume session-123456 /status /compact /cost -``` - -You can also inspect memory/config state for a restored session: - -```bash -cd rust -cargo run -p rusty-claude-cli -- --resume ~/.claude/sessions/session-123456.json /memory /config -``` - -## Available commands - -### Top-level CLI commands - -- `prompt ` — run one prompt non-interactively -- `--resume [/commands...]` — inspect or maintain a saved session stored under `~/.claude/sessions/` -- `dump-manifests` — print extracted upstream manifest counts -- `bootstrap-plan` — print the current bootstrap skeleton -- `system-prompt [--cwd PATH] [--date YYYY-MM-DD]` — render the synthesized system prompt -- `self-update` — update the installed binary from the latest GitHub release when a matching asset is available -- `--help` / `-h` — show CLI help -- `--version` / `-V` — print the CLI version and build info locally (no API call) -- `--output-format text|json` — choose non-interactive prompt output rendering -- `--allowedTools ` — restrict enabled tools for interactive sessions and prompt-mode tool use - -### Interactive slash commands - -- `/help` — show command help -- `/status` — show current session status -- `/compact` — compact local session history -- `/model [model]` — inspect or switch the active model -- `/permissions [read-only|workspace-write|danger-full-access]` — inspect or switch permissions -- `/clear [--confirm]` — clear the current local session -- `/cost` — show token usage totals -- `/resume ` — load a saved session into the REPL -- `/config [env|hooks|model]` — inspect discovered Claude config -- `/memory` — inspect loaded instruction memory files -- `/init` — bootstrap `.claude.json`, `.claude/`, `CLAUDE.md`, and local ignore rules -- `/diff` — show the current git diff for the workspace -- `/version` — print version and build metadata locally -- `/export [file]` — export the current conversation transcript -- `/sessions` — list recent managed local sessions from `~/.claude/sessions/` -- `/session [list|switch ]` — inspect or switch managed local sessions -- `/exit` — leave the REPL - -## Environment variables - -### Anthropic/API - -- `ANTHROPIC_API_KEY` — highest-precedence API credential -- `ANTHROPIC_AUTH_TOKEN` — bearer-token override used when no API key is set -- Persisted OAuth credentials in `~/.claude/credentials.json` — used when neither env var is set -- `ANTHROPIC_BASE_URL` — override the Anthropic API base URL -- `ANTHROPIC_MODEL` — default model used by selected live integration tests - -### CLI/runtime - -- `RUSTY_CLAUDE_PERMISSION_MODE` — default REPL permission mode (`read-only`, `workspace-write`, or `danger-full-access`) -- `CLAUDE_CONFIG_HOME` — override Claude config discovery root -- `CLAUDE_CODE_REMOTE` — enable remote-session bootstrap handling when supported -- `CLAUDE_CODE_REMOTE_SESSION_ID` — remote session identifier when using remote mode -- `CLAUDE_CODE_UPSTREAM` — override the upstream TS source path for compat-harness extraction -- `CLAWD_WEB_SEARCH_BASE_URL` — override the built-in web search service endpoint used by tooling - -## Notes - -- `compat-harness` exists to compare the Rust port against the upstream TypeScript codebase and is intentionally excluded from the requested release test run. -- The CLI currently focuses on a practical integrated workflow: prompt execution, REPL operation, session inspection/resume, config discovery, and tool/runtime plumbing. +有关许可详情,请参阅仓库根目录。 diff --git a/rust/crates/api/Cargo.toml b/rust/crates/api/Cargo.toml index c5e152e..b9923a8 100644 --- a/rust/crates/api/Cargo.toml +++ b/rust/crates/api/Cargo.toml @@ -9,7 +9,7 @@ publish.workspace = true reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } runtime = { path = "../runtime" } serde = { version = "1", features = ["derive"] } -serde_json = "1" +serde_json.workspace = true tokio = { version = "1", features = ["io-util", "macros", "net", "rt-multi-thread", "time"] } [lints] diff --git a/rust/crates/api/src/client.rs b/rust/crates/api/src/client.rs index 110a80b..b596777 100644 --- a/rust/crates/api/src/client.rs +++ b/rust/crates/api/src/client.rs @@ -1,1006 +1,141 @@ -use std::collections::VecDeque; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - -use runtime::{ - load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, - OAuthTokenExchangeRequest, -}; -use serde::Deserialize; - use crate::error::ApiError; -use crate::sse::SseParser; +use crate::providers::claw_provider::{self, AuthSource, ClawApiClient}; +use crate::providers::openai_compat::{self, OpenAiCompatClient, OpenAiCompatConfig}; +use crate::providers::{self, Provider, ProviderKind}; use crate::types::{MessageRequest, MessageResponse, StreamEvent}; -const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; -const ANTHROPIC_VERSION: &str = "2023-06-01"; -const REQUEST_ID_HEADER: &str = "request-id"; -const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; -const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); -const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); -const DEFAULT_MAX_RETRIES: u32 = 2; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum AuthSource { - None, - ApiKey(String), - BearerToken(String), - ApiKeyAndBearer { - api_key: String, - bearer_token: String, - }, +async fn send_via_provider( + provider: &P, + request: &MessageRequest, +) -> Result { + provider.send_message(request).await } -impl AuthSource { - pub fn from_env() -> Result { - let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?; - let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?; - match (api_key, auth_token) { - (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer { - api_key, - bearer_token, - }), - (Some(api_key), None) => Ok(Self::ApiKey(api_key)), - (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), - (None, None) => Err(ApiError::MissingApiKey), - } - } - - #[must_use] - pub fn api_key(&self) -> Option<&str> { - match self { - Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key), - Self::None | Self::BearerToken(_) => None, - } - } - - #[must_use] - pub fn bearer_token(&self) -> Option<&str> { - match self { - Self::BearerToken(token) - | Self::ApiKeyAndBearer { - bearer_token: token, - .. - } => Some(token), - Self::None | Self::ApiKey(_) => None, - } - } - - #[must_use] - pub fn masked_authorization_header(&self) -> &'static str { - if self.bearer_token().is_some() { - "Bearer [REDACTED]" - } else { - "" - } - } - - pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - if let Some(api_key) = self.api_key() { - request_builder = request_builder.header("x-api-key", api_key); - } - if let Some(token) = self.bearer_token() { - request_builder = request_builder.bearer_auth(token); - } - request_builder - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] -pub struct OAuthTokenSet { - pub access_token: String, - pub refresh_token: Option, - pub expires_at: Option, - #[serde(default)] - pub scopes: Vec, -} - -impl From for AuthSource { - fn from(value: OAuthTokenSet) -> Self { - Self::BearerToken(value.access_token) - } +async fn stream_via_provider( + provider: &P, + request: &MessageRequest, +) -> Result { + provider.stream_message(request).await } #[derive(Debug, Clone)] -pub struct AnthropicClient { - http: reqwest::Client, - auth: AuthSource, - base_url: String, - max_retries: u32, - initial_backoff: Duration, - max_backoff: Duration, +pub enum ProviderClient { + ClawApi(ClawApiClient), + Xai(OpenAiCompatClient), + OpenAi(OpenAiCompatClient), } -impl AnthropicClient { - #[must_use] - pub fn new(api_key: impl Into) -> Self { - Self { - http: reqwest::Client::new(), - auth: AuthSource::ApiKey(api_key.into()), - base_url: DEFAULT_BASE_URL.to_string(), - max_retries: DEFAULT_MAX_RETRIES, - initial_backoff: DEFAULT_INITIAL_BACKOFF, - max_backoff: DEFAULT_MAX_BACKOFF, +impl ProviderClient { + pub fn from_model(model: &str) -> Result { + Self::from_model_with_default_auth(model, None) + } + + pub fn from_model_with_default_auth( + model: &str, + default_auth: Option, + ) -> Result { + let resolved_model = providers::resolve_model_alias(model); + match providers::detect_provider_kind(&resolved_model) { + ProviderKind::ClawApi => Ok(Self::ClawApi(match default_auth { + Some(auth) => ClawApiClient::from_auth(auth), + None => ClawApiClient::from_env()?, + })), + ProviderKind::Xai => Ok(Self::Xai(OpenAiCompatClient::from_env( + OpenAiCompatConfig::xai(), + )?)), + ProviderKind::OpenAi => Ok(Self::OpenAi(OpenAiCompatClient::from_env( + OpenAiCompatConfig::openai(), + )?)), } } #[must_use] - pub fn from_auth(auth: AuthSource) -> Self { - Self { - http: reqwest::Client::new(), - auth, - base_url: DEFAULT_BASE_URL.to_string(), - max_retries: DEFAULT_MAX_RETRIES, - initial_backoff: DEFAULT_INITIAL_BACKOFF, - max_backoff: DEFAULT_MAX_BACKOFF, + pub const fn provider_kind(&self) -> ProviderKind { + match self { + Self::ClawApi(_) => ProviderKind::ClawApi, + Self::Xai(_) => ProviderKind::Xai, + Self::OpenAi(_) => ProviderKind::OpenAi, } } - pub fn from_env() -> Result { - Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url())) - } - - #[must_use] - pub fn with_auth_source(mut self, auth: AuthSource) -> Self { - self.auth = auth; - self - } - - #[must_use] - pub fn with_auth_token(mut self, auth_token: Option) -> Self { - match ( - self.auth.api_key().map(ToOwned::to_owned), - auth_token.filter(|token| !token.is_empty()), - ) { - (Some(api_key), Some(bearer_token)) => { - self.auth = AuthSource::ApiKeyAndBearer { - api_key, - bearer_token, - }; - } - (Some(api_key), None) => { - self.auth = AuthSource::ApiKey(api_key); - } - (None, Some(bearer_token)) => { - self.auth = AuthSource::BearerToken(bearer_token); - } - (None, None) => { - self.auth = AuthSource::None; - } - } - self - } - - #[must_use] - pub fn with_base_url(mut self, base_url: impl Into) -> Self { - self.base_url = base_url.into(); - self - } - - #[must_use] - pub fn with_retry_policy( - mut self, - max_retries: u32, - initial_backoff: Duration, - max_backoff: Duration, - ) -> Self { - self.max_retries = max_retries; - self.initial_backoff = initial_backoff; - self.max_backoff = max_backoff; - self - } - - #[must_use] - pub fn auth_source(&self) -> &AuthSource { - &self.auth - } - pub async fn send_message( &self, request: &MessageRequest, ) -> Result { - let request = MessageRequest { - stream: false, - ..request.clone() - }; - let response = self.send_with_retry(&request).await?; - let request_id = request_id_from_headers(response.headers()); - let mut response = response - .json::() - .await - .map_err(ApiError::from)?; - if response.request_id.is_none() { - response.request_id = request_id; + match self { + Self::ClawApi(client) => send_via_provider(client, request).await, + Self::Xai(client) | Self::OpenAi(client) => send_via_provider(client, request).await, } - Ok(response) } pub async fn stream_message( &self, request: &MessageRequest, ) -> Result { - let response = self - .send_with_retry(&request.clone().with_streaming()) - .await?; - Ok(MessageStream { - request_id: request_id_from_headers(response.headers()), - response, - parser: SseParser::new(), - pending: VecDeque::new(), - done: false, - }) - } - - pub async fn exchange_oauth_code( - &self, - config: &OAuthConfig, - request: &OAuthTokenExchangeRequest, - ) -> Result { - let response = self - .http - .post(&config.token_url) - .header("content-type", "application/x-www-form-urlencoded") - .form(&request.form_params()) - .send() - .await - .map_err(ApiError::from)?; - let response = expect_success(response).await?; - response - .json::() - .await - .map_err(ApiError::from) - } - - pub async fn refresh_oauth_token( - &self, - config: &OAuthConfig, - request: &OAuthRefreshRequest, - ) -> Result { - let response = self - .http - .post(&config.token_url) - .header("content-type", "application/x-www-form-urlencoded") - .form(&request.form_params()) - .send() - .await - .map_err(ApiError::from)?; - let response = expect_success(response).await?; - response - .json::() - .await - .map_err(ApiError::from) - } - - async fn send_with_retry( - &self, - request: &MessageRequest, - ) -> Result { - let mut attempts = 0; - let mut last_error: Option; - - loop { - attempts += 1; - match self.send_raw_request(request).await { - Ok(response) => match expect_success(response).await { - Ok(response) => return Ok(response), - Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { - last_error = Some(error); - } - Err(error) => return Err(error), - }, - Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { - last_error = Some(error); - } - Err(error) => return Err(error), - } - - if attempts > self.max_retries { - break; - } - - tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; - } - - Err(ApiError::RetriesExhausted { - attempts, - last_error: Box::new(last_error.expect("retry loop must capture an error")), - }) - } - - async fn send_raw_request( - &self, - request: &MessageRequest, - ) -> Result { - let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); - let resolved_base_url = self.base_url.trim_end_matches('/'); - eprintln!("[anthropic-client] resolved_base_url={resolved_base_url}"); - eprintln!("[anthropic-client] request_url={request_url}"); - let request_builder = self - .http - .post(&request_url) - .header("anthropic-version", ANTHROPIC_VERSION) - .header("content-type", "application/json"); - let mut request_builder = self.auth.apply(request_builder); - - eprintln!( - "[anthropic-client] headers x-api-key={} authorization={} anthropic-version={ANTHROPIC_VERSION} content-type=application/json", - if self.auth.api_key().is_some() { - "[REDACTED]" - } else { - "" - }, - self.auth.masked_authorization_header() - ); - - request_builder = request_builder.json(request); - request_builder.send().await.map_err(ApiError::from) - } - - fn backoff_for_attempt(&self, attempt: u32) -> Result { - let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { - return Err(ApiError::BackoffOverflow { - attempt, - base_delay: self.initial_backoff, - }); - }; - Ok(self - .initial_backoff - .checked_mul(multiplier) - .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) - } -} - -impl AuthSource { - pub fn from_env_or_saved() -> Result { - if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { - return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - Some(bearer_token) => Ok(Self::ApiKeyAndBearer { - api_key, - bearer_token, - }), - None => Ok(Self::ApiKey(api_key)), - }; - } - if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - return Ok(Self::BearerToken(bearer_token)); - } - match load_saved_oauth_token() { - Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => { - if token_set.refresh_token.is_some() { - Err(ApiError::Auth( - "saved OAuth token is expired; load runtime OAuth config to refresh it" - .to_string(), - )) - } else { - Err(ApiError::ExpiredOAuthToken) - } - } - Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), - Ok(None) => Err(ApiError::MissingApiKey), - Err(error) => Err(error), + match self { + Self::ClawApi(client) => stream_via_provider(client, request) + .await + .map(MessageStream::ClawApi), + Self::Xai(client) | Self::OpenAi(client) => stream_via_provider(client, request) + .await + .map(MessageStream::OpenAiCompat), } } } -#[must_use] -pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool { - token_set - .expires_at - .is_some_and(|expires_at| expires_at <= now_unix_timestamp()) -} - -pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result, ApiError> { - let Some(token_set) = load_saved_oauth_token()? else { - return Ok(None); - }; - resolve_saved_oauth_token_set(config, token_set).map(Some) -} - -pub fn resolve_startup_auth_source(load_oauth_config: F) -> Result -where - F: FnOnce() -> Result, ApiError>, -{ - if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { - return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer { - api_key, - bearer_token, - }), - None => Ok(AuthSource::ApiKey(api_key)), - }; - } - if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { - return Ok(AuthSource::BearerToken(bearer_token)); - } - - let Some(token_set) = load_saved_oauth_token()? else { - return Err(ApiError::MissingApiKey); - }; - if !oauth_token_is_expired(&token_set) { - return Ok(AuthSource::BearerToken(token_set.access_token)); - } - if token_set.refresh_token.is_none() { - return Err(ApiError::ExpiredOAuthToken); - } - - let Some(config) = load_oauth_config()? else { - return Err(ApiError::Auth( - "saved OAuth token is expired; runtime OAuth config is missing".to_string(), - )); - }; - Ok(AuthSource::from(resolve_saved_oauth_token_set( - &config, token_set, - )?)) -} - -fn resolve_saved_oauth_token_set( - config: &OAuthConfig, - token_set: OAuthTokenSet, -) -> Result { - if !oauth_token_is_expired(&token_set) { - return Ok(token_set); - } - let Some(refresh_token) = token_set.refresh_token.clone() else { - return Err(ApiError::ExpiredOAuthToken); - }; - let client = AnthropicClient::from_auth(AuthSource::None).with_base_url(read_base_url()); - let refreshed = client_runtime_block_on(async { - client - .refresh_oauth_token( - config, - &OAuthRefreshRequest::from_config( - config, - refresh_token, - Some(token_set.scopes.clone()), - ), - ) - .await - })?; - let resolved = OAuthTokenSet { - access_token: refreshed.access_token, - refresh_token: refreshed.refresh_token.or(token_set.refresh_token), - expires_at: refreshed.expires_at, - scopes: refreshed.scopes, - }; - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: resolved.access_token.clone(), - refresh_token: resolved.refresh_token.clone(), - expires_at: resolved.expires_at, - scopes: resolved.scopes.clone(), - }) - .map_err(ApiError::from)?; - Ok(resolved) -} - -fn client_runtime_block_on(future: F) -> Result -where - F: std::future::Future>, -{ - tokio::runtime::Runtime::new() - .map_err(ApiError::from)? - .block_on(future) -} - -fn load_saved_oauth_token() -> Result, ApiError> { - let token_set = load_oauth_credentials().map_err(ApiError::from)?; - Ok(token_set.map(|token_set| OAuthTokenSet { - access_token: token_set.access_token, - refresh_token: token_set.refresh_token, - expires_at: token_set.expires_at, - scopes: token_set.scopes, - })) -} - -fn now_unix_timestamp() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_or(0, |duration| duration.as_secs()) -} - -fn read_env_non_empty(key: &str) -> Result, ApiError> { - match std::env::var(key) { - Ok(value) if !value.is_empty() => Ok(Some(value)), - Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), - Err(error) => Err(ApiError::from(error)), - } -} - -#[cfg(test)] -fn read_api_key() -> Result { - let auth = AuthSource::from_env_or_saved()?; - auth.api_key() - .or_else(|| auth.bearer_token()) - .map(ToOwned::to_owned) - .ok_or(ApiError::MissingApiKey) -} - -#[cfg(test)] -fn read_auth_token() -> Option { - read_env_non_empty("ANTHROPIC_AUTH_TOKEN") - .ok() - .and_then(std::convert::identity) -} - -pub fn read_base_url() -> String { - std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()) -} - -fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { - headers - .get(REQUEST_ID_HEADER) - .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) - .and_then(|value| value.to_str().ok()) - .map(ToOwned::to_owned) -} - #[derive(Debug)] -pub struct MessageStream { - request_id: Option, - response: reqwest::Response, - parser: SseParser, - pending: VecDeque, - done: bool, +pub enum MessageStream { + ClawApi(claw_provider::MessageStream), + OpenAiCompat(openai_compat::MessageStream), } impl MessageStream { #[must_use] pub fn request_id(&self) -> Option<&str> { - self.request_id.as_deref() + match self { + Self::ClawApi(stream) => stream.request_id(), + Self::OpenAiCompat(stream) => stream.request_id(), + } } pub async fn next_event(&mut self) -> Result, ApiError> { - loop { - if let Some(event) = self.pending.pop_front() { - return Ok(Some(event)); - } - - if self.done { - let remaining = self.parser.finish()?; - self.pending.extend(remaining); - if let Some(event) = self.pending.pop_front() { - return Ok(Some(event)); - } - return Ok(None); - } - - match self.response.chunk().await? { - Some(chunk) => { - self.pending.extend(self.parser.push(&chunk)?); - } - None => { - self.done = true; - } - } + match self { + Self::ClawApi(stream) => stream.next_event().await, + Self::OpenAiCompat(stream) => stream.next_event().await, } } } -async fn expect_success(response: reqwest::Response) -> Result { - let status = response.status(); - if status.is_success() { - return Ok(response); - } - - let body = response.text().await.unwrap_or_else(|_| String::new()); - let parsed_error = serde_json::from_str::(&body).ok(); - let retryable = is_retryable_status(status); - - Err(ApiError::Api { - status, - error_type: parsed_error - .as_ref() - .map(|error| error.error.error_type.clone()), - message: parsed_error - .as_ref() - .map(|error| error.error.message.clone()), - body, - retryable, - }) +pub use claw_provider::{ + oauth_token_is_expired, resolve_saved_oauth_token, resolve_startup_auth_source, OAuthTokenSet, +}; +#[must_use] +pub fn read_base_url() -> String { + claw_provider::read_base_url() } -const fn is_retryable_status(status: reqwest::StatusCode) -> bool { - matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) -} - -#[derive(Debug, Deserialize)] -struct AnthropicErrorEnvelope { - error: AnthropicErrorBody, -} - -#[derive(Debug, Deserialize)] -struct AnthropicErrorBody { - #[serde(rename = "type")] - error_type: String, - message: String, +#[must_use] +pub fn read_xai_base_url() -> String { + openai_compat::read_base_url(OpenAiCompatConfig::xai()) } #[cfg(test)] mod tests { - use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; - use std::io::{Read, Write}; - use std::net::TcpListener; - use std::sync::{Mutex, OnceLock}; - use std::thread; - use std::time::{Duration, SystemTime, UNIX_EPOCH}; + use crate::providers::{detect_provider_kind, resolve_model_alias, ProviderKind}; - use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig}; - - use crate::client::{ - now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, - resolve_startup_auth_source, AnthropicClient, AuthSource, OAuthTokenSet, - }; - use crate::types::{ContentBlockDelta, MessageRequest}; - - fn env_lock() -> std::sync::MutexGuard<'static, ()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) - .lock() - .expect("env lock") - } - - fn temp_config_home() -> std::path::PathBuf { - std::env::temp_dir().join(format!( - "api-oauth-test-{}-{}", - std::process::id(), - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("time") - .as_nanos() - )) - } - - fn sample_oauth_config(token_url: String) -> OAuthConfig { - OAuthConfig { - client_id: "runtime-client".to_string(), - authorize_url: "https://console.test/oauth/authorize".to_string(), - token_url, - callback_port: Some(4545), - manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), - scopes: vec!["org:read".to_string(), "user:write".to_string()], - } - } - - fn spawn_token_server(response_body: &'static str) -> String { - let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); - let address = listener.local_addr().expect("local addr"); - thread::spawn(move || { - let (mut stream, _) = listener.accept().expect("accept connection"); - let mut buffer = [0_u8; 4096]; - let _ = stream.read(&mut buffer).expect("read request"); - let response = format!( - "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", - response_body.len(), - response_body - ); - stream - .write_all(response.as_bytes()) - .expect("write response"); - }); - format!("http://{address}/oauth/token") + #[test] + fn resolves_existing_and_grok_aliases() { + assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6"); + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); } #[test] - fn read_api_key_requires_presence() { - let _guard = env_lock(); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - let error = super::read_api_key().expect_err("missing key should error"); - assert!(matches!(error, crate::error::ApiError::MissingApiKey)); - } - - #[test] - fn read_api_key_requires_non_empty_value() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); - std::env::remove_var("ANTHROPIC_API_KEY"); - let error = super::read_api_key().expect_err("empty key should error"); - assert!(matches!(error, crate::error::ApiError::MissingApiKey)); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - } - - #[test] - fn read_api_key_prefers_api_key_env() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); - std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + fn provider_detection_prefers_model_family() { + assert_eq!(detect_provider_kind("grok-3"), ProviderKind::Xai); assert_eq!( - super::read_api_key().expect("api key should load"), - "legacy-key" - ); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - } - - #[test] - fn read_auth_token_reads_auth_token_env() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); - assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - } - - #[test] - fn oauth_token_maps_to_bearer_auth_source() { - let auth = AuthSource::from(OAuthTokenSet { - access_token: "access-token".to_string(), - refresh_token: Some("refresh".to_string()), - expires_at: Some(123), - scopes: vec!["scope:a".to_string()], - }); - assert_eq!(auth.bearer_token(), Some("access-token")); - assert_eq!(auth.api_key(), None); - } - - #[test] - fn auth_source_from_env_combines_api_key_and_bearer_token() { - let _guard = env_lock(); - std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); - std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); - let auth = AuthSource::from_env().expect("env auth"); - assert_eq!(auth.api_key(), Some("legacy-key")); - assert_eq!(auth.bearer_token(), Some("auth-token")); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - } - - #[test] - fn auth_source_from_saved_oauth_when_env_absent() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "saved-access-token".to_string(), - refresh_token: Some("refresh".to_string()), - expires_at: Some(now_unix_timestamp() + 300), - scopes: vec!["scope:a".to_string()], - }) - .expect("save oauth credentials"); - - let auth = AuthSource::from_env_or_saved().expect("saved auth"); - assert_eq!(auth.bearer_token(), Some("saved-access-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn oauth_token_expiry_uses_expires_at_timestamp() { - assert!(oauth_token_is_expired(&OAuthTokenSet { - access_token: "access-token".to_string(), - refresh_token: None, - expires_at: Some(1), - scopes: Vec::new(), - })); - assert!(!oauth_token_is_expired(&OAuthTokenSet { - access_token: "access-token".to_string(), - refresh_token: None, - expires_at: Some(now_unix_timestamp() + 60), - scopes: Vec::new(), - })); - } - - #[test] - fn resolve_saved_oauth_token_refreshes_expired_credentials() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "expired-access-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expires_at: Some(1), - scopes: vec!["scope:a".to_string()], - }) - .expect("save expired oauth credentials"); - - let token_url = spawn_token_server( - "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", - ); - let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) - .expect("resolve refreshed token") - .expect("token set present"); - assert_eq!(resolved.access_token, "refreshed-token"); - let stored = runtime::load_oauth_credentials() - .expect("load stored credentials") - .expect("stored token set"); - assert_eq!(stored.access_token, "refreshed-token"); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "saved-access-token".to_string(), - refresh_token: Some("refresh".to_string()), - expires_at: Some(now_unix_timestamp() + 300), - scopes: vec!["scope:a".to_string()], - }) - .expect("save oauth credentials"); - - let auth = resolve_startup_auth_source(|| panic!("config should not be loaded")) - .expect("startup auth"); - assert_eq!(auth.bearer_token(), Some("saved-access-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "expired-access-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expires_at: Some(1), - scopes: vec!["scope:a".to_string()], - }) - .expect("save expired oauth credentials"); - - let error = - resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error"); - assert!( - matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing")) - ); - - let stored = runtime::load_oauth_credentials() - .expect("load stored credentials") - .expect("stored token set"); - assert_eq!(stored.access_token, "expired-access-token"); - assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { - let _guard = env_lock(); - let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); - std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); - std::env::remove_var("ANTHROPIC_API_KEY"); - save_oauth_credentials(&runtime::OAuthTokenSet { - access_token: "expired-access-token".to_string(), - refresh_token: Some("refresh-token".to_string()), - expires_at: Some(1), - scopes: vec!["scope:a".to_string()], - }) - .expect("save expired oauth credentials"); - - let token_url = spawn_token_server( - "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", - ); - let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) - .expect("resolve refreshed token") - .expect("token set present"); - assert_eq!(resolved.access_token, "refreshed-token"); - assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token")); - let stored = runtime::load_oauth_credentials() - .expect("load stored credentials") - .expect("stored token set"); - assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); - - clear_oauth_credentials().expect("clear credentials"); - std::env::remove_var("CLAUDE_CONFIG_HOME"); - std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); - } - - #[test] - fn message_request_stream_helper_sets_stream_true() { - let request = MessageRequest { - model: "claude-opus-4-6".to_string(), - max_tokens: 64, - messages: vec![], - system: None, - tools: None, - tool_choice: None, - stream: false, - }; - - assert!(request.with_streaming().stream); - } - - #[test] - fn backoff_doubles_until_maximum() { - let client = AnthropicClient::new("test-key").with_retry_policy( - 3, - Duration::from_millis(10), - Duration::from_millis(25), - ); - assert_eq!( - client.backoff_for_attempt(1).expect("attempt 1"), - Duration::from_millis(10) - ); - assert_eq!( - client.backoff_for_attempt(2).expect("attempt 2"), - Duration::from_millis(20) - ); - assert_eq!( - client.backoff_for_attempt(3).expect("attempt 3"), - Duration::from_millis(25) - ); - } - - #[test] - fn retryable_statuses_are_detected() { - assert!(super::is_retryable_status( - reqwest::StatusCode::TOO_MANY_REQUESTS - )); - assert!(super::is_retryable_status( - reqwest::StatusCode::INTERNAL_SERVER_ERROR - )); - assert!(!super::is_retryable_status( - reqwest::StatusCode::UNAUTHORIZED - )); - } - - #[test] - fn tool_delta_variant_round_trips() { - let delta = ContentBlockDelta::InputJsonDelta { - partial_json: "{\"city\":\"Paris\"}".to_string(), - }; - let encoded = serde_json::to_string(&delta).expect("delta should serialize"); - let decoded: ContentBlockDelta = - serde_json::from_str(&encoded).expect("delta should deserialize"); - assert_eq!(decoded, delta); - } - - #[test] - fn request_id_uses_primary_or_fallback_header() { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header")); - assert_eq!( - super::request_id_from_headers(&headers).as_deref(), - Some("req_primary") - ); - - headers.clear(); - headers.insert( - ALT_REQUEST_ID_HEADER, - "req_fallback".parse().expect("header"), - ); - assert_eq!( - super::request_id_from_headers(&headers).as_deref(), - Some("req_fallback") - ); - } - - #[test] - fn auth_source_applies_headers() { - let auth = AuthSource::ApiKeyAndBearer { - api_key: "test-key".to_string(), - bearer_token: "proxy-token".to_string(), - }; - let request = auth - .apply(reqwest::Client::new().post("https://example.test")) - .build() - .expect("request build"); - let headers = request.headers(); - assert_eq!( - headers.get("x-api-key").and_then(|v| v.to_str().ok()), - Some("test-key") - ); - assert_eq!( - headers.get("authorization").and_then(|v| v.to_str().ok()), - Some("Bearer proxy-token") + detect_provider_kind("claude-sonnet-4-6"), + ProviderKind::ClawApi ); } } diff --git a/rust/crates/api/src/error.rs b/rust/crates/api/src/error.rs index 2c31691..7649889 100644 --- a/rust/crates/api/src/error.rs +++ b/rust/crates/api/src/error.rs @@ -4,7 +4,10 @@ use std::time::Duration; #[derive(Debug)] pub enum ApiError { - MissingApiKey, + MissingCredentials { + provider: &'static str, + env_vars: &'static [&'static str], + }, ExpiredOAuthToken, Auth(String), InvalidApiKeyEnv(VarError), @@ -30,13 +33,21 @@ pub enum ApiError { } impl ApiError { + #[must_use] + pub const fn missing_credentials( + provider: &'static str, + env_vars: &'static [&'static str], + ) -> Self { + Self::MissingCredentials { provider, env_vars } + } + #[must_use] pub fn is_retryable(&self) -> bool { match self { Self::Http(error) => error.is_connect() || error.is_timeout() || error.is_request(), Self::Api { retryable, .. } => *retryable, Self::RetriesExhausted { last_error, .. } => last_error.is_retryable(), - Self::MissingApiKey + Self::MissingCredentials { .. } | Self::ExpiredOAuthToken | Self::Auth(_) | Self::InvalidApiKeyEnv(_) @@ -51,12 +62,11 @@ impl ApiError { impl Display for ApiError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::MissingApiKey => { - write!( - f, - "ANTHROPIC_AUTH_TOKEN or ANTHROPIC_API_KEY is not set; export one before calling the Anthropic API" - ) - } + Self::MissingCredentials { provider, env_vars } => write!( + f, + "missing {provider} credentials; export {} before calling the {provider} API", + env_vars.join(" or ") + ), Self::ExpiredOAuthToken => { write!( f, @@ -65,10 +75,7 @@ impl Display for ApiError { } Self::Auth(message) => write!(f, "auth error: {message}"), Self::InvalidApiKeyEnv(error) => { - write!( - f, - "failed to read ANTHROPIC_AUTH_TOKEN / ANTHROPIC_API_KEY: {error}" - ) + write!(f, "failed to read credential environment variable: {error}") } Self::Http(error) => write!(f, "http error: {error}"), Self::Io(error) => write!(f, "io error: {error}"), @@ -81,20 +88,14 @@ impl Display for ApiError { .. } => match (error_type, message) { (Some(error_type), Some(message)) => { - write!( - f, - "anthropic api returned {status} ({error_type}): {message}" - ) + write!(f, "api returned {status} ({error_type}): {message}") } - _ => write!(f, "anthropic api returned {status}: {body}"), + _ => write!(f, "api returned {status}: {body}"), }, Self::RetriesExhausted { attempts, last_error, - } => write!( - f, - "anthropic api failed after {attempts} attempts: {last_error}" - ), + } => write!(f, "api failed after {attempts} attempts: {last_error}"), Self::InvalidSseFrame(message) => write!(f, "invalid sse frame: {message}"), Self::BackoffOverflow { attempt, diff --git a/rust/crates/api/src/lib.rs b/rust/crates/api/src/lib.rs index a91344b..3306f53 100644 --- a/rust/crates/api/src/lib.rs +++ b/rust/crates/api/src/lib.rs @@ -1,13 +1,19 @@ mod client; mod error; +mod providers; mod sse; mod types; pub use client::{ - oauth_token_is_expired, read_base_url, resolve_saved_oauth_token, - resolve_startup_auth_source, AnthropicClient, AuthSource, MessageStream, OAuthTokenSet, + oauth_token_is_expired, read_base_url, read_xai_base_url, resolve_saved_oauth_token, + resolve_startup_auth_source, MessageStream, OAuthTokenSet, ProviderClient, }; pub use error::ApiError; +pub use providers::claw_provider::{AuthSource, ClawApiClient, ClawApiClient as ApiClient}; +pub use providers::openai_compat::{OpenAiCompatClient, OpenAiCompatConfig}; +pub use providers::{ + detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind, +}; pub use sse::{parse_frame, SseParser}; pub use types::{ ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, diff --git a/rust/crates/api/src/providers/claw_provider.rs b/rust/crates/api/src/providers/claw_provider.rs new file mode 100644 index 0000000..d9046cd --- /dev/null +++ b/rust/crates/api/src/providers/claw_provider.rs @@ -0,0 +1,1046 @@ +use std::collections::VecDeque; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use runtime::{ + load_oauth_credentials, save_oauth_credentials, OAuthConfig, OAuthRefreshRequest, + OAuthTokenExchangeRequest, +}; +use serde::Deserialize; + +use crate::error::ApiError; + +use super::{Provider, ProviderFuture}; +use crate::sse::SseParser; +use crate::types::{MessageRequest, MessageResponse, StreamEvent}; + +pub const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; +const ANTHROPIC_VERSION: &str = "2023-06-01"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); +const DEFAULT_MAX_RETRIES: u32 = 2; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AuthSource { + None, + ApiKey(String), + BearerToken(String), + ApiKeyAndBearer { + api_key: String, + bearer_token: String, + }, +} + +impl AuthSource { + pub fn from_env() -> Result { + let api_key = read_env_non_empty("ANTHROPIC_API_KEY")?; + let auth_token = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?; + match (api_key, auth_token) { + (Some(api_key), Some(bearer_token)) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + (Some(api_key), None) => Ok(Self::ApiKey(api_key)), + (None, Some(bearer_token)) => Ok(Self::BearerToken(bearer_token)), + (None, None) => Err(ApiError::missing_credentials( + "Claw", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )), + } + } + + #[must_use] + pub fn api_key(&self) -> Option<&str> { + match self { + Self::ApiKey(api_key) | Self::ApiKeyAndBearer { api_key, .. } => Some(api_key), + Self::None | Self::BearerToken(_) => None, + } + } + + #[must_use] + pub fn bearer_token(&self) -> Option<&str> { + match self { + Self::BearerToken(token) + | Self::ApiKeyAndBearer { + bearer_token: token, + .. + } => Some(token), + Self::None | Self::ApiKey(_) => None, + } + } + + #[must_use] + pub fn masked_authorization_header(&self) -> &'static str { + if self.bearer_token().is_some() { + "Bearer [REDACTED]" + } else { + "" + } + } + + pub fn apply(&self, mut request_builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder { + if let Some(api_key) = self.api_key() { + request_builder = request_builder.header("x-api-key", api_key); + } + if let Some(token) = self.bearer_token() { + request_builder = request_builder.bearer_auth(token); + } + request_builder + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Deserialize)] +pub struct OAuthTokenSet { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, + #[serde(default)] + pub scopes: Vec, +} + +impl From for AuthSource { + fn from(value: OAuthTokenSet) -> Self { + Self::BearerToken(value.access_token) + } +} + +#[derive(Debug, Clone)] +pub struct ClawApiClient { + http: reqwest::Client, + auth: AuthSource, + base_url: String, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, +} + +impl ClawApiClient { + #[must_use] + pub fn new(api_key: impl Into) -> Self { + Self { + http: reqwest::Client::new(), + auth: AuthSource::ApiKey(api_key.into()), + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + #[must_use] + pub fn from_auth(auth: AuthSource) -> Self { + Self { + http: reqwest::Client::new(), + auth, + base_url: DEFAULT_BASE_URL.to_string(), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + pub fn from_env() -> Result { + Ok(Self::from_auth(AuthSource::from_env_or_saved()?).with_base_url(read_base_url())) + } + + #[must_use] + pub fn with_auth_source(mut self, auth: AuthSource) -> Self { + self.auth = auth; + self + } + + #[must_use] + pub fn with_auth_token(mut self, auth_token: Option) -> Self { + match ( + self.auth.api_key().map(ToOwned::to_owned), + auth_token.filter(|token| !token.is_empty()), + ) { + (Some(api_key), Some(bearer_token)) => { + self.auth = AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }; + } + (Some(api_key), None) => { + self.auth = AuthSource::ApiKey(api_key); + } + (None, Some(bearer_token)) => { + self.auth = AuthSource::BearerToken(bearer_token); + } + (None, None) => { + self.auth = AuthSource::None; + } + } + self + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + #[must_use] + pub fn with_retry_policy( + mut self, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, + ) -> Self { + self.max_retries = max_retries; + self.initial_backoff = initial_backoff; + self.max_backoff = max_backoff; + self + } + + #[must_use] + pub fn auth_source(&self) -> &AuthSource { + &self.auth + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + let request = MessageRequest { + stream: false, + ..request.clone() + }; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let mut response = response + .json::() + .await + .map_err(ApiError::from)?; + if response.request_id.is_none() { + response.request_id = request_id; + } + Ok(response) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + let response = self + .send_with_retry(&request.clone().with_streaming()) + .await?; + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: SseParser::new(), + pending: VecDeque::new(), + done: false, + }) + } + + pub async fn exchange_oauth_code( + &self, + config: &OAuthConfig, + request: &OAuthTokenExchangeRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + pub async fn refresh_oauth_token( + &self, + config: &OAuthConfig, + request: &OAuthRefreshRequest, + ) -> Result { + let response = self + .http + .post(&config.token_url) + .header("content-type", "application/x-www-form-urlencoded") + .form(&request.form_params()) + .send() + .await + .map_err(ApiError::from)?; + let response = expect_success(response).await?; + response + .json::() + .await + .map_err(ApiError::from) + } + + async fn send_with_retry( + &self, + request: &MessageRequest, + ) -> Result { + let mut attempts = 0; + let mut last_error: Option; + + loop { + attempts += 1; + match self.send_raw_request(request).await { + Ok(response) => match expect_success(response).await { + Ok(response) => return Ok(response), + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + last_error = Some(error); + } + Err(error) => return Err(error), + }, + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => { + last_error = Some(error); + } + Err(error) => return Err(error), + } + + if attempts > self.max_retries { + break; + } + + tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + } + + Err(ApiError::RetriesExhausted { + attempts, + last_error: Box::new(last_error.expect("retry loop must capture an error")), + }) + } + + async fn send_raw_request( + &self, + request: &MessageRequest, + ) -> Result { + let request_url = format!("{}/v1/messages", self.base_url.trim_end_matches('/')); + let request_builder = self + .http + .post(&request_url) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("content-type", "application/json"); + let mut request_builder = self.auth.apply(request_builder); + + request_builder = request_builder.json(request); + request_builder.send().await.map_err(ApiError::from) + } + + fn backoff_for_attempt(&self, attempt: u32) -> Result { + let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { + return Err(ApiError::BackoffOverflow { + attempt, + base_delay: self.initial_backoff, + }); + }; + Ok(self + .initial_backoff + .checked_mul(multiplier) + .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) + } +} + +impl AuthSource { + pub fn from_env_or_saved() -> Result { + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(Self::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(Self::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(Self::BearerToken(bearer_token)); + } + match load_saved_oauth_token() { + Ok(Some(token_set)) if oauth_token_is_expired(&token_set) => { + if token_set.refresh_token.is_some() { + Err(ApiError::Auth( + "saved OAuth token is expired; load runtime OAuth config to refresh it" + .to_string(), + )) + } else { + Err(ApiError::ExpiredOAuthToken) + } + } + Ok(Some(token_set)) => Ok(Self::BearerToken(token_set.access_token)), + Ok(None) => Err(ApiError::missing_credentials( + "Claw", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )), + Err(error) => Err(error), + } + } +} + +#[must_use] +pub fn oauth_token_is_expired(token_set: &OAuthTokenSet) -> bool { + token_set + .expires_at + .is_some_and(|expires_at| expires_at <= now_unix_timestamp()) +} + +pub fn resolve_saved_oauth_token(config: &OAuthConfig) -> Result, ApiError> { + let Some(token_set) = load_saved_oauth_token()? else { + return Ok(None); + }; + resolve_saved_oauth_token_set(config, token_set).map(Some) +} + +pub fn has_auth_from_env_or_saved() -> Result { + Ok(read_env_non_empty("ANTHROPIC_API_KEY")?.is_some() + || read_env_non_empty("ANTHROPIC_AUTH_TOKEN")?.is_some() + || load_saved_oauth_token()?.is_some()) +} + +pub fn resolve_startup_auth_source(load_oauth_config: F) -> Result +where + F: FnOnce() -> Result, ApiError>, +{ + if let Some(api_key) = read_env_non_empty("ANTHROPIC_API_KEY")? { + return match read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + Some(bearer_token) => Ok(AuthSource::ApiKeyAndBearer { + api_key, + bearer_token, + }), + None => Ok(AuthSource::ApiKey(api_key)), + }; + } + if let Some(bearer_token) = read_env_non_empty("ANTHROPIC_AUTH_TOKEN")? { + return Ok(AuthSource::BearerToken(bearer_token)); + } + + let Some(token_set) = load_saved_oauth_token()? else { + return Err(ApiError::missing_credentials( + "Claw", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )); + }; + if !oauth_token_is_expired(&token_set) { + return Ok(AuthSource::BearerToken(token_set.access_token)); + } + if token_set.refresh_token.is_none() { + return Err(ApiError::ExpiredOAuthToken); + } + + let Some(config) = load_oauth_config()? else { + return Err(ApiError::Auth( + "saved OAuth token is expired; runtime OAuth config is missing".to_string(), + )); + }; + Ok(AuthSource::from(resolve_saved_oauth_token_set( + &config, token_set, + )?)) +} + +fn resolve_saved_oauth_token_set( + config: &OAuthConfig, + token_set: OAuthTokenSet, +) -> Result { + if !oauth_token_is_expired(&token_set) { + return Ok(token_set); + } + let Some(refresh_token) = token_set.refresh_token.clone() else { + return Err(ApiError::ExpiredOAuthToken); + }; + let client = ClawApiClient::from_auth(AuthSource::None).with_base_url(read_base_url()); + let refreshed = client_runtime_block_on(async { + client + .refresh_oauth_token( + config, + &OAuthRefreshRequest::from_config( + config, + refresh_token, + Some(token_set.scopes.clone()), + ), + ) + .await + })?; + let resolved = OAuthTokenSet { + access_token: refreshed.access_token, + refresh_token: refreshed.refresh_token.or(token_set.refresh_token), + expires_at: refreshed.expires_at, + scopes: refreshed.scopes, + }; + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: resolved.access_token.clone(), + refresh_token: resolved.refresh_token.clone(), + expires_at: resolved.expires_at, + scopes: resolved.scopes.clone(), + }) + .map_err(ApiError::from)?; + Ok(resolved) +} + +fn client_runtime_block_on(future: F) -> Result +where + F: std::future::Future>, +{ + tokio::runtime::Runtime::new() + .map_err(ApiError::from)? + .block_on(future) +} + +fn load_saved_oauth_token() -> Result, ApiError> { + let token_set = load_oauth_credentials().map_err(ApiError::from)?; + Ok(token_set.map(|token_set| OAuthTokenSet { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + scopes: token_set.scopes, + })) +} + +fn now_unix_timestamp() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Err(error) => Err(ApiError::from(error)), + } +} + +#[cfg(test)] +fn read_api_key() -> Result { + let auth = AuthSource::from_env_or_saved()?; + auth.api_key() + .or_else(|| auth.bearer_token()) + .map(ToOwned::to_owned) + .ok_or(ApiError::missing_credentials( + "Claw", + &["ANTHROPIC_AUTH_TOKEN", "ANTHROPIC_API_KEY"], + )) +} + +#[cfg(test)] +fn read_auth_token() -> Option { + read_env_non_empty("ANTHROPIC_AUTH_TOKEN") + .ok() + .and_then(std::convert::identity) +} + +#[must_use] +pub fn read_base_url() -> String { + std::env::var("ANTHROPIC_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string()) +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|value| value.to_str().ok()) + .map(ToOwned::to_owned) +} + +impl Provider for ClawApiClient { + type Stream = MessageStream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse> { + Box::pin(async move { self.send_message(request).await }) + } + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream> { + Box::pin(async move { self.stream_message(request).await }) + } +} + +#[derive(Debug)] +pub struct MessageStream { + request_id: Option, + response: reqwest::Response, + parser: SseParser, + pending: VecDeque, + done: bool, +} + +impl MessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + let remaining = self.parser.finish()?; + self.pending.extend(remaining); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + self.pending.extend(self.parser.push(&chunk)?); + } + None => { + self.done = true; + } + } + } + } +} + +async fn expect_success(response: reqwest::Response) -> Result { + let status = response.status(); + if status.is_success() { + return Ok(response); + } + + let body = response.text().await.unwrap_or_else(|_| String::new()); + let parsed_error = serde_json::from_str::(&body).ok(); + let retryable = is_retryable_status(status); + + Err(ApiError::Api { + status, + error_type: parsed_error + .as_ref() + .map(|error| error.error.error_type.clone()), + message: parsed_error + .as_ref() + .map(|error| error.error.message.clone()), + body, + retryable, + }) +} + +const fn is_retryable_status(status: reqwest::StatusCode) -> bool { + matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) +} + +#[derive(Debug, Deserialize)] +struct ApiErrorEnvelope { + error: ApiErrorBody, +} + +#[derive(Debug, Deserialize)] +struct ApiErrorBody { + #[serde(rename = "type")] + error_type: String, + message: String, +} + +#[cfg(test)] +mod tests { + use super::{ALT_REQUEST_ID_HEADER, REQUEST_ID_HEADER}; + use std::io::{Read, Write}; + use std::net::TcpListener; + use std::sync::{Mutex, OnceLock}; + use std::thread; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use runtime::{clear_oauth_credentials, save_oauth_credentials, OAuthConfig}; + + use super::{ + now_unix_timestamp, oauth_token_is_expired, resolve_saved_oauth_token, + resolve_startup_auth_source, AuthSource, ClawApiClient, OAuthTokenSet, + }; + use crate::types::{ContentBlockDelta, MessageRequest}; + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + } + + fn temp_config_home() -> std::path::PathBuf { + std::env::temp_dir().join(format!( + "api-oauth-test-{}-{}", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_nanos() + )) + } + + fn cleanup_temp_config_home(config_home: &std::path::Path) { + match std::fs::remove_dir_all(config_home) { + Ok(()) => {} + Err(error) if error.kind() == std::io::ErrorKind::NotFound => {} + Err(error) => panic!("cleanup temp dir: {error}"), + } + } + + fn sample_oauth_config(token_url: String) -> OAuthConfig { + OAuthConfig { + client_id: "runtime-client".to_string(), + authorize_url: "https://console.test/oauth/authorize".to_string(), + token_url, + callback_port: Some(4545), + manual_redirect_url: Some("https://console.test/oauth/callback".to_string()), + scopes: vec!["org:read".to_string(), "user:write".to_string()], + } + } + + fn spawn_token_server(response_body: &'static str) -> String { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); + let address = listener.local_addr().expect("local addr"); + thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept connection"); + let mut buffer = [0_u8; 4096]; + let _ = stream.read(&mut buffer).expect("read request"); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\n\r\n{}", + response_body.len(), + response_body + ); + stream + .write_all(response.as_bytes()) + .expect("write response"); + }); + format!("http://{address}/oauth/token") + } + + #[test] + fn read_api_key_requires_presence() { + let _guard = env_lock(); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + std::env::remove_var("CLAW_CONFIG_HOME"); + let error = super::read_api_key().expect_err("missing key should error"); + assert!(matches!( + error, + crate::error::ApiError::MissingCredentials { .. } + )); + } + + #[test] + fn read_api_key_requires_non_empty_value() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", ""); + std::env::remove_var("ANTHROPIC_API_KEY"); + let error = super::read_api_key().expect_err("empty key should error"); + assert!(matches!( + error, + crate::error::ApiError::MissingCredentials { .. } + )); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + } + + #[test] + fn read_api_key_prefers_api_key_env() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + assert_eq!( + super::read_api_key().expect("api key should load"), + "legacy-key" + ); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn read_auth_token_reads_auth_token_env() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + assert_eq!(super::read_auth_token().as_deref(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + } + + #[test] + fn oauth_token_maps_to_bearer_auth_source() { + let auth = AuthSource::from(OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(123), + scopes: vec!["scope:a".to_string()], + }); + assert_eq!(auth.bearer_token(), Some("access-token")); + assert_eq!(auth.api_key(), None); + } + + #[test] + fn auth_source_from_env_combines_api_key_and_bearer_token() { + let _guard = env_lock(); + std::env::set_var("ANTHROPIC_AUTH_TOKEN", "auth-token"); + std::env::set_var("ANTHROPIC_API_KEY", "legacy-key"); + let auth = AuthSource::from_env().expect("env auth"); + assert_eq!(auth.api_key(), Some("legacy-key")); + assert_eq!(auth.bearer_token(), Some("auth-token")); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + } + + #[test] + fn auth_source_from_saved_oauth_when_env_absent() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = AuthSource::from_env_or_saved().expect("saved auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn oauth_token_expiry_uses_expires_at_timestamp() { + assert!(oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(1), + scopes: Vec::new(), + })); + assert!(!oauth_token_is_expired(&OAuthTokenSet { + access_token: "access-token".to_string(), + refresh_token: None, + expires_at: Some(now_unix_timestamp() + 60), + scopes: Vec::new(), + })); + } + + #[test] + fn resolve_saved_oauth_token_refreshes_expired_credentials() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"refresh_token\":\"fresh-refresh\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "refreshed-token"); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn resolve_startup_auth_source_uses_saved_oauth_without_loading_config() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "saved-access-token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(now_unix_timestamp() + 300), + scopes: vec!["scope:a".to_string()], + }) + .expect("save oauth credentials"); + + let auth = resolve_startup_auth_source(|| panic!("config should not be loaded")) + .expect("startup auth"); + assert_eq!(auth.bearer_token(), Some("saved-access-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn resolve_startup_auth_source_errors_when_refreshable_token_lacks_config() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let error = + resolve_startup_auth_source(|| Ok(None)).expect_err("missing config should error"); + assert!( + matches!(error, crate::error::ApiError::Auth(message) if message.contains("runtime OAuth config is missing")) + ); + + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.access_token, "expired-access-token"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn resolve_saved_oauth_token_preserves_refresh_token_when_refresh_response_omits_it() { + let _guard = env_lock(); + let config_home = temp_config_home(); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); + std::env::remove_var("ANTHROPIC_AUTH_TOKEN"); + std::env::remove_var("ANTHROPIC_API_KEY"); + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: "expired-access-token".to_string(), + refresh_token: Some("refresh-token".to_string()), + expires_at: Some(1), + scopes: vec!["scope:a".to_string()], + }) + .expect("save expired oauth credentials"); + + let token_url = spawn_token_server( + "{\"access_token\":\"refreshed-token\",\"expires_at\":9999999999,\"scopes\":[\"scope:a\"]}", + ); + let resolved = resolve_saved_oauth_token(&sample_oauth_config(token_url)) + .expect("resolve refreshed token") + .expect("token set present"); + assert_eq!(resolved.access_token, "refreshed-token"); + assert_eq!(resolved.refresh_token.as_deref(), Some("refresh-token")); + let stored = runtime::load_oauth_credentials() + .expect("load stored credentials") + .expect("stored token set"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + clear_oauth_credentials().expect("clear credentials"); + std::env::remove_var("CLAW_CONFIG_HOME"); + cleanup_temp_config_home(&config_home); + } + + #[test] + fn message_request_stream_helper_sets_stream_true() { + let request = MessageRequest { + model: "claude-opus-4-6".to_string(), + max_tokens: 64, + messages: vec![], + system: None, + tools: None, + tool_choice: None, + stream: false, + }; + + assert!(request.with_streaming().stream); + } + + #[test] + fn backoff_doubles_until_maximum() { + let client = ClawApiClient::new("test-key").with_retry_policy( + 3, + Duration::from_millis(10), + Duration::from_millis(25), + ); + assert_eq!( + client.backoff_for_attempt(1).expect("attempt 1"), + Duration::from_millis(10) + ); + assert_eq!( + client.backoff_for_attempt(2).expect("attempt 2"), + Duration::from_millis(20) + ); + assert_eq!( + client.backoff_for_attempt(3).expect("attempt 3"), + Duration::from_millis(25) + ); + } + + #[test] + fn retryable_statuses_are_detected() { + assert!(super::is_retryable_status( + reqwest::StatusCode::TOO_MANY_REQUESTS + )); + assert!(super::is_retryable_status( + reqwest::StatusCode::INTERNAL_SERVER_ERROR + )); + assert!(!super::is_retryable_status( + reqwest::StatusCode::UNAUTHORIZED + )); + } + + #[test] + fn tool_delta_variant_round_trips() { + let delta = ContentBlockDelta::InputJsonDelta { + partial_json: "{\"city\":\"Paris\"}".to_string(), + }; + let encoded = serde_json::to_string(&delta).expect("delta should serialize"); + let decoded: ContentBlockDelta = + serde_json::from_str(&encoded).expect("delta should deserialize"); + assert_eq!(decoded, delta); + } + + #[test] + fn request_id_uses_primary_or_fallback_header() { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert(REQUEST_ID_HEADER, "req_primary".parse().expect("header")); + assert_eq!( + super::request_id_from_headers(&headers).as_deref(), + Some("req_primary") + ); + + headers.clear(); + headers.insert( + ALT_REQUEST_ID_HEADER, + "req_fallback".parse().expect("header"), + ); + assert_eq!( + super::request_id_from_headers(&headers).as_deref(), + Some("req_fallback") + ); + } + + #[test] + fn auth_source_applies_headers() { + let auth = AuthSource::ApiKeyAndBearer { + api_key: "test-key".to_string(), + bearer_token: "proxy-token".to_string(), + }; + let request = auth + .apply(reqwest::Client::new().post("https://example.test")) + .build() + .expect("request build"); + let headers = request.headers(); + assert_eq!( + headers.get("x-api-key").and_then(|v| v.to_str().ok()), + Some("test-key") + ); + assert_eq!( + headers.get("authorization").and_then(|v| v.to_str().ok()), + Some("Bearer proxy-token") + ); + } +} diff --git a/rust/crates/api/src/providers/mod.rs b/rust/crates/api/src/providers/mod.rs new file mode 100644 index 0000000..192afd6 --- /dev/null +++ b/rust/crates/api/src/providers/mod.rs @@ -0,0 +1,239 @@ +use std::future::Future; +use std::pin::Pin; + +use crate::error::ApiError; +use crate::types::{MessageRequest, MessageResponse}; + +pub mod claw_provider; +pub mod openai_compat; + +pub type ProviderFuture<'a, T> = Pin> + Send + 'a>>; + +pub trait Provider { + type Stream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse>; + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream>; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ProviderKind { + ClawApi, + Xai, + OpenAi, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ProviderMetadata { + pub provider: ProviderKind, + pub auth_env: &'static str, + pub base_url_env: &'static str, + pub default_base_url: &'static str, +} + +const MODEL_REGISTRY: &[(&str, ProviderMetadata)] = &[ + ( + "opus", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "sonnet", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "haiku", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "claude-opus-4-6", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "claude-sonnet-4-6", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "claude-haiku-4-5-20251213", + ProviderMetadata { + provider: ProviderKind::ClawApi, + auth_env: "ANTHROPIC_API_KEY", + base_url_env: "ANTHROPIC_BASE_URL", + default_base_url: claw_provider::DEFAULT_BASE_URL, + }, + ), + ( + "grok", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-3", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-mini", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-3-mini", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), + ( + "grok-2", + ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }, + ), +]; + +#[must_use] +pub fn resolve_model_alias(model: &str) -> String { + let trimmed = model.trim(); + let lower = trimmed.to_ascii_lowercase(); + MODEL_REGISTRY + .iter() + .find_map(|(alias, metadata)| { + (*alias == lower).then_some(match metadata.provider { + ProviderKind::ClawApi => match *alias { + "opus" => "claude-opus-4-6", + "sonnet" => "claude-sonnet-4-6", + "haiku" => "claude-haiku-4-5-20251213", + _ => trimmed, + }, + ProviderKind::Xai => match *alias { + "grok" | "grok-3" => "grok-3", + "grok-mini" | "grok-3-mini" => "grok-3-mini", + "grok-2" => "grok-2", + _ => trimmed, + }, + ProviderKind::OpenAi => trimmed, + }) + }) + .map_or_else(|| trimmed.to_string(), ToOwned::to_owned) +} + +#[must_use] +pub fn metadata_for_model(model: &str) -> Option { + let canonical = resolve_model_alias(model); + let lower = canonical.to_ascii_lowercase(); + if let Some((_, metadata)) = MODEL_REGISTRY.iter().find(|(alias, _)| *alias == lower) { + return Some(*metadata); + } + if lower.starts_with("grok") { + return Some(ProviderMetadata { + provider: ProviderKind::Xai, + auth_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: openai_compat::DEFAULT_XAI_BASE_URL, + }); + } + None +} + +#[must_use] +pub fn detect_provider_kind(model: &str) -> ProviderKind { + if let Some(metadata) = metadata_for_model(model) { + return metadata.provider; + } + if claw_provider::has_auth_from_env_or_saved().unwrap_or(false) { + return ProviderKind::ClawApi; + } + if openai_compat::has_api_key("OPENAI_API_KEY") { + return ProviderKind::OpenAi; + } + if openai_compat::has_api_key("XAI_API_KEY") { + return ProviderKind::Xai; + } + ProviderKind::ClawApi +} + +#[must_use] +pub fn max_tokens_for_model(model: &str) -> u32 { + let canonical = resolve_model_alias(model); + if canonical.contains("opus") { + 32_000 + } else { + 64_000 + } +} + +#[cfg(test)] +mod tests { + use super::{detect_provider_kind, max_tokens_for_model, resolve_model_alias, ProviderKind}; + + #[test] + fn resolves_grok_aliases() { + assert_eq!(resolve_model_alias("grok"), "grok-3"); + assert_eq!(resolve_model_alias("grok-mini"), "grok-3-mini"); + assert_eq!(resolve_model_alias("grok-2"), "grok-2"); + } + + #[test] + fn detects_provider_from_model_name_first() { + assert_eq!(detect_provider_kind("grok"), ProviderKind::Xai); + assert_eq!( + detect_provider_kind("claude-sonnet-4-6"), + ProviderKind::ClawApi + ); + } + + #[test] + fn keeps_existing_max_token_heuristic() { + assert_eq!(max_tokens_for_model("opus"), 32_000); + assert_eq!(max_tokens_for_model("grok-3"), 64_000); + } +} diff --git a/rust/crates/api/src/providers/openai_compat.rs b/rust/crates/api/src/providers/openai_compat.rs new file mode 100644 index 0000000..e8210ae --- /dev/null +++ b/rust/crates/api/src/providers/openai_compat.rs @@ -0,0 +1,1050 @@ +use std::collections::{BTreeMap, VecDeque}; +use std::time::Duration; + +use serde::Deserialize; +use serde_json::{json, Value}; + +use crate::error::ApiError; +use crate::types::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageDelta, MessageDeltaEvent, MessageRequest, + MessageResponse, MessageStartEvent, MessageStopEvent, OutputContentBlock, StreamEvent, + ToolChoice, ToolDefinition, ToolResultContentBlock, Usage, +}; + +use super::{Provider, ProviderFuture}; + +pub const DEFAULT_XAI_BASE_URL: &str = "https://api.x.ai/v1"; +pub const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1"; +const REQUEST_ID_HEADER: &str = "request-id"; +const ALT_REQUEST_ID_HEADER: &str = "x-request-id"; +const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200); +const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(2); +const DEFAULT_MAX_RETRIES: u32 = 2; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct OpenAiCompatConfig { + pub provider_name: &'static str, + pub api_key_env: &'static str, + pub base_url_env: &'static str, + pub default_base_url: &'static str, +} + +const XAI_ENV_VARS: &[&str] = &["XAI_API_KEY"]; +const OPENAI_ENV_VARS: &[&str] = &["OPENAI_API_KEY"]; + +impl OpenAiCompatConfig { + #[must_use] + pub const fn xai() -> Self { + Self { + provider_name: "xAI", + api_key_env: "XAI_API_KEY", + base_url_env: "XAI_BASE_URL", + default_base_url: DEFAULT_XAI_BASE_URL, + } + } + + #[must_use] + pub const fn openai() -> Self { + Self { + provider_name: "OpenAI", + api_key_env: "OPENAI_API_KEY", + base_url_env: "OPENAI_BASE_URL", + default_base_url: DEFAULT_OPENAI_BASE_URL, + } + } + #[must_use] + pub fn credential_env_vars(self) -> &'static [&'static str] { + match self.provider_name { + "xAI" => XAI_ENV_VARS, + "OpenAI" => OPENAI_ENV_VARS, + _ => &[], + } + } +} + +#[derive(Debug, Clone)] +pub struct OpenAiCompatClient { + http: reqwest::Client, + api_key: String, + base_url: String, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, +} + +impl OpenAiCompatClient { + #[must_use] + pub fn new(api_key: impl Into, config: OpenAiCompatConfig) -> Self { + Self { + http: reqwest::Client::new(), + api_key: api_key.into(), + base_url: read_base_url(config), + max_retries: DEFAULT_MAX_RETRIES, + initial_backoff: DEFAULT_INITIAL_BACKOFF, + max_backoff: DEFAULT_MAX_BACKOFF, + } + } + + pub fn from_env(config: OpenAiCompatConfig) -> Result { + let Some(api_key) = read_env_non_empty(config.api_key_env)? else { + return Err(ApiError::missing_credentials( + config.provider_name, + config.credential_env_vars(), + )); + }; + Ok(Self::new(api_key, config)) + } + + #[must_use] + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + #[must_use] + pub fn with_retry_policy( + mut self, + max_retries: u32, + initial_backoff: Duration, + max_backoff: Duration, + ) -> Self { + self.max_retries = max_retries; + self.initial_backoff = initial_backoff; + self.max_backoff = max_backoff; + self + } + + pub async fn send_message( + &self, + request: &MessageRequest, + ) -> Result { + let request = MessageRequest { + stream: false, + ..request.clone() + }; + let response = self.send_with_retry(&request).await?; + let request_id = request_id_from_headers(response.headers()); + let payload = response.json::().await?; + let mut normalized = normalize_response(&request.model, payload)?; + if normalized.request_id.is_none() { + normalized.request_id = request_id; + } + Ok(normalized) + } + + pub async fn stream_message( + &self, + request: &MessageRequest, + ) -> Result { + let response = self + .send_with_retry(&request.clone().with_streaming()) + .await?; + Ok(MessageStream { + request_id: request_id_from_headers(response.headers()), + response, + parser: OpenAiSseParser::new(), + pending: VecDeque::new(), + done: false, + state: StreamState::new(request.model.clone()), + }) + } + + async fn send_with_retry( + &self, + request: &MessageRequest, + ) -> Result { + let mut attempts = 0; + + let last_error = loop { + attempts += 1; + let retryable_error = match self.send_raw_request(request).await { + Ok(response) => match expect_success(response).await { + Ok(response) => return Ok(response), + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, + Err(error) => return Err(error), + }, + Err(error) if error.is_retryable() && attempts <= self.max_retries + 1 => error, + Err(error) => return Err(error), + }; + + if attempts > self.max_retries { + break retryable_error; + } + + tokio::time::sleep(self.backoff_for_attempt(attempts)?).await; + }; + + Err(ApiError::RetriesExhausted { + attempts, + last_error: Box::new(last_error), + }) + } + + async fn send_raw_request( + &self, + request: &MessageRequest, + ) -> Result { + let request_url = chat_completions_endpoint(&self.base_url); + self.http + .post(&request_url) + .header("content-type", "application/json") + .bearer_auth(&self.api_key) + .json(&build_chat_completion_request(request)) + .send() + .await + .map_err(ApiError::from) + } + + fn backoff_for_attempt(&self, attempt: u32) -> Result { + let Some(multiplier) = 1_u32.checked_shl(attempt.saturating_sub(1)) else { + return Err(ApiError::BackoffOverflow { + attempt, + base_delay: self.initial_backoff, + }); + }; + Ok(self + .initial_backoff + .checked_mul(multiplier) + .map_or(self.max_backoff, |delay| delay.min(self.max_backoff))) + } +} + +impl Provider for OpenAiCompatClient { + type Stream = MessageStream; + + fn send_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, MessageResponse> { + Box::pin(async move { self.send_message(request).await }) + } + + fn stream_message<'a>( + &'a self, + request: &'a MessageRequest, + ) -> ProviderFuture<'a, Self::Stream> { + Box::pin(async move { self.stream_message(request).await }) + } +} + +#[derive(Debug)] +pub struct MessageStream { + request_id: Option, + response: reqwest::Response, + parser: OpenAiSseParser, + pending: VecDeque, + done: bool, + state: StreamState, +} + +impl MessageStream { + #[must_use] + pub fn request_id(&self) -> Option<&str> { + self.request_id.as_deref() + } + + pub async fn next_event(&mut self) -> Result, ApiError> { + loop { + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + + if self.done { + self.pending.extend(self.state.finish()?); + if let Some(event) = self.pending.pop_front() { + return Ok(Some(event)); + } + return Ok(None); + } + + match self.response.chunk().await? { + Some(chunk) => { + for parsed in self.parser.push(&chunk)? { + self.pending.extend(self.state.ingest_chunk(parsed)?); + } + } + None => { + self.done = true; + } + } + } + } +} + +#[derive(Debug, Default)] +struct OpenAiSseParser { + buffer: Vec, +} + +impl OpenAiSseParser { + fn new() -> Self { + Self::default() + } + + fn push(&mut self, chunk: &[u8]) -> Result, ApiError> { + self.buffer.extend_from_slice(chunk); + let mut events = Vec::new(); + + while let Some(frame) = next_sse_frame(&mut self.buffer) { + if let Some(event) = parse_sse_frame(&frame)? { + events.push(event); + } + } + + Ok(events) + } +} + +#[derive(Debug)] +struct StreamState { + model: String, + message_started: bool, + text_started: bool, + text_finished: bool, + finished: bool, + stop_reason: Option, + usage: Option, + tool_calls: BTreeMap, +} + +impl StreamState { + fn new(model: String) -> Self { + Self { + model, + message_started: false, + text_started: false, + text_finished: false, + finished: false, + stop_reason: None, + usage: None, + tool_calls: BTreeMap::new(), + } + } + + fn ingest_chunk(&mut self, chunk: ChatCompletionChunk) -> Result, ApiError> { + let mut events = Vec::new(); + if !self.message_started { + self.message_started = true; + events.push(StreamEvent::MessageStart(MessageStartEvent { + message: MessageResponse { + id: chunk.id.clone(), + kind: "message".to_string(), + role: "assistant".to_string(), + content: Vec::new(), + model: chunk.model.clone().unwrap_or_else(|| self.model.clone()), + stop_reason: None, + stop_sequence: None, + usage: Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 0, + }, + request_id: None, + }, + })); + } + + if let Some(usage) = chunk.usage { + self.usage = Some(Usage { + input_tokens: usage.prompt_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: usage.completion_tokens, + }); + } + + for choice in chunk.choices { + if let Some(content) = choice.delta.content.filter(|value| !value.is_empty()) { + if !self.text_started { + self.text_started = true; + events.push(StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 0, + content_block: OutputContentBlock::Text { + text: String::new(), + }, + })); + } + events.push(StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::TextDelta { text: content }, + })); + } + + for tool_call in choice.delta.tool_calls { + let state = self.tool_calls.entry(tool_call.index).or_default(); + state.apply(tool_call); + let block_index = state.block_index(); + if !state.started { + if let Some(start_event) = state.start_event()? { + state.started = true; + events.push(StreamEvent::ContentBlockStart(start_event)); + } else { + continue; + } + } + if let Some(delta_event) = state.delta_event() { + events.push(StreamEvent::ContentBlockDelta(delta_event)); + } + if choice.finish_reason.as_deref() == Some("tool_calls") && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: block_index, + })); + } + } + + if let Some(finish_reason) = choice.finish_reason { + self.stop_reason = Some(normalize_finish_reason(&finish_reason)); + if finish_reason == "tool_calls" { + for state in self.tool_calls.values_mut() { + if state.started && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: state.block_index(), + })); + } + } + } + } + } + + Ok(events) + } + + fn finish(&mut self) -> Result, ApiError> { + if self.finished { + return Ok(Vec::new()); + } + self.finished = true; + + let mut events = Vec::new(); + if self.text_started && !self.text_finished { + self.text_finished = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: 0, + })); + } + + for state in self.tool_calls.values_mut() { + if !state.started { + if let Some(start_event) = state.start_event()? { + state.started = true; + events.push(StreamEvent::ContentBlockStart(start_event)); + if let Some(delta_event) = state.delta_event() { + events.push(StreamEvent::ContentBlockDelta(delta_event)); + } + } + } + if state.started && !state.stopped { + state.stopped = true; + events.push(StreamEvent::ContentBlockStop(ContentBlockStopEvent { + index: state.block_index(), + })); + } + } + + if self.message_started { + events.push(StreamEvent::MessageDelta(MessageDeltaEvent { + delta: MessageDelta { + stop_reason: Some( + self.stop_reason + .clone() + .unwrap_or_else(|| "end_turn".to_string()), + ), + stop_sequence: None, + }, + usage: self.usage.clone().unwrap_or(Usage { + input_tokens: 0, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: 0, + }), + })); + events.push(StreamEvent::MessageStop(MessageStopEvent {})); + } + Ok(events) + } +} + +#[derive(Debug, Default)] +struct ToolCallState { + openai_index: u32, + id: Option, + name: Option, + arguments: String, + emitted_len: usize, + started: bool, + stopped: bool, +} + +impl ToolCallState { + fn apply(&mut self, tool_call: DeltaToolCall) { + self.openai_index = tool_call.index; + if let Some(id) = tool_call.id { + self.id = Some(id); + } + if let Some(name) = tool_call.function.name { + self.name = Some(name); + } + if let Some(arguments) = tool_call.function.arguments { + self.arguments.push_str(&arguments); + } + } + + const fn block_index(&self) -> u32 { + self.openai_index + 1 + } + + fn start_event(&self) -> Result, ApiError> { + let Some(name) = self.name.clone() else { + return Ok(None); + }; + let id = self + .id + .clone() + .unwrap_or_else(|| format!("tool_call_{}", self.openai_index)); + Ok(Some(ContentBlockStartEvent { + index: self.block_index(), + content_block: OutputContentBlock::ToolUse { + id, + name, + input: json!({}), + }, + })) + } + + fn delta_event(&mut self) -> Option { + if self.emitted_len >= self.arguments.len() { + return None; + } + let delta = self.arguments[self.emitted_len..].to_string(); + self.emitted_len = self.arguments.len(); + Some(ContentBlockDeltaEvent { + index: self.block_index(), + delta: ContentBlockDelta::InputJsonDelta { + partial_json: delta, + }, + }) + } +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionResponse { + id: String, + model: String, + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatChoice { + message: ChatMessage, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatMessage { + role: String, + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Vec, +} + +#[derive(Debug, Deserialize)] +struct ResponseToolCall { + id: String, + function: ResponseToolFunction, +} + +#[derive(Debug, Deserialize)] +struct ResponseToolFunction { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct OpenAiUsage { + #[serde(default)] + prompt_tokens: u32, + #[serde(default)] + completion_tokens: u32, +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionChunk { + id: String, + #[serde(default)] + model: Option, + #[serde(default)] + choices: Vec, + #[serde(default)] + usage: Option, +} + +#[derive(Debug, Deserialize)] +struct ChunkChoice { + delta: ChunkDelta, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Default, Deserialize)] +struct ChunkDelta { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Vec, +} + +#[derive(Debug, Deserialize)] +struct DeltaToolCall { + #[serde(default)] + index: u32, + #[serde(default)] + id: Option, + #[serde(default)] + function: DeltaFunction, +} + +#[derive(Debug, Default, Deserialize)] +struct DeltaFunction { + #[serde(default)] + name: Option, + #[serde(default)] + arguments: Option, +} + +#[derive(Debug, Deserialize)] +struct ErrorEnvelope { + error: ErrorBody, +} + +#[derive(Debug, Deserialize)] +struct ErrorBody { + #[serde(rename = "type")] + error_type: Option, + message: Option, +} + +fn build_chat_completion_request(request: &MessageRequest) -> Value { + let mut messages = Vec::new(); + if let Some(system) = request.system.as_ref().filter(|value| !value.is_empty()) { + messages.push(json!({ + "role": "system", + "content": system, + })); + } + for message in &request.messages { + messages.extend(translate_message(message)); + } + + let mut payload = json!({ + "model": request.model, + "max_tokens": request.max_tokens, + "messages": messages, + "stream": request.stream, + }); + + if let Some(tools) = &request.tools { + payload["tools"] = + Value::Array(tools.iter().map(openai_tool_definition).collect::>()); + } + if let Some(tool_choice) = &request.tool_choice { + payload["tool_choice"] = openai_tool_choice(tool_choice); + } + + payload +} + +fn translate_message(message: &InputMessage) -> Vec { + match message.role.as_str() { + "assistant" => { + let mut text = String::new(); + let mut tool_calls = Vec::new(); + for block in &message.content { + match block { + InputContentBlock::Text { text: value } => text.push_str(value), + InputContentBlock::ToolUse { id, name, input } => tool_calls.push(json!({ + "id": id, + "type": "function", + "function": { + "name": name, + "arguments": input.to_string(), + } + })), + InputContentBlock::ToolResult { .. } => {} + } + } + if text.is_empty() && tool_calls.is_empty() { + Vec::new() + } else { + vec![json!({ + "role": "assistant", + "content": (!text.is_empty()).then_some(text), + "tool_calls": tool_calls, + })] + } + } + _ => message + .content + .iter() + .filter_map(|block| match block { + InputContentBlock::Text { text } => Some(json!({ + "role": "user", + "content": text, + })), + InputContentBlock::ToolResult { + tool_use_id, + content, + is_error, + } => Some(json!({ + "role": "tool", + "tool_call_id": tool_use_id, + "content": flatten_tool_result_content(content), + "is_error": is_error, + })), + InputContentBlock::ToolUse { .. } => None, + }) + .collect(), + } +} + +fn flatten_tool_result_content(content: &[ToolResultContentBlock]) -> String { + content + .iter() + .map(|block| match block { + ToolResultContentBlock::Text { text } => text.clone(), + ToolResultContentBlock::Json { value } => value.to_string(), + }) + .collect::>() + .join("\n") +} + +fn openai_tool_definition(tool: &ToolDefinition) -> Value { + json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.input_schema, + } + }) +} + +fn openai_tool_choice(tool_choice: &ToolChoice) -> Value { + match tool_choice { + ToolChoice::Auto => Value::String("auto".to_string()), + ToolChoice::Any => Value::String("required".to_string()), + ToolChoice::Tool { name } => json!({ + "type": "function", + "function": { "name": name }, + }), + } +} + +fn normalize_response( + model: &str, + response: ChatCompletionResponse, +) -> Result { + let choice = response + .choices + .into_iter() + .next() + .ok_or(ApiError::InvalidSseFrame( + "chat completion response missing choices", + ))?; + let mut content = Vec::new(); + if let Some(text) = choice.message.content.filter(|value| !value.is_empty()) { + content.push(OutputContentBlock::Text { text }); + } + for tool_call in choice.message.tool_calls { + content.push(OutputContentBlock::ToolUse { + id: tool_call.id, + name: tool_call.function.name, + input: parse_tool_arguments(&tool_call.function.arguments), + }); + } + + Ok(MessageResponse { + id: response.id, + kind: "message".to_string(), + role: choice.message.role, + content, + model: response.model.if_empty_then(model.to_string()), + stop_reason: choice + .finish_reason + .map(|value| normalize_finish_reason(&value)), + stop_sequence: None, + usage: Usage { + input_tokens: response + .usage + .as_ref() + .map_or(0, |usage| usage.prompt_tokens), + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + output_tokens: response + .usage + .as_ref() + .map_or(0, |usage| usage.completion_tokens), + }, + request_id: None, + }) +} + +fn parse_tool_arguments(arguments: &str) -> Value { + serde_json::from_str(arguments).unwrap_or_else(|_| json!({ "raw": arguments })) +} + +fn next_sse_frame(buffer: &mut Vec) -> Option { + let separator = buffer + .windows(2) + .position(|window| window == b"\n\n") + .map(|position| (position, 2)) + .or_else(|| { + buffer + .windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|position| (position, 4)) + })?; + + let (position, separator_len) = separator; + let frame = buffer.drain(..position + separator_len).collect::>(); + let frame_len = frame.len().saturating_sub(separator_len); + Some(String::from_utf8_lossy(&frame[..frame_len]).into_owned()) +} + +fn parse_sse_frame(frame: &str) -> Result, ApiError> { + let trimmed = frame.trim(); + if trimmed.is_empty() { + return Ok(None); + } + + let mut data_lines = Vec::new(); + for line in trimmed.lines() { + if line.starts_with(':') { + continue; + } + if let Some(data) = line.strip_prefix("data:") { + data_lines.push(data.trim_start()); + } + } + if data_lines.is_empty() { + return Ok(None); + } + let payload = data_lines.join("\n"); + if payload == "[DONE]" { + return Ok(None); + } + serde_json::from_str(&payload) + .map(Some) + .map_err(ApiError::from) +} + +fn read_env_non_empty(key: &str) -> Result, ApiError> { + match std::env::var(key) { + Ok(value) if !value.is_empty() => Ok(Some(value)), + Ok(_) | Err(std::env::VarError::NotPresent) => Ok(None), + Err(error) => Err(ApiError::from(error)), + } +} + +#[must_use] +pub fn has_api_key(key: &str) -> bool { + read_env_non_empty(key) + .ok() + .and_then(std::convert::identity) + .is_some() +} + +#[must_use] +pub fn read_base_url(config: OpenAiCompatConfig) -> String { + std::env::var(config.base_url_env).unwrap_or_else(|_| config.default_base_url.to_string()) +} + +fn chat_completions_endpoint(base_url: &str) -> String { + let trimmed = base_url.trim_end_matches('/'); + if trimmed.ends_with("/chat/completions") { + trimmed.to_string() + } else { + format!("{trimmed}/chat/completions") + } +} + +fn request_id_from_headers(headers: &reqwest::header::HeaderMap) -> Option { + headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(ALT_REQUEST_ID_HEADER)) + .and_then(|value| value.to_str().ok()) + .map(ToOwned::to_owned) +} + +async fn expect_success(response: reqwest::Response) -> Result { + let status = response.status(); + if status.is_success() { + return Ok(response); + } + + let body = response.text().await.unwrap_or_default(); + let parsed_error = serde_json::from_str::(&body).ok(); + let retryable = is_retryable_status(status); + + Err(ApiError::Api { + status, + error_type: parsed_error + .as_ref() + .and_then(|error| error.error.error_type.clone()), + message: parsed_error + .as_ref() + .and_then(|error| error.error.message.clone()), + body, + retryable, + }) +} + +const fn is_retryable_status(status: reqwest::StatusCode) -> bool { + matches!(status.as_u16(), 408 | 409 | 429 | 500 | 502 | 503 | 504) +} + +fn normalize_finish_reason(value: &str) -> String { + match value { + "stop" => "end_turn", + "tool_calls" => "tool_use", + other => other, + } + .to_string() +} + +trait StringExt { + fn if_empty_then(self, fallback: String) -> String; +} + +impl StringExt for String { + fn if_empty_then(self, fallback: String) -> String { + if self.is_empty() { + fallback + } else { + self + } + } +} + +#[cfg(test)] +mod tests { + use super::{ + build_chat_completion_request, chat_completions_endpoint, normalize_finish_reason, + openai_tool_choice, parse_tool_arguments, OpenAiCompatClient, OpenAiCompatConfig, + }; + use crate::error::ApiError; + use crate::types::{ + InputContentBlock, InputMessage, MessageRequest, ToolChoice, ToolDefinition, + ToolResultContentBlock, + }; + use serde_json::json; + use std::sync::{Mutex, OnceLock}; + + #[test] + fn request_translation_uses_openai_compatible_shape() { + let payload = build_chat_completion_request(&MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![ + InputContentBlock::Text { + text: "hello".to_string(), + }, + InputContentBlock::ToolResult { + tool_use_id: "tool_1".to_string(), + content: vec![ToolResultContentBlock::Json { + value: json!({"ok": true}), + }], + is_error: false, + }, + ], + }], + system: Some("be helpful".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Get weather".to_string()), + input_schema: json!({"type": "object"}), + }]), + tool_choice: Some(ToolChoice::Auto), + stream: false, + }); + + assert_eq!(payload["messages"][0]["role"], json!("system")); + assert_eq!(payload["messages"][1]["role"], json!("user")); + assert_eq!(payload["messages"][2]["role"], json!("tool")); + assert_eq!(payload["tools"][0]["type"], json!("function")); + assert_eq!(payload["tool_choice"], json!("auto")); + } + + #[test] + fn tool_choice_translation_supports_required_function() { + assert_eq!(openai_tool_choice(&ToolChoice::Any), json!("required")); + assert_eq!( + openai_tool_choice(&ToolChoice::Tool { + name: "weather".to_string(), + }), + json!({"type": "function", "function": {"name": "weather"}}) + ); + } + + #[test] + fn parses_tool_arguments_fallback() { + assert_eq!( + parse_tool_arguments("{\"city\":\"Paris\"}"), + json!({"city": "Paris"}) + ); + assert_eq!(parse_tool_arguments("not-json"), json!({"raw": "not-json"})); + } + + #[test] + fn missing_xai_api_key_is_provider_specific() { + let _lock = env_lock(); + std::env::remove_var("XAI_API_KEY"); + let error = OpenAiCompatClient::from_env(OpenAiCompatConfig::xai()) + .expect_err("missing key should error"); + assert!(matches!( + error, + ApiError::MissingCredentials { + provider: "xAI", + .. + } + )); + } + + #[test] + fn endpoint_builder_accepts_base_urls_and_full_endpoints() { + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1"), + "https://api.x.ai/v1/chat/completions" + ); + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1/"), + "https://api.x.ai/v1/chat/completions" + ); + assert_eq!( + chat_completions_endpoint("https://api.x.ai/v1/chat/completions"), + "https://api.x.ai/v1/chat/completions" + ); + } + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock") + } + + #[test] + fn normalizes_stop_reasons() { + assert_eq!(normalize_finish_reason("stop"), "end_turn"); + assert_eq!(normalize_finish_reason("tool_calls"), "tool_use"); + } +} diff --git a/rust/crates/api/src/sse.rs b/rust/crates/api/src/sse.rs index d7334cd..5f54e50 100644 --- a/rust/crates/api/src/sse.rs +++ b/rust/crates/api/src/sse.rs @@ -216,4 +216,64 @@ mod tests { )) ); } + + #[test] + fn parses_thinking_content_block_start() { + let frame = concat!( + "event: content_block_start\n", + "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\",\"signature\":null}}\n\n" + ); + + let event = parse_frame(frame).expect("frame should parse"); + assert_eq!( + event, + Some(StreamEvent::ContentBlockStart( + crate::types::ContentBlockStartEvent { + index: 0, + content_block: OutputContentBlock::Thinking { + thinking: String::new(), + signature: None, + }, + }, + )) + ); + } + + #[test] + fn parses_thinking_related_deltas() { + let thinking = concat!( + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"step 1\"}}\n\n" + ); + let signature = concat!( + "event: content_block_delta\n", + "data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"signature_delta\",\"signature\":\"sig_123\"}}\n\n" + ); + + let thinking_event = parse_frame(thinking).expect("thinking delta should parse"); + let signature_event = parse_frame(signature).expect("signature delta should parse"); + + assert_eq!( + thinking_event, + Some(StreamEvent::ContentBlockDelta( + crate::types::ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::ThinkingDelta { + thinking: "step 1".to_string(), + }, + } + )) + ); + assert_eq!( + signature_event, + Some(StreamEvent::ContentBlockDelta( + crate::types::ContentBlockDeltaEvent { + index: 0, + delta: ContentBlockDelta::SignatureDelta { + signature: "sig_123".to_string(), + }, + } + )) + ); + } } diff --git a/rust/crates/api/src/types.rs b/rust/crates/api/src/types.rs index 45d5c08..c060be6 100644 --- a/rust/crates/api/src/types.rs +++ b/rust/crates/api/src/types.rs @@ -135,6 +135,15 @@ pub enum OutputContentBlock { name: String, input: Value, }, + Thinking { + #[serde(default)] + thinking: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + signature: Option, + }, + RedactedThinking { + data: Value, + }, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -190,6 +199,8 @@ pub struct ContentBlockDeltaEvent { pub enum ContentBlockDelta { TextDelta { text: String }, InputJsonDelta { partial_json: String }, + ThinkingDelta { thinking: String }, + SignatureDelta { signature: String }, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] diff --git a/rust/crates/api/tests/client_integration.rs b/rust/crates/api/tests/client_integration.rs index c37fa99..3b6a3c3 100644 --- a/rust/crates/api/tests/client_integration.rs +++ b/rust/crates/api/tests/client_integration.rs @@ -3,9 +3,9 @@ use std::sync::Arc; use std::time::Duration; use api::{ - AnthropicClient, ApiError, ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, - InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, OutputContentBlock, - StreamEvent, ToolChoice, ToolDefinition, + ApiClient, ApiError, AuthSource, ContentBlockDelta, ContentBlockDeltaEvent, + ContentBlockStartEvent, InputContentBlock, InputMessage, MessageDeltaEvent, MessageRequest, + OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, }; use serde_json::json; use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -20,8 +20,8 @@ async fn send_message_posts_json_and_parses_response() { "\"id\":\"msg_test\",", "\"type\":\"message\",", "\"role\":\"assistant\",", - "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claude\"}],", - "\"model\":\"claude-3-7-sonnet-latest\",", + "\"content\":[{\"type\":\"text\",\"text\":\"Hello from Claw\"}],", + "\"model\":\"claude-sonnet-4-6\",", "\"stop_reason\":\"end_turn\",", "\"stop_sequence\":null,", "\"usage\":{\"input_tokens\":12,\"output_tokens\":4},", @@ -34,7 +34,7 @@ async fn send_message_posts_json_and_parses_response() { ) .await; - let client = AnthropicClient::new("test-key") + let client = ApiClient::new("test-key") .with_auth_token(Some("proxy-token".to_string())) .with_base_url(server.base_url()); let response = client @@ -48,7 +48,7 @@ async fn send_message_posts_json_and_parses_response() { assert_eq!( response.content, vec![OutputContentBlock::Text { - text: "Hello from Claude".to_string(), + text: "Hello from Claw".to_string(), }] ); @@ -68,7 +68,7 @@ async fn send_message_posts_json_and_parses_response() { serde_json::from_str(&request.body).expect("request body should be json"); assert_eq!( body.get("model").and_then(serde_json::Value::as_str), - Some("claude-3-7-sonnet-latest") + Some("claude-sonnet-4-6") ); assert!(body.get("stream").is_none()); assert_eq!(body["tools"][0]["name"], json!("get_weather")); @@ -80,7 +80,7 @@ async fn stream_message_parses_sse_events_with_tool_use() { let state = Arc::new(Mutex::new(Vec::::new())); let sse = concat!( "event: message_start\n", - "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n", + "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_stream\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":null,\"stop_sequence\":null,\"usage\":{\"input_tokens\":8,\"output_tokens\":0}}}\n\n", "event: content_block_start\n", "data: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"tool_use\",\"id\":\"toolu_123\",\"name\":\"get_weather\",\"input\":{}}}\n\n", "event: content_block_delta\n", @@ -104,7 +104,7 @@ async fn stream_message_parses_sse_events_with_tool_use() { ) .await; - let client = AnthropicClient::new("test-key") + let client = ApiClient::new("test-key") .with_auth_token(Some("proxy-token".to_string())) .with_base_url(server.base_url()); let mut stream = client @@ -176,13 +176,13 @@ async fn retries_retryable_failures_before_succeeding() { http_response( "200 OK", "application/json", - "{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-3-7-sonnet-latest\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", + "{\"id\":\"msg_retry\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Recovered\"}],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", ), ], ) .await; - let client = AnthropicClient::new("test-key") + let client = ApiClient::new("test-key") .with_base_url(server.base_url()) .with_retry_policy(2, Duration::from_millis(1), Duration::from_millis(2)); @@ -195,6 +195,47 @@ async fn retries_retryable_failures_before_succeeding() { assert_eq!(state.lock().await.len(), 2); } +#[tokio::test] +async fn provider_client_dispatches_api_requests() { + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"msg_provider\",\"type\":\"message\",\"role\":\"assistant\",\"content\":[{\"type\":\"text\",\"text\":\"Dispatched\"}],\"model\":\"claude-sonnet-4-6\",\"stop_reason\":\"end_turn\",\"stop_sequence\":null,\"usage\":{\"input_tokens\":3,\"output_tokens\":2}}", + )], + ) + .await; + + let client = ProviderClient::from_model_with_default_auth( + "claude-sonnet-4-6", + Some(AuthSource::ApiKey("test-key".to_string())), + ) + .expect("api provider client should be constructed"); + let client = match client { + ProviderClient::ClawApi(client) => { + ProviderClient::ClawApi(client.with_base_url(server.base_url())) + } + other => panic!("expected default provider, got {other:?}"), + }; + + let response = client + .send_message(&sample_request(false)) + .await + .expect("provider-dispatched request should succeed"); + + assert_eq!(response.total_tokens(), 5); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/v1/messages"); + assert_eq!( + request.headers.get("x-api-key").map(String::as_str), + Some("test-key") + ); +} + #[tokio::test] async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { let state = Arc::new(Mutex::new(Vec::::new())); @@ -215,7 +256,7 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { ) .await; - let client = AnthropicClient::new("test-key") + let client = ApiClient::new("test-key") .with_base_url(server.base_url()) .with_retry_policy(1, Duration::from_millis(1), Duration::from_millis(2)); @@ -246,11 +287,10 @@ async fn surfaces_retry_exhaustion_for_persistent_retryable_errors() { #[tokio::test] #[ignore = "requires ANTHROPIC_API_KEY and network access"] async fn live_stream_smoke_test() { - let client = AnthropicClient::from_env().expect("ANTHROPIC_API_KEY must be set"); + let client = ApiClient::from_env().expect("ANTHROPIC_API_KEY must be set"); let mut stream = client .stream_message(&MessageRequest { - model: std::env::var("ANTHROPIC_MODEL") - .unwrap_or_else(|_| "claude-3-7-sonnet-latest".to_string()), + model: std::env::var("CLAW_MODEL").unwrap_or_else(|_| "claude-sonnet-4-6".to_string()), max_tokens: 32, messages: vec![InputMessage::user_text( "Reply with exactly: hello from rust", @@ -410,7 +450,7 @@ fn http_response_with_headers( fn sample_request(stream: bool) -> MessageRequest { MessageRequest { - model: "claude-3-7-sonnet-latest".to_string(), + model: "claude-sonnet-4-6".to_string(), max_tokens: 64, messages: vec![InputMessage { role: "user".to_string(), diff --git a/rust/crates/api/tests/openai_compat_integration.rs b/rust/crates/api/tests/openai_compat_integration.rs new file mode 100644 index 0000000..b345b1f --- /dev/null +++ b/rust/crates/api/tests/openai_compat_integration.rs @@ -0,0 +1,415 @@ +use std::collections::HashMap; +use std::ffi::OsString; +use std::sync::Arc; +use std::sync::{Mutex as StdMutex, OnceLock}; + +use api::{ + ContentBlockDelta, ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, + InputContentBlock, InputMessage, MessageRequest, OpenAiCompatClient, OpenAiCompatConfig, + OutputContentBlock, ProviderClient, StreamEvent, ToolChoice, ToolDefinition, +}; +use serde_json::json; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; +use tokio::sync::Mutex; + +#[tokio::test] +async fn send_message_uses_openai_compatible_endpoint_and_auth() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_test\",", + "\"model\":\"grok-3\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Hello from Grok\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":11,\"completion_tokens\":5}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(server.base_url()); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.model, "grok-3"); + assert_eq!(response.total_tokens(), 16); + assert_eq!( + response.content, + vec![OutputContentBlock::Text { + text: "Hello from Grok".to_string(), + }] + ); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/chat/completions"); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer xai-test-key") + ); + let body: serde_json::Value = serde_json::from_str(&request.body).expect("json body"); + assert_eq!(body["model"], json!("grok-3")); + assert_eq!(body["messages"][0]["role"], json!("system")); + assert_eq!(body["tools"][0]["type"], json!("function")); +} + +#[tokio::test] +async fn send_message_accepts_full_chat_completions_endpoint_override() { + let state = Arc::new(Mutex::new(Vec::::new())); + let body = concat!( + "{", + "\"id\":\"chatcmpl_full_endpoint\",", + "\"model\":\"grok-3\",", + "\"choices\":[{", + "\"message\":{\"role\":\"assistant\",\"content\":\"Endpoint override works\",\"tool_calls\":[]},", + "\"finish_reason\":\"stop\"", + "}],", + "\"usage\":{\"prompt_tokens\":7,\"completion_tokens\":3}", + "}" + ); + let server = spawn_server( + state.clone(), + vec![http_response("200 OK", "application/json", body)], + ) + .await; + + let endpoint_url = format!("{}/chat/completions", server.base_url()); + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(endpoint_url); + let response = client + .send_message(&sample_request(false)) + .await + .expect("request should succeed"); + + assert_eq!(response.total_tokens(), 10); + + let captured = state.lock().await; + let request = captured.first().expect("server should capture request"); + assert_eq!(request.path, "/chat/completions"); +} + +#[tokio::test] +async fn stream_message_normalizes_text_and_multiple_tool_calls() { + let state = Arc::new(Mutex::new(Vec::::new())); + let sse = concat!( + "data: {\"id\":\"chatcmpl_stream\",\"model\":\"grok-3\",\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n", + "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"weather\",\"arguments\":\"{\\\"city\\\":\\\"Paris\\\"}\"}},{\"index\":1,\"id\":\"call_2\",\"function\":{\"name\":\"clock\",\"arguments\":\"{\\\"zone\\\":\\\"UTC\\\"}\"}}]}}]}\n\n", + "data: {\"id\":\"chatcmpl_stream\",\"choices\":[{\"delta\":{},\"finish_reason\":\"tool_calls\"}]}\n\n", + "data: [DONE]\n\n" + ); + let server = spawn_server( + state.clone(), + vec![http_response_with_headers( + "200 OK", + "text/event-stream", + sse, + &[("x-request-id", "req_grok_stream")], + )], + ) + .await; + + let client = OpenAiCompatClient::new("xai-test-key", OpenAiCompatConfig::xai()) + .with_base_url(server.base_url()); + let mut stream = client + .stream_message(&sample_request(false)) + .await + .expect("stream should start"); + + assert_eq!(stream.request_id(), Some("req_grok_stream")); + + let mut events = Vec::new(); + while let Some(event) = stream.next_event().await.expect("event should parse") { + events.push(event); + } + + assert!(matches!(events[0], StreamEvent::MessageStart(_))); + assert!(matches!( + events[1], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + content_block: OutputContentBlock::Text { .. }, + .. + }) + )); + assert!(matches!( + events[2], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + delta: ContentBlockDelta::TextDelta { .. }, + .. + }) + )); + assert!(matches!( + events[3], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 1, + content_block: OutputContentBlock::ToolUse { .. }, + }) + )); + assert!(matches!( + events[4], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 1, + delta: ContentBlockDelta::InputJsonDelta { .. }, + }) + )); + assert!(matches!( + events[5], + StreamEvent::ContentBlockStart(ContentBlockStartEvent { + index: 2, + content_block: OutputContentBlock::ToolUse { .. }, + }) + )); + assert!(matches!( + events[6], + StreamEvent::ContentBlockDelta(ContentBlockDeltaEvent { + index: 2, + delta: ContentBlockDelta::InputJsonDelta { .. }, + }) + )); + assert!(matches!( + events[7], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 1 }) + )); + assert!(matches!( + events[8], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 2 }) + )); + assert!(matches!( + events[9], + StreamEvent::ContentBlockStop(ContentBlockStopEvent { index: 0 }) + )); + assert!(matches!(events[10], StreamEvent::MessageDelta(_))); + assert!(matches!(events[11], StreamEvent::MessageStop(_))); + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + assert!(request.body.contains("\"stream\":true")); +} + +#[tokio::test] +async fn provider_client_dispatches_xai_requests_from_env() { + let _lock = env_lock(); + let _api_key = ScopedEnvVar::set("XAI_API_KEY", "xai-test-key"); + + let state = Arc::new(Mutex::new(Vec::::new())); + let server = spawn_server( + state.clone(), + vec![http_response( + "200 OK", + "application/json", + "{\"id\":\"chatcmpl_provider\",\"model\":\"grok-3\",\"choices\":[{\"message\":{\"role\":\"assistant\",\"content\":\"Through provider client\",\"tool_calls\":[]},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":9,\"completion_tokens\":4}}", + )], + ) + .await; + let _base_url = ScopedEnvVar::set("XAI_BASE_URL", server.base_url()); + + let client = + ProviderClient::from_model("grok").expect("xAI provider client should be constructed"); + assert!(matches!(client, ProviderClient::Xai(_))); + + let response = client + .send_message(&sample_request(false)) + .await + .expect("provider-dispatched request should succeed"); + + assert_eq!(response.total_tokens(), 13); + + let captured = state.lock().await; + let request = captured.first().expect("captured request"); + assert_eq!(request.path, "/chat/completions"); + assert_eq!( + request.headers.get("authorization").map(String::as_str), + Some("Bearer xai-test-key") + ); +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CapturedRequest { + path: String, + headers: HashMap, + body: String, +} + +struct TestServer { + base_url: String, + join_handle: tokio::task::JoinHandle<()>, +} + +impl TestServer { + fn base_url(&self) -> String { + self.base_url.clone() + } +} + +impl Drop for TestServer { + fn drop(&mut self) { + self.join_handle.abort(); + } +} + +async fn spawn_server( + state: Arc>>, + responses: Vec, +) -> TestServer { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let address = listener.local_addr().expect("listener addr"); + let join_handle = tokio::spawn(async move { + for response in responses { + let (mut socket, _) = listener.accept().await.expect("accept"); + let mut buffer = Vec::new(); + let mut header_end = None; + loop { + let mut chunk = [0_u8; 1024]; + let read = socket.read(&mut chunk).await.expect("read request"); + if read == 0 { + break; + } + buffer.extend_from_slice(&chunk[..read]); + if let Some(position) = find_header_end(&buffer) { + header_end = Some(position); + break; + } + } + + let header_end = header_end.expect("headers should exist"); + let (header_bytes, remaining) = buffer.split_at(header_end); + let header_text = String::from_utf8(header_bytes.to_vec()).expect("utf8 headers"); + let mut lines = header_text.split("\r\n"); + let request_line = lines.next().expect("request line"); + let path = request_line + .split_whitespace() + .nth(1) + .expect("path") + .to_string(); + let mut headers = HashMap::new(); + let mut content_length = 0_usize; + for line in lines { + if line.is_empty() { + continue; + } + let (name, value) = line.split_once(':').expect("header"); + let value = value.trim().to_string(); + if name.eq_ignore_ascii_case("content-length") { + content_length = value.parse().expect("content length"); + } + headers.insert(name.to_ascii_lowercase(), value); + } + + let mut body = remaining[4..].to_vec(); + while body.len() < content_length { + let mut chunk = vec![0_u8; content_length - body.len()]; + let read = socket.read(&mut chunk).await.expect("read body"); + if read == 0 { + break; + } + body.extend_from_slice(&chunk[..read]); + } + + state.lock().await.push(CapturedRequest { + path, + headers, + body: String::from_utf8(body).expect("utf8 body"), + }); + + socket + .write_all(response.as_bytes()) + .await + .expect("write response"); + } + }); + + TestServer { + base_url: format!("http://{address}"), + join_handle, + } +} + +fn find_header_end(bytes: &[u8]) -> Option { + bytes.windows(4).position(|window| window == b"\r\n\r\n") +} + +fn http_response(status: &str, content_type: &str, body: &str) -> String { + http_response_with_headers(status, content_type, body, &[]) +} + +fn http_response_with_headers( + status: &str, + content_type: &str, + body: &str, + headers: &[(&str, &str)], +) -> String { + let mut extra_headers = String::new(); + for (name, value) in headers { + use std::fmt::Write as _; + write!(&mut extra_headers, "{name}: {value}\r\n").expect("header write"); + } + format!( + "HTTP/1.1 {status}\r\ncontent-type: {content_type}\r\n{extra_headers}content-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ) +} + +fn sample_request(stream: bool) -> MessageRequest { + MessageRequest { + model: "grok-3".to_string(), + max_tokens: 64, + messages: vec![InputMessage { + role: "user".to_string(), + content: vec![InputContentBlock::Text { + text: "Say hello".to_string(), + }], + }], + system: Some("Use tools when needed".to_string()), + tools: Some(vec![ToolDefinition { + name: "weather".to_string(), + description: Some("Fetches weather".to_string()), + input_schema: json!({ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + }), + }]), + tool_choice: Some(ToolChoice::Auto), + stream, + } +} + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| StdMutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +struct ScopedEnvVar { + key: &'static str, + previous: Option, +} + +impl ScopedEnvVar { + fn set(key: &'static str, value: impl AsRef) -> Self { + let previous = std::env::var_os(key); + std::env::set_var(key, value); + Self { key, previous } + } +} + +impl Drop for ScopedEnvVar { + fn drop(&mut self) { + match &self.previous { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +} diff --git a/rust/crates/api/tests/provider_client_integration.rs b/rust/crates/api/tests/provider_client_integration.rs new file mode 100644 index 0000000..abeebdd --- /dev/null +++ b/rust/crates/api/tests/provider_client_integration.rs @@ -0,0 +1,86 @@ +use std::ffi::OsString; +use std::sync::{Mutex, OnceLock}; + +use api::{read_xai_base_url, ApiError, AuthSource, ProviderClient, ProviderKind}; + +#[test] +fn provider_client_routes_grok_aliases_through_xai() { + let _lock = env_lock(); + let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", Some("xai-test-key")); + + let client = ProviderClient::from_model("grok-mini").expect("grok alias should resolve"); + + assert_eq!(client.provider_kind(), ProviderKind::Xai); +} + +#[test] +fn provider_client_reports_missing_xai_credentials_for_grok_models() { + let _lock = env_lock(); + let _xai_api_key = EnvVarGuard::set("XAI_API_KEY", None); + + let error = ProviderClient::from_model("grok-3") + .expect_err("grok requests without XAI_API_KEY should fail fast"); + + match error { + ApiError::MissingCredentials { provider, env_vars } => { + assert_eq!(provider, "xAI"); + assert_eq!(env_vars, &["XAI_API_KEY"]); + } + other => panic!("expected missing xAI credentials, got {other:?}"), + } +} + +#[test] +fn provider_client_uses_explicit_auth_without_env_lookup() { + let _lock = env_lock(); + let _api_key = EnvVarGuard::set("ANTHROPIC_API_KEY", None); + let _auth_token = EnvVarGuard::set("ANTHROPIC_AUTH_TOKEN", None); + + let client = ProviderClient::from_model_with_default_auth( + "claude-sonnet-4-6", + Some(AuthSource::ApiKey("claw-test-key".to_string())), + ) + .expect("explicit auth should avoid env lookup"); + + assert_eq!(client.provider_kind(), ProviderKind::ClawApi); +} + +#[test] +fn read_xai_base_url_prefers_env_override() { + let _lock = env_lock(); + let _xai_base_url = EnvVarGuard::set("XAI_BASE_URL", Some("https://example.xai.test/v1")); + + assert_eq!(read_xai_base_url(), "https://example.xai.test/v1"); +} + +fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +struct EnvVarGuard { + key: &'static str, + original: Option, +} + +impl EnvVarGuard { + fn set(key: &'static str, value: Option<&str>) -> Self { + let original = std::env::var_os(key); + match value { + Some(value) => std::env::set_var(key, value), + None => std::env::remove_var(key), + } + Self { key, original } + } +} + +impl Drop for EnvVarGuard { + fn drop(&mut self) { + match &self.original { + Some(value) => std::env::set_var(self.key, value), + None => std::env::remove_var(self.key), + } + } +} diff --git a/rust/crates/claw-cli/Cargo.toml b/rust/crates/claw-cli/Cargo.toml new file mode 100644 index 0000000..074718a --- /dev/null +++ b/rust/crates/claw-cli/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "claw-cli" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true + +[[bin]] +name = "claw" +path = "src/main.rs" + +[dependencies] +api = { path = "../api" } +commands = { path = "../commands" } +compat-harness = { path = "../compat-harness" } +crossterm = "0.28" +pulldown-cmark = "0.13" +rustyline = "15" +runtime = { path = "../runtime" } +plugins = { path = "../plugins" } +serde_json.workspace = true +syntect = "5" +tokio = { version = "1", features = ["rt-multi-thread", "time"] } +tools = { path = "../tools" } + +[lints] +workspace = true diff --git a/rust/crates/claw-cli/src/app.rs b/rust/crates/claw-cli/src/app.rs new file mode 100644 index 0000000..85e754f --- /dev/null +++ b/rust/crates/claw-cli/src/app.rs @@ -0,0 +1,402 @@ +use std::io::{self, Write}; +use std::path::PathBuf; + +use crate::args::{OutputFormat, PermissionMode}; +use crate::input::{LineEditor, ReadOutcome}; +use crate::render::{Spinner, TerminalRenderer}; +use runtime::{ConversationClient, ConversationMessage, RuntimeError, StreamEvent, UsageSummary}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionConfig { + pub model: String, + pub permission_mode: PermissionMode, + pub config: Option, + pub output_format: OutputFormat, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionState { + pub turns: usize, + pub compacted_messages: usize, + pub last_model: String, + pub last_usage: UsageSummary, +} + +impl SessionState { + #[must_use] + pub fn new(model: impl Into) -> Self { + Self { + turns: 0, + compacted_messages: 0, + last_model: model.into(), + last_usage: UsageSummary::default(), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CommandResult { + Continue, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SlashCommand { + Help, + Status, + Compact, + Unknown(String), +} + +impl SlashCommand { + #[must_use] + pub fn parse(input: &str) -> Option { + let trimmed = input.trim(); + if !trimmed.starts_with('/') { + return None; + } + + let command = trimmed + .trim_start_matches('/') + .split_whitespace() + .next() + .unwrap_or_default(); + Some(match command { + "help" => Self::Help, + "status" => Self::Status, + "compact" => Self::Compact, + other => Self::Unknown(other.to_string()), + }) + } +} + +struct SlashCommandHandler { + command: SlashCommand, + summary: &'static str, +} + +const SLASH_COMMAND_HANDLERS: &[SlashCommandHandler] = &[ + SlashCommandHandler { + command: SlashCommand::Help, + summary: "Show command help", + }, + SlashCommandHandler { + command: SlashCommand::Status, + summary: "Show current session status", + }, + SlashCommandHandler { + command: SlashCommand::Compact, + summary: "Compact local session history", + }, +]; + +pub struct CliApp { + config: SessionConfig, + renderer: TerminalRenderer, + state: SessionState, + conversation_client: ConversationClient, + conversation_history: Vec, +} + +impl CliApp { + pub fn new(config: SessionConfig) -> Result { + let state = SessionState::new(config.model.clone()); + let conversation_client = ConversationClient::from_env(config.model.clone())?; + Ok(Self { + config, + renderer: TerminalRenderer::new(), + state, + conversation_client, + conversation_history: Vec::new(), + }) + } + + pub fn run_repl(&mut self) -> io::Result<()> { + let mut editor = LineEditor::new("› ", Vec::new()); + println!("Claw Code interactive mode"); + println!("Type /help for commands. Shift+Enter or Ctrl+J inserts a newline."); + + loop { + match editor.read_line()? { + ReadOutcome::Submit(input) => { + if input.trim().is_empty() { + continue; + } + self.handle_submission(&input, &mut io::stdout())?; + } + ReadOutcome::Cancel => continue, + ReadOutcome::Exit => break, + } + } + + Ok(()) + } + + pub fn run_prompt(&mut self, prompt: &str, out: &mut impl Write) -> io::Result<()> { + self.render_response(prompt, out) + } + + pub fn handle_submission( + &mut self, + input: &str, + out: &mut impl Write, + ) -> io::Result { + if let Some(command) = SlashCommand::parse(input) { + return self.dispatch_slash_command(command, out); + } + + self.state.turns += 1; + self.render_response(input, out)?; + Ok(CommandResult::Continue) + } + + fn dispatch_slash_command( + &mut self, + command: SlashCommand, + out: &mut impl Write, + ) -> io::Result { + match command { + SlashCommand::Help => Self::handle_help(out), + SlashCommand::Status => self.handle_status(out), + SlashCommand::Compact => self.handle_compact(out), + SlashCommand::Unknown(name) => { + writeln!(out, "Unknown slash command: /{name}")?; + Ok(CommandResult::Continue) + } + _ => { + writeln!(out, "Slash command unavailable in this mode")?; + Ok(CommandResult::Continue) + } + } + } + + fn handle_help(out: &mut impl Write) -> io::Result { + writeln!(out, "Available commands:")?; + for handler in SLASH_COMMAND_HANDLERS { + let name = match handler.command { + SlashCommand::Help => "/help", + SlashCommand::Status => "/status", + SlashCommand::Compact => "/compact", + _ => continue, + }; + writeln!(out, " {name:<9} {}", handler.summary)?; + } + Ok(CommandResult::Continue) + } + + fn handle_status(&mut self, out: &mut impl Write) -> io::Result { + writeln!( + out, + "status: turns={} model={} permission-mode={:?} output-format={:?} last-usage={} in/{} out config={}", + self.state.turns, + self.state.last_model, + self.config.permission_mode, + self.config.output_format, + self.state.last_usage.input_tokens, + self.state.last_usage.output_tokens, + self.config + .config + .as_ref() + .map_or_else(|| String::from(""), |path| path.display().to_string()) + )?; + Ok(CommandResult::Continue) + } + + fn handle_compact(&mut self, out: &mut impl Write) -> io::Result { + self.state.compacted_messages += self.state.turns; + self.state.turns = 0; + self.conversation_history.clear(); + writeln!( + out, + "Compacted session history into a local summary ({} messages total compacted).", + self.state.compacted_messages + )?; + Ok(CommandResult::Continue) + } + + fn handle_stream_event( + renderer: &TerminalRenderer, + event: StreamEvent, + stream_spinner: &mut Spinner, + tool_spinner: &mut Spinner, + saw_text: &mut bool, + turn_usage: &mut UsageSummary, + out: &mut impl Write, + ) { + match event { + StreamEvent::TextDelta(delta) => { + if !*saw_text { + let _ = + stream_spinner.finish("Streaming response", renderer.color_theme(), out); + *saw_text = true; + } + let _ = write!(out, "{delta}"); + let _ = out.flush(); + } + StreamEvent::ToolCallStart { name, input } => { + if *saw_text { + let _ = writeln!(out); + } + let _ = tool_spinner.tick( + &format!("Running tool `{name}` with {input}"), + renderer.color_theme(), + out, + ); + } + StreamEvent::ToolCallResult { + name, + output, + is_error, + } => { + let label = if is_error { + format!("Tool `{name}` failed") + } else { + format!("Tool `{name}` completed") + }; + let _ = tool_spinner.finish(&label, renderer.color_theme(), out); + let rendered_output = format!("### Tool `{name}`\n\n```text\n{output}\n```\n"); + let _ = renderer.stream_markdown(&rendered_output, out); + } + StreamEvent::Usage(usage) => { + *turn_usage = usage; + } + } + } + + fn write_turn_output( + &self, + summary: &runtime::TurnSummary, + out: &mut impl Write, + ) -> io::Result<()> { + match self.config.output_format { + OutputFormat::Text => { + writeln!( + out, + "\nToken usage: {} input / {} output", + self.state.last_usage.input_tokens, self.state.last_usage.output_tokens + )?; + } + OutputFormat::Json => { + writeln!( + out, + "{}", + serde_json::json!({ + "message": summary.assistant_text, + "usage": { + "input_tokens": self.state.last_usage.input_tokens, + "output_tokens": self.state.last_usage.output_tokens, + } + }) + )?; + } + OutputFormat::Ndjson => { + writeln!( + out, + "{}", + serde_json::json!({ + "type": "message", + "text": summary.assistant_text, + "usage": { + "input_tokens": self.state.last_usage.input_tokens, + "output_tokens": self.state.last_usage.output_tokens, + } + }) + )?; + } + } + Ok(()) + } + + fn render_response(&mut self, input: &str, out: &mut impl Write) -> io::Result<()> { + let mut stream_spinner = Spinner::new(); + stream_spinner.tick( + "Opening conversation stream", + self.renderer.color_theme(), + out, + )?; + + let mut turn_usage = UsageSummary::default(); + let mut tool_spinner = Spinner::new(); + let mut saw_text = false; + let renderer = &self.renderer; + + let result = + self.conversation_client + .run_turn(&mut self.conversation_history, input, |event| { + Self::handle_stream_event( + renderer, + event, + &mut stream_spinner, + &mut tool_spinner, + &mut saw_text, + &mut turn_usage, + out, + ); + }); + + let summary = match result { + Ok(summary) => summary, + Err(error) => { + stream_spinner.fail( + "Streaming response failed", + self.renderer.color_theme(), + out, + )?; + return Err(io::Error::other(error)); + } + }; + self.state.last_usage = summary.usage.clone(); + if saw_text { + writeln!(out)?; + } else { + stream_spinner.finish("Streaming response", self.renderer.color_theme(), out)?; + } + + self.write_turn_output(&summary, out)?; + let _ = turn_usage; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use crate::args::{OutputFormat, PermissionMode}; + + use super::{CommandResult, SessionConfig, SlashCommand}; + + #[test] + fn parses_required_slash_commands() { + assert_eq!(SlashCommand::parse("/help"), Some(SlashCommand::Help)); + assert_eq!(SlashCommand::parse(" /status "), Some(SlashCommand::Status)); + assert_eq!( + SlashCommand::parse("/compact now"), + Some(SlashCommand::Compact) + ); + } + + #[test] + fn help_output_lists_commands() { + let mut out = Vec::new(); + let result = super::CliApp::handle_help(&mut out).expect("help succeeds"); + assert_eq!(result, CommandResult::Continue); + let output = String::from_utf8_lossy(&out); + assert!(output.contains("/help")); + assert!(output.contains("/status")); + assert!(output.contains("/compact")); + } + + #[test] + fn session_state_tracks_config_values() { + let config = SessionConfig { + model: "sonnet".into(), + permission_mode: PermissionMode::DangerFullAccess, + config: Some(PathBuf::from("settings.toml")), + output_format: OutputFormat::Text, + }; + + assert_eq!(config.model, "sonnet"); + assert_eq!(config.permission_mode, PermissionMode::DangerFullAccess); + assert_eq!(config.config, Some(PathBuf::from("settings.toml"))); + } +} diff --git a/rust/crates/claw-cli/src/args.rs b/rust/crates/claw-cli/src/args.rs new file mode 100644 index 0000000..3c204a9 --- /dev/null +++ b/rust/crates/claw-cli/src/args.rs @@ -0,0 +1,104 @@ +use std::path::PathBuf; + +use clap::{Parser, Subcommand, ValueEnum}; + +#[derive(Debug, Clone, Parser, PartialEq, Eq)] +#[command(name = "claw-cli", version, about = "Claw Code CLI")] +pub struct Cli { + #[arg(long, default_value = "claude-opus-4-6")] + pub model: String, + + #[arg(long, value_enum, default_value_t = PermissionMode::DangerFullAccess)] + pub permission_mode: PermissionMode, + + #[arg(long)] + pub config: Option, + + #[arg(long, value_enum, default_value_t = OutputFormat::Text)] + pub output_format: OutputFormat, + + #[command(subcommand)] + pub command: Option, +} + +#[derive(Debug, Clone, Subcommand, PartialEq, Eq)] +pub enum Command { + /// Read upstream TS sources and print extracted counts + DumpManifests, + /// Print the current bootstrap phase skeleton + BootstrapPlan, + /// Start the OAuth login flow + Login, + /// Clear saved OAuth credentials + Logout, + /// Run a non-interactive prompt and exit + Prompt { prompt: Vec }, +} + +#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)] +pub enum PermissionMode { + ReadOnly, + WorkspaceWrite, + DangerFullAccess, +} + +#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)] +pub enum OutputFormat { + Text, + Json, + Ndjson, +} + +#[cfg(test)] +mod tests { + use clap::Parser; + + use super::{Cli, Command, OutputFormat, PermissionMode}; + + #[test] + fn parses_requested_flags() { + let cli = Cli::parse_from([ + "claw-cli", + "--model", + "claude-haiku-4-5-20251213", + "--permission-mode", + "read-only", + "--config", + "/tmp/config.toml", + "--output-format", + "ndjson", + "prompt", + "hello", + "world", + ]); + + assert_eq!(cli.model, "claude-haiku-4-5-20251213"); + assert_eq!(cli.permission_mode, PermissionMode::ReadOnly); + assert_eq!( + cli.config.as_deref(), + Some(std::path::Path::new("/tmp/config.toml")) + ); + assert_eq!(cli.output_format, OutputFormat::Ndjson); + assert_eq!( + cli.command, + Some(Command::Prompt { + prompt: vec!["hello".into(), "world".into()] + }) + ); + } + + #[test] + fn parses_login_and_logout_commands() { + let login = Cli::parse_from(["claw-cli", "login"]); + assert_eq!(login.command, Some(Command::Login)); + + let logout = Cli::parse_from(["claw-cli", "logout"]); + assert_eq!(logout.command, Some(Command::Logout)); + } + + #[test] + fn defaults_to_danger_full_access_permission_mode() { + let cli = Cli::parse_from(["claw-cli"]); + assert_eq!(cli.permission_mode, PermissionMode::DangerFullAccess); + } +} diff --git a/rust/crates/claw-cli/src/init.rs b/rust/crates/claw-cli/src/init.rs new file mode 100644 index 0000000..f4db53a --- /dev/null +++ b/rust/crates/claw-cli/src/init.rs @@ -0,0 +1,432 @@ +use std::fs; +use std::path::{Path, PathBuf}; + +const STARTER_CLAW_JSON: &str = concat!( + "{\n", + " \"permissions\": {\n", + " \"defaultMode\": \"dontAsk\"\n", + " }\n", + "}\n", +); +const GITIGNORE_COMMENT: &str = "# Claw Code local artifacts"; +const GITIGNORE_ENTRIES: [&str; 2] = [".claw/settings.local.json", ".claw/sessions/"]; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum InitStatus { + Created, + Updated, + Skipped, +} + +impl InitStatus { + #[must_use] + pub(crate) fn label(self) -> &'static str { + match self { + Self::Created => "created", + Self::Updated => "updated", + Self::Skipped => "skipped (already exists)", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct InitArtifact { + pub(crate) name: &'static str, + pub(crate) status: InitStatus, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct InitReport { + pub(crate) project_root: PathBuf, + pub(crate) artifacts: Vec, +} + +impl InitReport { + #[must_use] + pub(crate) fn render(&self) -> String { + let mut lines = vec![ + "Init".to_string(), + format!(" Project {}", self.project_root.display()), + ]; + for artifact in &self.artifacts { + lines.push(format!( + " {:<16} {}", + artifact.name, + artifact.status.label() + )); + } + lines.push(" Next step Review and tailor the generated guidance".to_string()); + lines.join("\n") + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[allow(clippy::struct_excessive_bools)] +struct RepoDetection { + rust_workspace: bool, + rust_root: bool, + python: bool, + package_json: bool, + typescript: bool, + nextjs: bool, + react: bool, + vite: bool, + nest: bool, + src_dir: bool, + tests_dir: bool, + rust_dir: bool, +} + +pub(crate) fn initialize_repo(cwd: &Path) -> Result> { + let mut artifacts = Vec::new(); + + let claw_dir = cwd.join(".claw"); + artifacts.push(InitArtifact { + name: ".claw/", + status: ensure_dir(&claw_dir)?, + }); + + let claw_json = cwd.join(".claw.json"); + artifacts.push(InitArtifact { + name: ".claw.json", + status: write_file_if_missing(&claw_json, STARTER_CLAW_JSON)?, + }); + + let gitignore = cwd.join(".gitignore"); + artifacts.push(InitArtifact { + name: ".gitignore", + status: ensure_gitignore_entries(&gitignore)?, + }); + + let claw_md = cwd.join("CLAW.md"); + let content = render_init_claw_md(cwd); + artifacts.push(InitArtifact { + name: "CLAW.md", + status: write_file_if_missing(&claw_md, &content)?, + }); + + Ok(InitReport { + project_root: cwd.to_path_buf(), + artifacts, + }) +} + +fn ensure_dir(path: &Path) -> Result { + if path.is_dir() { + return Ok(InitStatus::Skipped); + } + fs::create_dir_all(path)?; + Ok(InitStatus::Created) +} + +fn write_file_if_missing(path: &Path, content: &str) -> Result { + if path.exists() { + return Ok(InitStatus::Skipped); + } + fs::write(path, content)?; + Ok(InitStatus::Created) +} + +fn ensure_gitignore_entries(path: &Path) -> Result { + if !path.exists() { + let mut lines = vec![GITIGNORE_COMMENT.to_string()]; + lines.extend(GITIGNORE_ENTRIES.iter().map(|entry| (*entry).to_string())); + fs::write(path, format!("{}\n", lines.join("\n")))?; + return Ok(InitStatus::Created); + } + + let existing = fs::read_to_string(path)?; + let mut lines = existing.lines().map(ToOwned::to_owned).collect::>(); + let mut changed = false; + + if !lines.iter().any(|line| line == GITIGNORE_COMMENT) { + lines.push(GITIGNORE_COMMENT.to_string()); + changed = true; + } + + for entry in GITIGNORE_ENTRIES { + if !lines.iter().any(|line| line == entry) { + lines.push(entry.to_string()); + changed = true; + } + } + + if !changed { + return Ok(InitStatus::Skipped); + } + + fs::write(path, format!("{}\n", lines.join("\n")))?; + Ok(InitStatus::Updated) +} + +pub(crate) fn render_init_claw_md(cwd: &Path) -> String { + let detection = detect_repo(cwd); + let mut lines = vec![ + "# CLAW.md".to_string(), + String::new(), + "This file provides guidance to Claw Code (clawcode.dev) when working with code in this repository.".to_string(), + String::new(), + ]; + + let detected_languages = detected_languages(&detection); + let detected_frameworks = detected_frameworks(&detection); + lines.push("## Detected stack".to_string()); + if detected_languages.is_empty() { + lines.push("- No specific language markers were detected yet; document the primary language and verification commands once the project structure settles.".to_string()); + } else { + lines.push(format!("- Languages: {}.", detected_languages.join(", "))); + } + if detected_frameworks.is_empty() { + lines.push("- Frameworks: none detected from the supported starter markers.".to_string()); + } else { + lines.push(format!( + "- Frameworks/tooling markers: {}.", + detected_frameworks.join(", ") + )); + } + lines.push(String::new()); + + let verification_lines = verification_lines(cwd, &detection); + if !verification_lines.is_empty() { + lines.push("## Verification".to_string()); + lines.extend(verification_lines); + lines.push(String::new()); + } + + let structure_lines = repository_shape_lines(&detection); + if !structure_lines.is_empty() { + lines.push("## Repository shape".to_string()); + lines.extend(structure_lines); + lines.push(String::new()); + } + + let framework_lines = framework_notes(&detection); + if !framework_lines.is_empty() { + lines.push("## Framework notes".to_string()); + lines.extend(framework_lines); + lines.push(String::new()); + } + + lines.push("## Working agreement".to_string()); + lines.push("- Prefer small, reviewable changes and keep generated bootstrap files aligned with actual repo workflows.".to_string()); + lines.push("- Keep shared defaults in `.claw.json`; reserve `.claw/settings.local.json` for machine-local overrides.".to_string()); + lines.push("- Do not overwrite existing `CLAW.md` content automatically; update it intentionally when repo workflows change.".to_string()); + lines.push(String::new()); + + lines.join("\n") +} + +fn detect_repo(cwd: &Path) -> RepoDetection { + let package_json_contents = fs::read_to_string(cwd.join("package.json")) + .unwrap_or_default() + .to_ascii_lowercase(); + RepoDetection { + rust_workspace: cwd.join("rust").join("Cargo.toml").is_file(), + rust_root: cwd.join("Cargo.toml").is_file(), + python: cwd.join("pyproject.toml").is_file() + || cwd.join("requirements.txt").is_file() + || cwd.join("setup.py").is_file(), + package_json: cwd.join("package.json").is_file(), + typescript: cwd.join("tsconfig.json").is_file() + || package_json_contents.contains("typescript"), + nextjs: package_json_contents.contains("\"next\""), + react: package_json_contents.contains("\"react\""), + vite: package_json_contents.contains("\"vite\""), + nest: package_json_contents.contains("@nestjs"), + src_dir: cwd.join("src").is_dir(), + tests_dir: cwd.join("tests").is_dir(), + rust_dir: cwd.join("rust").is_dir(), + } +} + +fn detected_languages(detection: &RepoDetection) -> Vec<&'static str> { + let mut languages = Vec::new(); + if detection.rust_workspace || detection.rust_root { + languages.push("Rust"); + } + if detection.python { + languages.push("Python"); + } + if detection.typescript { + languages.push("TypeScript"); + } else if detection.package_json { + languages.push("JavaScript/Node.js"); + } + languages +} + +fn detected_frameworks(detection: &RepoDetection) -> Vec<&'static str> { + let mut frameworks = Vec::new(); + if detection.nextjs { + frameworks.push("Next.js"); + } + if detection.react { + frameworks.push("React"); + } + if detection.vite { + frameworks.push("Vite"); + } + if detection.nest { + frameworks.push("NestJS"); + } + frameworks +} + +fn verification_lines(cwd: &Path, detection: &RepoDetection) -> Vec { + let mut lines = Vec::new(); + if detection.rust_workspace { + lines.push("- Run Rust verification from `rust/`: `cargo fmt`, `cargo clippy --workspace --all-targets -- -D warnings`, `cargo test --workspace`".to_string()); + } else if detection.rust_root { + lines.push("- Run Rust verification from the repo root: `cargo fmt`, `cargo clippy --workspace --all-targets -- -D warnings`, `cargo test --workspace`".to_string()); + } + if detection.python { + if cwd.join("pyproject.toml").is_file() { + lines.push("- Run the Python project checks declared in `pyproject.toml` (for example: `pytest`, `ruff check`, and `mypy` when configured).".to_string()); + } else { + lines.push( + "- Run the repo's Python test/lint commands before shipping changes.".to_string(), + ); + } + } + if detection.package_json { + lines.push("- Run the JavaScript/TypeScript checks from `package.json` before shipping changes (`npm test`, `npm run lint`, `npm run build`, or the repo equivalent).".to_string()); + } + if detection.tests_dir && detection.src_dir { + lines.push("- `src/` and `tests/` are both present; update both surfaces together when behavior changes.".to_string()); + } + lines +} + +fn repository_shape_lines(detection: &RepoDetection) -> Vec { + let mut lines = Vec::new(); + if detection.rust_dir { + lines.push( + "- `rust/` contains the Rust workspace and active CLI/runtime implementation." + .to_string(), + ); + } + if detection.src_dir { + lines.push("- `src/` contains source files that should stay consistent with generated guidance and tests.".to_string()); + } + if detection.tests_dir { + lines.push("- `tests/` contains validation surfaces that should be reviewed alongside code changes.".to_string()); + } + lines +} + +fn framework_notes(detection: &RepoDetection) -> Vec { + let mut lines = Vec::new(); + if detection.nextjs { + lines.push("- Next.js detected: preserve routing/data-fetching conventions and verify production builds after changing app structure.".to_string()); + } + if detection.react && !detection.nextjs { + lines.push("- React detected: keep component behavior covered with focused tests and avoid unnecessary prop/API churn.".to_string()); + } + if detection.vite { + lines.push("- Vite detected: validate the production bundle after changing build-sensitive configuration or imports.".to_string()); + } + if detection.nest { + lines.push("- NestJS detected: keep module/provider boundaries explicit and verify controller/service wiring after refactors.".to_string()); + } + lines +} + +#[cfg(test)] +mod tests { + use super::{initialize_repo, render_init_claw_md}; + use std::fs; + use std::path::Path; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn temp_dir() -> std::path::PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("claw-init-{nanos}")) + } + + #[test] + fn initialize_repo_creates_expected_files_and_gitignore_entries() { + let root = temp_dir(); + fs::create_dir_all(root.join("rust")).expect("create rust dir"); + fs::write(root.join("rust").join("Cargo.toml"), "[workspace]\n").expect("write cargo"); + + let report = initialize_repo(&root).expect("init should succeed"); + let rendered = report.render(); + assert!(rendered.contains(".claw/ created")); + assert!(rendered.contains(".claw.json created")); + assert!(rendered.contains(".gitignore created")); + assert!(rendered.contains("CLAW.md created")); + assert!(root.join(".claw").is_dir()); + assert!(root.join(".claw.json").is_file()); + assert!(root.join("CLAW.md").is_file()); + assert_eq!( + fs::read_to_string(root.join(".claw.json")).expect("read claw json"), + concat!( + "{\n", + " \"permissions\": {\n", + " \"defaultMode\": \"dontAsk\"\n", + " }\n", + "}\n", + ) + ); + let gitignore = fs::read_to_string(root.join(".gitignore")).expect("read gitignore"); + assert!(gitignore.contains(".claw/settings.local.json")); + assert!(gitignore.contains(".claw/sessions/")); + let claw_md = fs::read_to_string(root.join("CLAW.md")).expect("read claw md"); + assert!(claw_md.contains("Languages: Rust.")); + assert!(claw_md.contains("cargo clippy --workspace --all-targets -- -D warnings")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn initialize_repo_is_idempotent_and_preserves_existing_files() { + let root = temp_dir(); + fs::create_dir_all(&root).expect("create root"); + fs::write(root.join("CLAW.md"), "custom guidance\n").expect("write existing claw md"); + fs::write(root.join(".gitignore"), ".claw/settings.local.json\n").expect("write gitignore"); + + let first = initialize_repo(&root).expect("first init should succeed"); + assert!(first + .render() + .contains("CLAW.md skipped (already exists)")); + let second = initialize_repo(&root).expect("second init should succeed"); + let second_rendered = second.render(); + assert!(second_rendered.contains(".claw/ skipped (already exists)")); + assert!(second_rendered.contains(".claw.json skipped (already exists)")); + assert!(second_rendered.contains(".gitignore skipped (already exists)")); + assert!(second_rendered.contains("CLAW.md skipped (already exists)")); + assert_eq!( + fs::read_to_string(root.join("CLAW.md")).expect("read existing claw md"), + "custom guidance\n" + ); + let gitignore = fs::read_to_string(root.join(".gitignore")).expect("read gitignore"); + assert_eq!(gitignore.matches(".claw/settings.local.json").count(), 1); + assert_eq!(gitignore.matches(".claw/sessions/").count(), 1); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn render_init_template_mentions_detected_python_and_nextjs_markers() { + let root = temp_dir(); + fs::create_dir_all(&root).expect("create root"); + fs::write(root.join("pyproject.toml"), "[project]\nname = \"demo\"\n") + .expect("write pyproject"); + fs::write( + root.join("package.json"), + r#"{"dependencies":{"next":"14.0.0","react":"18.0.0"},"devDependencies":{"typescript":"5.0.0"}}"#, + ) + .expect("write package json"); + + let rendered = render_init_claw_md(Path::new(&root)); + assert!(rendered.contains("Languages: Python, TypeScript.")); + assert!(rendered.contains("Frameworks/tooling markers: Next.js, React.")); + assert!(rendered.contains("pyproject.toml")); + assert!(rendered.contains("Next.js detected")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } +} diff --git a/rust/crates/claw-cli/src/input.rs b/rust/crates/claw-cli/src/input.rs new file mode 100644 index 0000000..a718cd7 --- /dev/null +++ b/rust/crates/claw-cli/src/input.rs @@ -0,0 +1,1195 @@ +use std::borrow::Cow; +use std::io::{self, IsTerminal, Write}; + +use crossterm::cursor::{MoveToColumn, MoveUp}; +use crossterm::event::{self, Event, KeyCode, KeyEvent, KeyEventKind, KeyModifiers}; +use crossterm::queue; +use crossterm::terminal::{self, Clear, ClearType}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ReadOutcome { + Submit(String), + Cancel, + Exit, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum EditorMode { + Plain, + Insert, + Normal, + Visual, + Command, +} + +impl EditorMode { + fn indicator(self, vim_enabled: bool) -> Option<&'static str> { + if !vim_enabled { + return None; + } + + Some(match self { + Self::Plain => "PLAIN", + Self::Insert => "INSERT", + Self::Normal => "NORMAL", + Self::Visual => "VISUAL", + Self::Command => "COMMAND", + }) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +struct YankBuffer { + text: String, + linewise: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct EditSession { + text: String, + cursor: usize, + mode: EditorMode, + pending_operator: Option, + visual_anchor: Option, + command_buffer: String, + command_cursor: usize, + history_index: Option, + history_backup: Option, + rendered_cursor_row: usize, + rendered_lines: usize, +} + +impl EditSession { + fn new(vim_enabled: bool) -> Self { + Self { + text: String::new(), + cursor: 0, + mode: if vim_enabled { + EditorMode::Insert + } else { + EditorMode::Plain + }, + pending_operator: None, + visual_anchor: None, + command_buffer: String::new(), + command_cursor: 0, + history_index: None, + history_backup: None, + rendered_cursor_row: 0, + rendered_lines: 1, + } + } + + fn active_text(&self) -> &str { + if self.mode == EditorMode::Command { + &self.command_buffer + } else { + &self.text + } + } + + fn current_len(&self) -> usize { + self.active_text().len() + } + + fn has_input(&self) -> bool { + !self.active_text().is_empty() + } + + fn current_line(&self) -> String { + self.active_text().to_string() + } + + fn set_text_from_history(&mut self, entry: String) { + self.text = entry; + self.cursor = self.text.len(); + self.pending_operator = None; + self.visual_anchor = None; + if self.mode != EditorMode::Plain && self.mode != EditorMode::Insert { + self.mode = EditorMode::Normal; + } + } + + fn enter_insert_mode(&mut self) { + self.mode = EditorMode::Insert; + self.pending_operator = None; + self.visual_anchor = None; + } + + fn enter_normal_mode(&mut self) { + self.mode = EditorMode::Normal; + self.pending_operator = None; + self.visual_anchor = None; + } + + fn enter_visual_mode(&mut self) { + self.mode = EditorMode::Visual; + self.pending_operator = None; + self.visual_anchor = Some(self.cursor); + } + + fn enter_command_mode(&mut self) { + self.mode = EditorMode::Command; + self.pending_operator = None; + self.visual_anchor = None; + self.command_buffer.clear(); + self.command_buffer.push(':'); + self.command_cursor = self.command_buffer.len(); + } + + fn exit_command_mode(&mut self) { + self.command_buffer.clear(); + self.command_cursor = 0; + self.enter_normal_mode(); + } + + fn visible_buffer(&self) -> Cow<'_, str> { + if self.mode != EditorMode::Visual { + return Cow::Borrowed(self.active_text()); + } + + let Some(anchor) = self.visual_anchor else { + return Cow::Borrowed(self.active_text()); + }; + let Some((start, end)) = selection_bounds(&self.text, anchor, self.cursor) else { + return Cow::Borrowed(self.active_text()); + }; + + Cow::Owned(render_selected_text(&self.text, start, end)) + } + + fn prompt<'a>(&self, base_prompt: &'a str, vim_enabled: bool) -> Cow<'a, str> { + match self.mode.indicator(vim_enabled) { + Some(mode) => Cow::Owned(format!("[{mode}] {base_prompt}")), + None => Cow::Borrowed(base_prompt), + } + } + + fn clear_render(&self, out: &mut impl Write) -> io::Result<()> { + if self.rendered_cursor_row > 0 { + queue!(out, MoveUp(to_u16(self.rendered_cursor_row)?))?; + } + queue!(out, MoveToColumn(0), Clear(ClearType::FromCursorDown))?; + out.flush() + } + + fn render( + &mut self, + out: &mut impl Write, + base_prompt: &str, + vim_enabled: bool, + ) -> io::Result<()> { + self.clear_render(out)?; + + let prompt = self.prompt(base_prompt, vim_enabled); + let buffer = self.visible_buffer(); + write!(out, "{prompt}{buffer}")?; + + let (cursor_row, cursor_col, total_lines) = self.cursor_layout(prompt.as_ref()); + let rows_to_move_up = total_lines.saturating_sub(cursor_row + 1); + if rows_to_move_up > 0 { + queue!(out, MoveUp(to_u16(rows_to_move_up)?))?; + } + queue!(out, MoveToColumn(to_u16(cursor_col)?))?; + out.flush()?; + + self.rendered_cursor_row = cursor_row; + self.rendered_lines = total_lines; + Ok(()) + } + + fn finalize_render( + &self, + out: &mut impl Write, + base_prompt: &str, + vim_enabled: bool, + ) -> io::Result<()> { + self.clear_render(out)?; + let prompt = self.prompt(base_prompt, vim_enabled); + let buffer = self.visible_buffer(); + write!(out, "{prompt}{buffer}")?; + writeln!(out) + } + + fn cursor_layout(&self, prompt: &str) -> (usize, usize, usize) { + let active_text = self.active_text(); + let cursor = if self.mode == EditorMode::Command { + self.command_cursor + } else { + self.cursor + }; + + let cursor_prefix = &active_text[..cursor]; + let cursor_row = cursor_prefix.bytes().filter(|byte| *byte == b'\n').count(); + let cursor_col = match cursor_prefix.rsplit_once('\n') { + Some((_, suffix)) => suffix.chars().count(), + None => prompt.chars().count() + cursor_prefix.chars().count(), + }; + let total_lines = active_text.bytes().filter(|byte| *byte == b'\n').count() + 1; + (cursor_row, cursor_col, total_lines) + } +} + +enum KeyAction { + Continue, + Submit(String), + Cancel, + Exit, + ToggleVim, +} + +pub struct LineEditor { + prompt: String, + completions: Vec, + history: Vec, + yank_buffer: YankBuffer, + vim_enabled: bool, + completion_state: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct CompletionState { + prefix: String, + matches: Vec, + next_index: usize, +} + +impl LineEditor { + #[must_use] + pub fn new(prompt: impl Into, completions: Vec) -> Self { + Self { + prompt: prompt.into(), + completions, + history: Vec::new(), + yank_buffer: YankBuffer::default(), + vim_enabled: false, + completion_state: None, + } + } + + pub fn push_history(&mut self, entry: impl Into) { + let entry = entry.into(); + if entry.trim().is_empty() { + return; + } + + self.history.push(entry); + } + + pub fn read_line(&mut self) -> io::Result { + if !io::stdin().is_terminal() || !io::stdout().is_terminal() { + return self.read_line_fallback(); + } + + let _raw_mode = RawModeGuard::new()?; + let mut stdout = io::stdout(); + let mut session = EditSession::new(self.vim_enabled); + session.render(&mut stdout, &self.prompt, self.vim_enabled)?; + + loop { + let Event::Key(key) = event::read()? else { + continue; + }; + if !matches!(key.kind, KeyEventKind::Press | KeyEventKind::Repeat) { + continue; + } + + match self.handle_key_event(&mut session, key) { + KeyAction::Continue => { + session.render(&mut stdout, &self.prompt, self.vim_enabled)?; + } + KeyAction::Submit(line) => { + session.finalize_render(&mut stdout, &self.prompt, self.vim_enabled)?; + return Ok(ReadOutcome::Submit(line)); + } + KeyAction::Cancel => { + session.clear_render(&mut stdout)?; + writeln!(stdout)?; + return Ok(ReadOutcome::Cancel); + } + KeyAction::Exit => { + session.clear_render(&mut stdout)?; + writeln!(stdout)?; + return Ok(ReadOutcome::Exit); + } + KeyAction::ToggleVim => { + session.clear_render(&mut stdout)?; + self.vim_enabled = !self.vim_enabled; + writeln!( + stdout, + "Vim mode {}.", + if self.vim_enabled { + "enabled" + } else { + "disabled" + } + )?; + session = EditSession::new(self.vim_enabled); + session.render(&mut stdout, &self.prompt, self.vim_enabled)?; + } + } + } + } + + fn read_line_fallback(&mut self) -> io::Result { + loop { + let mut stdout = io::stdout(); + write!(stdout, "{}", self.prompt)?; + stdout.flush()?; + + let mut buffer = String::new(); + let bytes_read = io::stdin().read_line(&mut buffer)?; + if bytes_read == 0 { + return Ok(ReadOutcome::Exit); + } + + while matches!(buffer.chars().last(), Some('\n' | '\r')) { + buffer.pop(); + } + + if self.handle_submission(&buffer) == Submission::ToggleVim { + self.vim_enabled = !self.vim_enabled; + writeln!( + stdout, + "Vim mode {}.", + if self.vim_enabled { + "enabled" + } else { + "disabled" + } + )?; + continue; + } + + return Ok(ReadOutcome::Submit(buffer)); + } + } + + fn handle_key_event(&mut self, session: &mut EditSession, key: KeyEvent) -> KeyAction { + if key.code != KeyCode::Tab { + self.completion_state = None; + } + + if key.modifiers.contains(KeyModifiers::CONTROL) { + match key.code { + KeyCode::Char('c') | KeyCode::Char('C') => { + return if session.has_input() { + KeyAction::Cancel + } else { + KeyAction::Exit + }; + } + KeyCode::Char('j') | KeyCode::Char('J') => { + if session.mode != EditorMode::Normal && session.mode != EditorMode::Visual { + self.insert_active_text(session, "\n"); + } + return KeyAction::Continue; + } + KeyCode::Char('d') | KeyCode::Char('D') => { + if session.current_len() == 0 { + return KeyAction::Exit; + } + self.delete_char_under_cursor(session); + return KeyAction::Continue; + } + _ => {} + } + } + + match key.code { + KeyCode::Enter if key.modifiers.contains(KeyModifiers::SHIFT) => { + if session.mode != EditorMode::Normal && session.mode != EditorMode::Visual { + self.insert_active_text(session, "\n"); + } + KeyAction::Continue + } + KeyCode::Enter => self.submit_or_toggle(session), + KeyCode::Esc => self.handle_escape(session), + KeyCode::Backspace => { + self.handle_backspace(session); + KeyAction::Continue + } + KeyCode::Delete => { + self.delete_char_under_cursor(session); + KeyAction::Continue + } + KeyCode::Left => { + self.move_left(session); + KeyAction::Continue + } + KeyCode::Right => { + self.move_right(session); + KeyAction::Continue + } + KeyCode::Up => { + self.history_up(session); + KeyAction::Continue + } + KeyCode::Down => { + self.history_down(session); + KeyAction::Continue + } + KeyCode::Home => { + self.move_line_start(session); + KeyAction::Continue + } + KeyCode::End => { + self.move_line_end(session); + KeyAction::Continue + } + KeyCode::Tab => { + self.complete_slash_command(session); + KeyAction::Continue + } + KeyCode::Char(ch) => { + self.handle_char(session, ch); + KeyAction::Continue + } + _ => KeyAction::Continue, + } + } + + fn handle_char(&mut self, session: &mut EditSession, ch: char) { + match session.mode { + EditorMode::Plain => self.insert_active_char(session, ch), + EditorMode::Insert => self.insert_active_char(session, ch), + EditorMode::Normal => self.handle_normal_char(session, ch), + EditorMode::Visual => self.handle_visual_char(session, ch), + EditorMode::Command => self.insert_active_char(session, ch), + } + } + + fn handle_normal_char(&mut self, session: &mut EditSession, ch: char) { + if let Some(operator) = session.pending_operator.take() { + match (operator, ch) { + ('d', 'd') => { + self.delete_current_line(session); + return; + } + ('y', 'y') => { + self.yank_current_line(session); + return; + } + _ => {} + } + } + + match ch { + 'h' => self.move_left(session), + 'j' => self.move_down(session), + 'k' => self.move_up(session), + 'l' => self.move_right(session), + 'd' | 'y' => session.pending_operator = Some(ch), + 'p' => self.paste_after(session), + 'i' => session.enter_insert_mode(), + 'v' => session.enter_visual_mode(), + ':' => session.enter_command_mode(), + _ => {} + } + } + + fn handle_visual_char(&mut self, session: &mut EditSession, ch: char) { + match ch { + 'h' => self.move_left(session), + 'j' => self.move_down(session), + 'k' => self.move_up(session), + 'l' => self.move_right(session), + 'v' => session.enter_normal_mode(), + _ => {} + } + } + + fn handle_escape(&mut self, session: &mut EditSession) -> KeyAction { + match session.mode { + EditorMode::Plain => KeyAction::Continue, + EditorMode::Insert => { + if session.cursor > 0 { + session.cursor = previous_boundary(&session.text, session.cursor); + } + session.enter_normal_mode(); + KeyAction::Continue + } + EditorMode::Normal => KeyAction::Continue, + EditorMode::Visual => { + session.enter_normal_mode(); + KeyAction::Continue + } + EditorMode::Command => { + session.exit_command_mode(); + KeyAction::Continue + } + } + } + + fn handle_backspace(&mut self, session: &mut EditSession) { + match session.mode { + EditorMode::Normal | EditorMode::Visual => self.move_left(session), + EditorMode::Command => { + if session.command_cursor <= 1 { + session.exit_command_mode(); + } else { + remove_previous_char(&mut session.command_buffer, &mut session.command_cursor); + } + } + EditorMode::Plain | EditorMode::Insert => { + remove_previous_char(&mut session.text, &mut session.cursor); + } + } + } + + fn submit_or_toggle(&mut self, session: &EditSession) -> KeyAction { + let line = session.current_line(); + match self.handle_submission(&line) { + Submission::Submit => KeyAction::Submit(line), + Submission::ToggleVim => KeyAction::ToggleVim, + } + } + + fn handle_submission(&mut self, line: &str) -> Submission { + if line.trim() == "/vim" { + Submission::ToggleVim + } else { + Submission::Submit + } + } + + fn insert_active_char(&mut self, session: &mut EditSession, ch: char) { + let mut buffer = [0; 4]; + self.insert_active_text(session, ch.encode_utf8(&mut buffer)); + } + + fn insert_active_text(&mut self, session: &mut EditSession, text: &str) { + if session.mode == EditorMode::Command { + session + .command_buffer + .insert_str(session.command_cursor, text); + session.command_cursor += text.len(); + } else { + session.text.insert_str(session.cursor, text); + session.cursor += text.len(); + } + } + + fn move_left(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + session.command_cursor = + previous_command_boundary(&session.command_buffer, session.command_cursor); + } else { + session.cursor = previous_boundary(&session.text, session.cursor); + } + } + + fn move_right(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + session.command_cursor = next_boundary(&session.command_buffer, session.command_cursor); + } else { + session.cursor = next_boundary(&session.text, session.cursor); + } + } + + fn move_line_start(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + session.command_cursor = 1; + } else { + session.cursor = line_start(&session.text, session.cursor); + } + } + + fn move_line_end(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + session.command_cursor = session.command_buffer.len(); + } else { + session.cursor = line_end(&session.text, session.cursor); + } + } + + fn move_up(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + return; + } + session.cursor = move_vertical(&session.text, session.cursor, -1); + } + + fn move_down(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + return; + } + session.cursor = move_vertical(&session.text, session.cursor, 1); + } + + fn delete_char_under_cursor(&self, session: &mut EditSession) { + match session.mode { + EditorMode::Command => { + if session.command_cursor < session.command_buffer.len() { + let end = next_boundary(&session.command_buffer, session.command_cursor); + session.command_buffer.drain(session.command_cursor..end); + } + } + _ => { + if session.cursor < session.text.len() { + let end = next_boundary(&session.text, session.cursor); + session.text.drain(session.cursor..end); + } + } + } + } + + fn delete_current_line(&mut self, session: &mut EditSession) { + let (line_start_idx, line_end_idx, delete_start_idx) = + current_line_delete_range(&session.text, session.cursor); + self.yank_buffer.text = session.text[line_start_idx..line_end_idx].to_string(); + self.yank_buffer.linewise = true; + session.text.drain(delete_start_idx..line_end_idx); + session.cursor = delete_start_idx.min(session.text.len()); + } + + fn yank_current_line(&mut self, session: &mut EditSession) { + let (line_start_idx, line_end_idx, _) = + current_line_delete_range(&session.text, session.cursor); + self.yank_buffer.text = session.text[line_start_idx..line_end_idx].to_string(); + self.yank_buffer.linewise = true; + } + + fn paste_after(&mut self, session: &mut EditSession) { + if self.yank_buffer.text.is_empty() { + return; + } + + if self.yank_buffer.linewise { + let line_end_idx = line_end(&session.text, session.cursor); + let insert_at = if line_end_idx < session.text.len() { + line_end_idx + 1 + } else { + session.text.len() + }; + let mut insertion = self.yank_buffer.text.clone(); + if insert_at == session.text.len() + && !session.text.is_empty() + && !session.text.ends_with('\n') + { + insertion.insert(0, '\n'); + } + if insert_at < session.text.len() && !insertion.ends_with('\n') { + insertion.push('\n'); + } + session.text.insert_str(insert_at, &insertion); + session.cursor = if insertion.starts_with('\n') { + insert_at + 1 + } else { + insert_at + }; + return; + } + + let insert_at = next_boundary(&session.text, session.cursor); + session.text.insert_str(insert_at, &self.yank_buffer.text); + session.cursor = insert_at + self.yank_buffer.text.len(); + } + + fn complete_slash_command(&mut self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + self.completion_state = None; + return; + } + if let Some(state) = self + .completion_state + .as_mut() + .filter(|_| session.cursor == session.text.len()) + .filter(|state| { + state + .matches + .iter() + .any(|candidate| candidate == &session.text) + }) + { + let candidate = state.matches[state.next_index % state.matches.len()].clone(); + state.next_index += 1; + session.text.replace_range(..session.cursor, &candidate); + session.cursor = candidate.len(); + return; + } + let Some(prefix) = slash_command_prefix(&session.text, session.cursor) else { + self.completion_state = None; + return; + }; + let matches = self + .completions + .iter() + .filter(|candidate| candidate.starts_with(prefix) && candidate.as_str() != prefix) + .cloned() + .collect::>(); + if matches.is_empty() { + self.completion_state = None; + return; + } + + let candidate = if let Some(state) = self + .completion_state + .as_mut() + .filter(|state| state.prefix == prefix && state.matches == matches) + { + let index = state.next_index % state.matches.len(); + state.next_index += 1; + state.matches[index].clone() + } else { + let candidate = matches[0].clone(); + self.completion_state = Some(CompletionState { + prefix: prefix.to_string(), + matches, + next_index: 1, + }); + candidate + }; + + session.text.replace_range(..session.cursor, &candidate); + session.cursor = candidate.len(); + } + + fn history_up(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command || self.history.is_empty() { + return; + } + + let next_index = match session.history_index { + Some(index) => index.saturating_sub(1), + None => { + session.history_backup = Some(session.text.clone()); + self.history.len() - 1 + } + }; + + session.history_index = Some(next_index); + session.set_text_from_history(self.history[next_index].clone()); + } + + fn history_down(&self, session: &mut EditSession) { + if session.mode == EditorMode::Command { + return; + } + + let Some(index) = session.history_index else { + return; + }; + + if index + 1 < self.history.len() { + let next_index = index + 1; + session.history_index = Some(next_index); + session.set_text_from_history(self.history[next_index].clone()); + return; + } + + session.history_index = None; + let restored = session.history_backup.take().unwrap_or_default(); + session.set_text_from_history(restored); + if self.vim_enabled { + session.enter_insert_mode(); + } else { + session.mode = EditorMode::Plain; + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Submission { + Submit, + ToggleVim, +} + +struct RawModeGuard; + +impl RawModeGuard { + fn new() -> io::Result { + terminal::enable_raw_mode().map_err(io::Error::other)?; + Ok(Self) + } +} + +impl Drop for RawModeGuard { + fn drop(&mut self) { + let _ = terminal::disable_raw_mode(); + } +} + +fn previous_boundary(text: &str, cursor: usize) -> usize { + if cursor == 0 { + return 0; + } + + text[..cursor] + .char_indices() + .next_back() + .map_or(0, |(index, _)| index) +} + +fn previous_command_boundary(text: &str, cursor: usize) -> usize { + previous_boundary(text, cursor).max(1) +} + +fn next_boundary(text: &str, cursor: usize) -> usize { + if cursor >= text.len() { + return text.len(); + } + + text[cursor..] + .chars() + .next() + .map_or(text.len(), |ch| cursor + ch.len_utf8()) +} + +fn remove_previous_char(text: &mut String, cursor: &mut usize) { + if *cursor == 0 { + return; + } + + let start = previous_boundary(text, *cursor); + text.drain(start..*cursor); + *cursor = start; +} + +fn line_start(text: &str, cursor: usize) -> usize { + text[..cursor].rfind('\n').map_or(0, |index| index + 1) +} + +fn line_end(text: &str, cursor: usize) -> usize { + text[cursor..] + .find('\n') + .map_or(text.len(), |index| cursor + index) +} + +fn move_vertical(text: &str, cursor: usize, delta: isize) -> usize { + let starts = line_starts(text); + let current_row = text[..cursor].bytes().filter(|byte| *byte == b'\n').count(); + let current_start = starts[current_row]; + let current_col = text[current_start..cursor].chars().count(); + + let max_row = starts.len().saturating_sub(1) as isize; + let target_row = (current_row as isize + delta).clamp(0, max_row) as usize; + if target_row == current_row { + return cursor; + } + + let target_start = starts[target_row]; + let target_end = if target_row + 1 < starts.len() { + starts[target_row + 1] - 1 + } else { + text.len() + }; + byte_index_for_char_column(&text[target_start..target_end], current_col) + target_start +} + +fn line_starts(text: &str) -> Vec { + let mut starts = vec![0]; + for (index, ch) in text.char_indices() { + if ch == '\n' { + starts.push(index + 1); + } + } + starts +} + +fn byte_index_for_char_column(text: &str, column: usize) -> usize { + let mut current = 0; + for (index, _) in text.char_indices() { + if current == column { + return index; + } + current += 1; + } + text.len() +} + +fn current_line_delete_range(text: &str, cursor: usize) -> (usize, usize, usize) { + let line_start_idx = line_start(text, cursor); + let line_end_core = line_end(text, cursor); + let line_end_idx = if line_end_core < text.len() { + line_end_core + 1 + } else { + line_end_core + }; + let delete_start_idx = if line_end_idx == text.len() && line_start_idx > 0 { + line_start_idx - 1 + } else { + line_start_idx + }; + (line_start_idx, line_end_idx, delete_start_idx) +} + +fn selection_bounds(text: &str, anchor: usize, cursor: usize) -> Option<(usize, usize)> { + if text.is_empty() { + return None; + } + + if cursor >= anchor { + let end = next_boundary(text, cursor); + Some((anchor.min(text.len()), end.min(text.len()))) + } else { + let end = next_boundary(text, anchor); + Some((cursor.min(text.len()), end.min(text.len()))) + } +} + +fn render_selected_text(text: &str, start: usize, end: usize) -> String { + let mut rendered = String::new(); + let mut in_selection = false; + + for (index, ch) in text.char_indices() { + if !in_selection && index == start { + rendered.push_str("\x1b[7m"); + in_selection = true; + } + if in_selection && index == end { + rendered.push_str("\x1b[0m"); + in_selection = false; + } + rendered.push(ch); + } + + if in_selection { + rendered.push_str("\x1b[0m"); + } + + rendered +} + +fn slash_command_prefix(line: &str, pos: usize) -> Option<&str> { + if pos != line.len() { + return None; + } + + let prefix = &line[..pos]; + if prefix.contains(char::is_whitespace) || !prefix.starts_with('/') { + return None; + } + + Some(prefix) +} + +fn to_u16(value: usize) -> io::Result { + u16::try_from(value).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "terminal position overflowed u16", + ) + }) +} + +#[cfg(test)] +mod tests { + use super::{ + selection_bounds, slash_command_prefix, EditSession, EditorMode, KeyAction, LineEditor, + }; + use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; + + #[test] + fn extracts_only_terminal_slash_command_prefixes() { + // given + let complete_prefix = slash_command_prefix("/he", 3); + let whitespace_prefix = slash_command_prefix("/help me", 5); + let plain_text_prefix = slash_command_prefix("hello", 5); + let mid_buffer_prefix = slash_command_prefix("/help", 2); + + // when + let result = ( + complete_prefix, + whitespace_prefix, + plain_text_prefix, + mid_buffer_prefix, + ); + + // then + assert_eq!(result, (Some("/he"), None, None, None)); + } + + #[test] + fn toggle_submission_flips_vim_mode() { + // given + let mut editor = LineEditor::new("> ", vec!["/help".to_string(), "/vim".to_string()]); + + // when + let first = editor.handle_submission("/vim"); + editor.vim_enabled = true; + let second = editor.handle_submission("/vim"); + + // then + assert!(matches!(first, super::Submission::ToggleVim)); + assert!(matches!(second, super::Submission::ToggleVim)); + } + + #[test] + fn normal_mode_supports_motion_and_insert_transition() { + // given + let mut editor = LineEditor::new("> ", vec![]); + editor.vim_enabled = true; + let mut session = EditSession::new(true); + session.text = "hello".to_string(); + session.cursor = session.text.len(); + let _ = editor.handle_escape(&mut session); + + // when + editor.handle_char(&mut session, 'h'); + editor.handle_char(&mut session, 'i'); + editor.handle_char(&mut session, '!'); + + // then + assert_eq!(session.mode, EditorMode::Insert); + assert_eq!(session.text, "hel!lo"); + } + + #[test] + fn yy_and_p_paste_yanked_line_after_current_line() { + // given + let mut editor = LineEditor::new("> ", vec![]); + editor.vim_enabled = true; + let mut session = EditSession::new(true); + session.text = "alpha\nbeta\ngamma".to_string(); + session.cursor = 0; + let _ = editor.handle_escape(&mut session); + + // when + editor.handle_char(&mut session, 'y'); + editor.handle_char(&mut session, 'y'); + editor.handle_char(&mut session, 'p'); + + // then + assert_eq!(session.text, "alpha\nalpha\nbeta\ngamma"); + } + + #[test] + fn dd_and_p_paste_deleted_line_after_current_line() { + // given + let mut editor = LineEditor::new("> ", vec![]); + editor.vim_enabled = true; + let mut session = EditSession::new(true); + session.text = "alpha\nbeta\ngamma".to_string(); + session.cursor = 0; + let _ = editor.handle_escape(&mut session); + + // when + editor.handle_char(&mut session, 'j'); + editor.handle_char(&mut session, 'd'); + editor.handle_char(&mut session, 'd'); + editor.handle_char(&mut session, 'p'); + + // then + assert_eq!(session.text, "alpha\ngamma\nbeta\n"); + } + + #[test] + fn visual_mode_tracks_selection_with_motions() { + // given + let mut editor = LineEditor::new("> ", vec![]); + editor.vim_enabled = true; + let mut session = EditSession::new(true); + session.text = "alpha\nbeta".to_string(); + session.cursor = 0; + let _ = editor.handle_escape(&mut session); + + // when + editor.handle_char(&mut session, 'v'); + editor.handle_char(&mut session, 'j'); + editor.handle_char(&mut session, 'l'); + + // then + assert_eq!(session.mode, EditorMode::Visual); + assert_eq!( + selection_bounds( + &session.text, + session.visual_anchor.unwrap_or(0), + session.cursor + ), + Some((0, 8)) + ); + } + + #[test] + fn command_mode_submits_colon_prefixed_input() { + // given + let mut editor = LineEditor::new("> ", vec![]); + editor.vim_enabled = true; + let mut session = EditSession::new(true); + session.text = "draft".to_string(); + session.cursor = session.text.len(); + let _ = editor.handle_escape(&mut session); + + // when + editor.handle_char(&mut session, ':'); + editor.handle_char(&mut session, 'q'); + editor.handle_char(&mut session, '!'); + let action = editor.submit_or_toggle(&session); + + // then + assert_eq!(session.mode, EditorMode::Command); + assert_eq!(session.command_buffer, ":q!"); + assert!(matches!(action, KeyAction::Submit(line) if line == ":q!")); + } + + #[test] + fn push_history_ignores_blank_entries() { + // given + let mut editor = LineEditor::new("> ", vec!["/help".to_string()]); + + // when + editor.push_history(" "); + editor.push_history("/help"); + + // then + assert_eq!(editor.history, vec!["/help".to_string()]); + } + + #[test] + fn tab_completes_matching_slash_commands() { + // given + let mut editor = LineEditor::new("> ", vec!["/help".to_string(), "/hello".to_string()]); + let mut session = EditSession::new(false); + session.text = "/he".to_string(); + session.cursor = session.text.len(); + + // when + editor.complete_slash_command(&mut session); + + // then + assert_eq!(session.text, "/help"); + assert_eq!(session.cursor, 5); + } + + #[test] + fn tab_cycles_between_matching_slash_commands() { + // given + let mut editor = LineEditor::new( + "> ", + vec!["/permissions".to_string(), "/plugin".to_string()], + ); + let mut session = EditSession::new(false); + session.text = "/p".to_string(); + session.cursor = session.text.len(); + + // when + editor.complete_slash_command(&mut session); + let first = session.text.clone(); + session.cursor = session.text.len(); + editor.complete_slash_command(&mut session); + let second = session.text.clone(); + + // then + assert_eq!(first, "/permissions"); + assert_eq!(second, "/plugin"); + } + + #[test] + fn ctrl_c_cancels_when_input_exists() { + // given + let mut editor = LineEditor::new("> ", vec![]); + let mut session = EditSession::new(false); + session.text = "draft".to_string(); + session.cursor = session.text.len(); + + // when + let action = editor.handle_key_event( + &mut session, + KeyEvent::new(KeyCode::Char('c'), KeyModifiers::CONTROL), + ); + + // then + assert!(matches!(action, KeyAction::Cancel)); + } +} diff --git a/rust/crates/claw-cli/src/main.rs b/rust/crates/claw-cli/src/main.rs new file mode 100644 index 0000000..2b7d6f1 --- /dev/null +++ b/rust/crates/claw-cli/src/main.rs @@ -0,0 +1,5090 @@ +mod init; +mod input; +mod render; + +use std::collections::BTreeSet; +use std::env; +use std::fmt::Write as _; +use std::fs; +use std::io::{self, IsTerminal, Read, Write}; +use std::net::TcpListener; +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::sync::mpsc::{self, RecvTimeoutError}; +use std::sync::{Arc, Mutex}; +use std::thread; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +use api::{ + resolve_startup_auth_source, AuthSource, ClawApiClient, ContentBlockDelta, InputContentBlock, + InputMessage, MessageRequest, MessageResponse, OutputContentBlock, + StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, +}; + +use commands::{ + handle_agents_slash_command, handle_plugins_slash_command, handle_skills_slash_command, + render_slash_command_help, resume_supported_slash_commands, slash_command_specs, + suggest_slash_commands, SlashCommand, +}; +use compat_harness::{extract_manifest, UpstreamPaths}; +use init::initialize_repo; +use plugins::{PluginManager, PluginManagerConfig}; +use render::{MarkdownStreamState, Spinner, TerminalRenderer}; +use runtime::{ + clear_oauth_credentials, generate_pkce_pair, generate_state, load_system_prompt, + parse_oauth_callback_request_target, save_oauth_credentials, ApiClient, ApiRequest, + AssistantEvent, CompactionConfig, ConfigLoader, ConfigSource, ContentBlock, + ConversationMessage, ConversationRuntime, MessageRole, OAuthAuthorizationRequest, OAuthConfig, + OAuthTokenExchangeRequest, PermissionMode, PermissionPolicy, ProjectContext, RuntimeError, + Session, TokenUsage, ToolError, ToolExecutor, UsageTracker, +}; +use serde_json::json; +use tools::GlobalToolRegistry; + +const DEFAULT_MODEL: &str = "claude-opus-4-6"; +fn max_tokens_for_model(model: &str) -> u32 { + if model.contains("opus") { + 32_000 + } else { + 64_000 + } +} +const DEFAULT_DATE: &str = "2026-03-31"; +const DEFAULT_OAUTH_CALLBACK_PORT: u16 = 4545; +const VERSION: &str = env!("CARGO_PKG_VERSION"); +const BUILD_TARGET: Option<&str> = option_env!("TARGET"); +const GIT_SHA: Option<&str> = option_env!("GIT_SHA"); +const INTERNAL_PROGRESS_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(3); + +type AllowedToolSet = BTreeSet; + +fn main() { + if let Err(error) = run() { + eprintln!("{}", render_cli_error(&error.to_string())); + std::process::exit(1); + } +} + +fn render_cli_error(problem: &str) -> String { + let mut lines = vec!["Error".to_string()]; + for (index, line) in problem.lines().enumerate() { + let label = if index == 0 { + " Problem " + } else { + " " + }; + lines.push(format!("{label}{line}")); + } + lines.push(" Help claw --help".to_string()); + lines.join("\n") +} + +fn run() -> Result<(), Box> { + let args: Vec = env::args().skip(1).collect(); + match parse_args(&args)? { + CliAction::DumpManifests => dump_manifests(), + CliAction::BootstrapPlan => print_bootstrap_plan(), + CliAction::Agents { args } => LiveCli::print_agents(args.as_deref())?, + CliAction::Skills { args } => LiveCli::print_skills(args.as_deref())?, + CliAction::PrintSystemPrompt { cwd, date } => print_system_prompt(cwd, date), + CliAction::Version => print_version(), + CliAction::ResumeSession { + session_path, + commands, + } => resume_session(&session_path, &commands), + CliAction::Prompt { + prompt, + model, + output_format, + allowed_tools, + permission_mode, + } => LiveCli::new(model, true, allowed_tools, permission_mode)? + .run_turn_with_output(&prompt, output_format)?, + CliAction::Login => run_login()?, + CliAction::Logout => run_logout()?, + CliAction::Init => run_init()?, + CliAction::Repl { + model, + allowed_tools, + permission_mode, + } => run_repl(model, allowed_tools, permission_mode)?, + CliAction::Help => print_help(), + } + Ok(()) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum CliAction { + DumpManifests, + BootstrapPlan, + Agents { + args: Option, + }, + Skills { + args: Option, + }, + PrintSystemPrompt { + cwd: PathBuf, + date: String, + }, + Version, + ResumeSession { + session_path: PathBuf, + commands: Vec, + }, + Prompt { + prompt: String, + model: String, + output_format: CliOutputFormat, + allowed_tools: Option, + permission_mode: PermissionMode, + }, + Login, + Logout, + Init, + Repl { + model: String, + allowed_tools: Option, + permission_mode: PermissionMode, + }, + // prompt-mode formatting is only supported for non-interactive runs + Help, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CliOutputFormat { + Text, + Json, +} + +impl CliOutputFormat { + fn parse(value: &str) -> Result { + match value { + "text" => Ok(Self::Text), + "json" => Ok(Self::Json), + other => Err(format!( + "unsupported value for --output-format: {other} (expected text or json)" + )), + } + } +} + +#[allow(clippy::too_many_lines)] +fn parse_args(args: &[String]) -> Result { + let mut model = DEFAULT_MODEL.to_string(); + let mut output_format = CliOutputFormat::Text; + let mut permission_mode = default_permission_mode(); + let mut wants_version = false; + let mut allowed_tool_values = Vec::new(); + let mut rest = Vec::new(); + let mut index = 0; + + while index < args.len() { + match args[index].as_str() { + "--version" | "-V" => { + wants_version = true; + index += 1; + } + "--model" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --model".to_string())?; + model = resolve_model_alias(value).to_string(); + index += 2; + } + flag if flag.starts_with("--model=") => { + model = resolve_model_alias(&flag[8..]).to_string(); + index += 1; + } + "--output-format" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --output-format".to_string())?; + output_format = CliOutputFormat::parse(value)?; + index += 2; + } + "--permission-mode" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --permission-mode".to_string())?; + permission_mode = parse_permission_mode_arg(value)?; + index += 2; + } + flag if flag.starts_with("--output-format=") => { + output_format = CliOutputFormat::parse(&flag[16..])?; + index += 1; + } + flag if flag.starts_with("--permission-mode=") => { + permission_mode = parse_permission_mode_arg(&flag[18..])?; + index += 1; + } + "--dangerously-skip-permissions" => { + permission_mode = PermissionMode::DangerFullAccess; + index += 1; + } + "-p" => { + // Claw Code compat: -p "prompt" = one-shot prompt + let prompt = args[index + 1..].join(" "); + if prompt.trim().is_empty() { + return Err("-p requires a prompt string".to_string()); + } + return Ok(CliAction::Prompt { + prompt, + model: resolve_model_alias(&model).to_string(), + output_format, + allowed_tools: normalize_allowed_tools(&allowed_tool_values)?, + permission_mode, + }); + } + "--print" => { + // Claw Code compat: --print makes output non-interactive + output_format = CliOutputFormat::Text; + index += 1; + } + "--allowedTools" | "--allowed-tools" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --allowedTools".to_string())?; + allowed_tool_values.push(value.clone()); + index += 2; + } + flag if flag.starts_with("--allowedTools=") => { + allowed_tool_values.push(flag[15..].to_string()); + index += 1; + } + flag if flag.starts_with("--allowed-tools=") => { + allowed_tool_values.push(flag[16..].to_string()); + index += 1; + } + other => { + rest.push(other.to_string()); + index += 1; + } + } + } + + if wants_version { + return Ok(CliAction::Version); + } + + let allowed_tools = normalize_allowed_tools(&allowed_tool_values)?; + + if rest.is_empty() { + return Ok(CliAction::Repl { + model, + allowed_tools, + permission_mode, + }); + } + if matches!(rest.first().map(String::as_str), Some("--help" | "-h")) { + return Ok(CliAction::Help); + } + if rest.first().map(String::as_str) == Some("--resume") { + return parse_resume_args(&rest[1..]); + } + + match rest[0].as_str() { + "dump-manifests" => Ok(CliAction::DumpManifests), + "bootstrap-plan" => Ok(CliAction::BootstrapPlan), + "agents" => Ok(CliAction::Agents { + args: join_optional_args(&rest[1..]), + }), + "skills" => Ok(CliAction::Skills { + args: join_optional_args(&rest[1..]), + }), + "system-prompt" => parse_system_prompt_args(&rest[1..]), + "login" => Ok(CliAction::Login), + "logout" => Ok(CliAction::Logout), + "init" => Ok(CliAction::Init), + "prompt" => { + let prompt = rest[1..].join(" "); + if prompt.trim().is_empty() { + return Err("prompt subcommand requires a prompt string".to_string()); + } + Ok(CliAction::Prompt { + prompt, + model, + output_format, + allowed_tools, + permission_mode, + }) + } + other if other.starts_with('/') => parse_direct_slash_cli_action(&rest), + _other => Ok(CliAction::Prompt { + prompt: rest.join(" "), + model, + output_format, + allowed_tools, + permission_mode, + }), + } +} + +fn join_optional_args(args: &[String]) -> Option { + let joined = args.join(" "); + let trimmed = joined.trim(); + (!trimmed.is_empty()).then(|| trimmed.to_string()) +} + +fn parse_direct_slash_cli_action(rest: &[String]) -> Result { + let raw = rest.join(" "); + match SlashCommand::parse(&raw) { + Some(SlashCommand::Help) => Ok(CliAction::Help), + Some(SlashCommand::Agents { args }) => Ok(CliAction::Agents { args }), + Some(SlashCommand::Skills { args }) => Ok(CliAction::Skills { args }), + Some(command) => Err(format_direct_slash_command_error( + match &command { + SlashCommand::Unknown(name) => format!("/{name}"), + _ => rest[0].clone(), + } + .as_str(), + matches!(command, SlashCommand::Unknown(_)), + )), + None => Err(format!("unknown subcommand: {}", rest[0])), + } +} + +fn format_direct_slash_command_error(command: &str, is_unknown: bool) -> String { + let trimmed = command.trim().trim_start_matches('/'); + let mut lines = vec![ + "Direct slash command unavailable".to_string(), + format!(" Command /{trimmed}"), + ]; + if is_unknown { + append_slash_command_suggestions(&mut lines, trimmed); + } else { + lines.push(" Try Start `claw` to use interactive slash commands".to_string()); + lines.push( + " Tip Resume-safe commands also work with `claw --resume SESSION.json ...`" + .to_string(), + ); + } + lines.join("\n") +} + +fn resolve_model_alias(model: &str) -> &str { + match model { + "opus" => "claude-opus-4-6", + "sonnet" => "claude-sonnet-4-6", + "haiku" => "claude-haiku-4-5-20251213", + _ => model, + } +} + +fn normalize_allowed_tools(values: &[String]) -> Result, String> { + current_tool_registry()?.normalize_allowed_tools(values) +} + +fn current_tool_registry() -> Result { + let cwd = env::current_dir().map_err(|error| error.to_string())?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader.load().map_err(|error| error.to_string())?; + let plugin_manager = build_plugin_manager(&cwd, &loader, &runtime_config); + let plugin_tools = plugin_manager + .aggregated_tools() + .map_err(|error| error.to_string())?; + GlobalToolRegistry::with_plugin_tools(plugin_tools) +} + +fn parse_permission_mode_arg(value: &str) -> Result { + normalize_permission_mode(value) + .ok_or_else(|| { + format!( + "unsupported permission mode '{value}'. Use read-only, workspace-write, or danger-full-access." + ) + }) + .map(permission_mode_from_label) +} + +fn permission_mode_from_label(mode: &str) -> PermissionMode { + match mode { + "read-only" => PermissionMode::ReadOnly, + "workspace-write" => PermissionMode::WorkspaceWrite, + "danger-full-access" => PermissionMode::DangerFullAccess, + other => panic!("unsupported permission mode label: {other}"), + } +} + +fn default_permission_mode() -> PermissionMode { + env::var("CLAW_PERMISSION_MODE") + .ok() + .as_deref() + .and_then(normalize_permission_mode) + .map_or(PermissionMode::DangerFullAccess, permission_mode_from_label) +} + +fn filter_tool_specs( + tool_registry: &GlobalToolRegistry, + allowed_tools: Option<&AllowedToolSet>, +) -> Vec { + tool_registry.definitions(allowed_tools) +} + +fn parse_system_prompt_args(args: &[String]) -> Result { + let mut cwd = env::current_dir().map_err(|error| error.to_string())?; + let mut date = DEFAULT_DATE.to_string(); + let mut index = 0; + + while index < args.len() { + match args[index].as_str() { + "--cwd" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --cwd".to_string())?; + cwd = PathBuf::from(value); + index += 2; + } + "--date" => { + let value = args + .get(index + 1) + .ok_or_else(|| "missing value for --date".to_string())?; + date.clone_from(value); + index += 2; + } + other => return Err(format!("unknown system-prompt option: {other}")), + } + } + + Ok(CliAction::PrintSystemPrompt { cwd, date }) +} + +fn parse_resume_args(args: &[String]) -> Result { + let session_path = args + .first() + .ok_or_else(|| "missing session path for --resume".to_string()) + .map(PathBuf::from)?; + let commands = args[1..].to_vec(); + if commands + .iter() + .any(|command| !command.trim_start().starts_with('/')) + { + return Err("--resume trailing arguments must be slash commands".to_string()); + } + Ok(CliAction::ResumeSession { + session_path, + commands, + }) +} + +fn dump_manifests() { + let workspace_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../.."); + let paths = UpstreamPaths::from_workspace_dir(&workspace_dir); + match extract_manifest(&paths) { + Ok(manifest) => { + println!("commands: {}", manifest.commands.entries().len()); + println!("tools: {}", manifest.tools.entries().len()); + println!("bootstrap phases: {}", manifest.bootstrap.phases().len()); + } + Err(error) => { + eprintln!("failed to extract manifests: {error}"); + std::process::exit(1); + } + } +} + +fn print_bootstrap_plan() { + for phase in runtime::BootstrapPlan::claw_default().phases() { + println!("- {phase:?}"); + } +} + +fn default_oauth_config() -> OAuthConfig { + OAuthConfig { + client_id: String::from("9d1c250a-e61b-44d9-88ed-5944d1962f5e"), + authorize_url: String::from("https://platform.claw.dev/oauth/authorize"), + token_url: String::from("https://platform.claw.dev/v1/oauth/token"), + callback_port: None, + manual_redirect_url: None, + scopes: vec![ + String::from("user:profile"), + String::from("user:inference"), + String::from("user:sessions:claw_code"), + ], + } +} + +fn run_login() -> Result<(), Box> { + let cwd = env::current_dir()?; + let config = ConfigLoader::default_for(&cwd).load()?; + let default_oauth = default_oauth_config(); + let oauth = config.oauth().unwrap_or(&default_oauth); + let callback_port = oauth.callback_port.unwrap_or(DEFAULT_OAUTH_CALLBACK_PORT); + let redirect_uri = runtime::loopback_redirect_uri(callback_port); + let pkce = generate_pkce_pair()?; + let state = generate_state()?; + let authorize_url = + OAuthAuthorizationRequest::from_config(oauth, redirect_uri.clone(), state.clone(), &pkce) + .build_url(); + + println!("Starting Claw OAuth login..."); + println!("Listening for callback on {redirect_uri}"); + if let Err(error) = open_browser(&authorize_url) { + eprintln!("warning: failed to open browser automatically: {error}"); + println!("Open this URL manually:\n{authorize_url}"); + } + + let callback = wait_for_oauth_callback(callback_port)?; + if let Some(error) = callback.error { + let description = callback + .error_description + .unwrap_or_else(|| "authorization failed".to_string()); + return Err(io::Error::other(format!("{error}: {description}")).into()); + } + let code = callback.code.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "callback did not include code") + })?; + let returned_state = callback.state.ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "callback did not include state") + })?; + if returned_state != state { + return Err(io::Error::new(io::ErrorKind::InvalidData, "oauth state mismatch").into()); + } + + let client = ClawApiClient::from_auth(AuthSource::None).with_base_url(api::read_base_url()); + let exchange_request = + OAuthTokenExchangeRequest::from_config(oauth, code, state, pkce.verifier, redirect_uri); + let runtime = tokio::runtime::Runtime::new()?; + let token_set = runtime.block_on(client.exchange_oauth_code(oauth, &exchange_request))?; + save_oauth_credentials(&runtime::OAuthTokenSet { + access_token: token_set.access_token, + refresh_token: token_set.refresh_token, + expires_at: token_set.expires_at, + scopes: token_set.scopes, + })?; + println!("Claw OAuth login complete."); + Ok(()) +} + +fn run_logout() -> Result<(), Box> { + clear_oauth_credentials()?; + println!("Claw OAuth credentials cleared."); + Ok(()) +} + +fn open_browser(url: &str) -> io::Result<()> { + let commands = if cfg!(target_os = "macos") { + vec![("open", vec![url])] + } else if cfg!(target_os = "windows") { + vec![("cmd", vec!["/C", "start", "", url])] + } else { + vec![("xdg-open", vec![url])] + }; + for (program, args) in commands { + match Command::new(program).args(args).spawn() { + Ok(_) => return Ok(()), + Err(error) if error.kind() == io::ErrorKind::NotFound => {} + Err(error) => return Err(error), + } + } + Err(io::Error::new( + io::ErrorKind::NotFound, + "no supported browser opener command found", + )) +} + +fn wait_for_oauth_callback( + port: u16, +) -> Result> { + let listener = TcpListener::bind(("127.0.0.1", port))?; + let (mut stream, _) = listener.accept()?; + let mut buffer = [0_u8; 4096]; + let bytes_read = stream.read(&mut buffer)?; + let request = String::from_utf8_lossy(&buffer[..bytes_read]); + let request_line = request.lines().next().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "missing callback request line") + })?; + let target = request_line.split_whitespace().nth(1).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "missing callback request target", + ) + })?; + let callback = parse_oauth_callback_request_target(target) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?; + let body = if callback.error.is_some() { + "Claw OAuth login failed. You can close this window." + } else { + "Claw OAuth login succeeded. You can close this window." + }; + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: text/plain; charset=utf-8\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + Ok(callback) +} + +fn print_system_prompt(cwd: PathBuf, date: String) { + match load_system_prompt(cwd, date, env::consts::OS, "unknown") { + Ok(sections) => println!("{}", sections.join("\n\n")), + Err(error) => { + eprintln!("failed to build system prompt: {error}"); + std::process::exit(1); + } + } +} + +fn print_version() { + println!("{}", render_version_report()); +} + +fn resume_session(session_path: &Path, commands: &[String]) { + let session = match Session::load_from_path(session_path) { + Ok(session) => session, + Err(error) => { + eprintln!("failed to restore session: {error}"); + std::process::exit(1); + } + }; + + if commands.is_empty() { + println!( + "Restored session from {} ({} messages).", + session_path.display(), + session.messages.len() + ); + return; + } + + let mut session = session; + for raw_command in commands { + let Some(command) = SlashCommand::parse(raw_command) else { + eprintln!("unsupported resumed command: {raw_command}"); + std::process::exit(2); + }; + match run_resume_command(session_path, &session, &command) { + Ok(ResumeCommandOutcome { + session: next_session, + message, + }) => { + session = next_session; + if let Some(message) = message { + println!("{message}"); + } + } + Err(error) => { + eprintln!("{error}"); + std::process::exit(2); + } + } + } +} + +#[derive(Debug, Clone)] +struct ResumeCommandOutcome { + session: Session, + message: Option, +} + +#[derive(Debug, Clone)] +struct StatusContext { + cwd: PathBuf, + session_path: Option, + loaded_config_files: usize, + discovered_config_files: usize, + memory_file_count: usize, + project_root: Option, + git_branch: Option, +} + +#[derive(Debug, Clone, Copy)] +struct StatusUsage { + message_count: usize, + turns: u32, + latest: TokenUsage, + cumulative: TokenUsage, + estimated_tokens: usize, +} + +fn format_model_report(model: &str, message_count: usize, turns: u32) -> String { + format!( + "Model + Current {model} + Session {message_count} messages · {turns} turns + +Aliases + opus claude-opus-4-6 + sonnet claude-sonnet-4-6 + haiku claude-haiku-4-5-20251213 + +Next + /model Show the current model + /model Switch models for this REPL session" + ) +} + +fn format_model_switch_report(previous: &str, next: &str, message_count: usize) -> String { + format!( + "Model updated + Previous {previous} + Current {next} + Preserved {message_count} messages + Tip Existing conversation context stayed attached" + ) +} + +fn format_permissions_report(mode: &str) -> String { + let modes = [ + ("read-only", "Read/search tools only", mode == "read-only"), + ( + "workspace-write", + "Edit files inside the workspace", + mode == "workspace-write", + ), + ( + "danger-full-access", + "Unrestricted tool access", + mode == "danger-full-access", + ), + ] + .into_iter() + .map(|(name, description, is_current)| { + let marker = if is_current { + "● current" + } else { + "○ available" + }; + format!(" {name:<18} {marker:<11} {description}") + }) + .collect::>() + .join( + " +", + ); + + let effect = match mode { + "read-only" => "Only read/search tools can run automatically", + "workspace-write" => "Editing tools can modify files in the workspace", + "danger-full-access" => "All tools can run without additional sandbox limits", + _ => "Unknown permission mode", + }; + + format!( + "Permissions + Active mode {mode} + Effect {effect} + +Modes +{modes} + +Next + /permissions Show the current mode + /permissions Switch modes for subsequent tool calls" + ) +} + +fn format_permissions_switch_report(previous: &str, next: &str) -> String { + format!( + "Permissions updated + Previous mode {previous} + Active mode {next} + Applies to Subsequent tool calls in this REPL + Tip Run /permissions to review all available modes" + ) +} + +fn format_cost_report(usage: TokenUsage) -> String { + format!( + "Cost + Input tokens {} + Output tokens {} + Cache create {} + Cache read {} + Total tokens {} + +Next + /status See session + workspace context + /compact Trim local history if the session is getting large", + usage.input_tokens, + usage.output_tokens, + usage.cache_creation_input_tokens, + usage.cache_read_input_tokens, + usage.total_tokens(), + ) +} + +fn format_resume_report(session_path: &str, message_count: usize, turns: u32) -> String { + format!( + "Session resumed + Session file {session_path} + History {message_count} messages · {turns} turns + Next /status · /diff · /export" + ) +} + +fn format_compact_report(removed: usize, resulting_messages: usize, skipped: bool) -> String { + if skipped { + format!( + "Compact + Result skipped + Reason Session is already below the compaction threshold + Messages kept {resulting_messages}" + ) + } else { + format!( + "Compact + Result compacted + Messages removed {removed} + Messages kept {resulting_messages} + Tip Use /status to review the trimmed session" + ) + } +} + +fn parse_git_status_metadata(status: Option<&str>) -> (Option, Option) { + let Some(status) = status else { + return (None, None); + }; + let branch = status.lines().next().and_then(|line| { + line.strip_prefix("## ") + .map(|line| { + line.split(['.', ' ']) + .next() + .unwrap_or_default() + .to_string() + }) + .filter(|value| !value.is_empty()) + }); + let project_root = find_git_root().ok(); + (project_root, branch) +} + +fn find_git_root() -> Result> { + let output = std::process::Command::new("git") + .args(["rev-parse", "--show-toplevel"]) + .current_dir(env::current_dir()?) + .output()?; + if !output.status.success() { + return Err("not a git repository".into()); + } + let path = String::from_utf8(output.stdout)?.trim().to_string(); + if path.is_empty() { + return Err("empty git root".into()); + } + Ok(PathBuf::from(path)) +} + +#[allow(clippy::too_many_lines)] +fn run_resume_command( + session_path: &Path, + session: &Session, + command: &SlashCommand, +) -> Result> { + match command { + SlashCommand::Help => Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(render_repl_help()), + }), + SlashCommand::Compact => { + let result = runtime::compact_session( + session, + CompactionConfig { + max_estimated_tokens: 0, + ..CompactionConfig::default() + }, + ); + let removed = result.removed_message_count; + let kept = result.compacted_session.messages.len(); + let skipped = removed == 0; + result.compacted_session.save_to_path(session_path)?; + Ok(ResumeCommandOutcome { + session: result.compacted_session, + message: Some(format_compact_report(removed, kept, skipped)), + }) + } + SlashCommand::Clear { confirm } => { + if !confirm { + return Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some( + "clear: confirmation required; rerun with /clear --confirm".to_string(), + ), + }); + } + let cleared = Session::new(); + cleared.save_to_path(session_path)?; + Ok(ResumeCommandOutcome { + session: cleared, + message: Some(format!( + "Cleared resumed session file {}.", + session_path.display() + )), + }) + } + SlashCommand::Status => { + let tracker = UsageTracker::from_session(session); + let usage = tracker.cumulative_usage(); + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(format_status_report( + "restored-session", + StatusUsage { + message_count: session.messages.len(), + turns: tracker.turns(), + latest: tracker.current_turn_usage(), + cumulative: usage, + estimated_tokens: 0, + }, + default_permission_mode().as_str(), + &status_context(Some(session_path))?, + )), + }) + } + SlashCommand::Cost => { + let usage = UsageTracker::from_session(session).cumulative_usage(); + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(format_cost_report(usage)), + }) + } + SlashCommand::Config { section } => Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(render_config_report(section.as_deref())?), + }), + SlashCommand::Memory => Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(render_memory_report()?), + }), + SlashCommand::Init => Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(init_claw_md()?), + }), + SlashCommand::Diff => Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(render_diff_report()?), + }), + SlashCommand::Version => Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(render_version_report()), + }), + SlashCommand::Export { path } => { + let export_path = resolve_export_path(path.as_deref(), session)?; + fs::write(&export_path, render_export_text(session))?; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(format!( + "Export\n Result wrote transcript\n File {}\n Messages {}", + export_path.display(), + session.messages.len(), + )), + }) + } + SlashCommand::Agents { args } => { + let cwd = env::current_dir()?; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(handle_agents_slash_command(args.as_deref(), &cwd)?), + }) + } + SlashCommand::Skills { args } => { + let cwd = env::current_dir()?; + Ok(ResumeCommandOutcome { + session: session.clone(), + message: Some(handle_skills_slash_command(args.as_deref(), &cwd)?), + }) + } + SlashCommand::Bughunter { .. } + | SlashCommand::Branch { .. } + | SlashCommand::Worktree { .. } + | SlashCommand::CommitPushPr { .. } + | SlashCommand::Commit + | SlashCommand::Pr { .. } + | SlashCommand::Issue { .. } + | SlashCommand::Ultraplan { .. } + | SlashCommand::Teleport { .. } + | SlashCommand::DebugToolCall + | SlashCommand::Resume { .. } + | SlashCommand::Model { .. } + | SlashCommand::Permissions { .. } + | SlashCommand::Session { .. } + | SlashCommand::Plugins { .. } + | SlashCommand::Unknown(_) => Err("unsupported resumed slash command".into()), + } +} + +fn run_repl( + model: String, + allowed_tools: Option, + permission_mode: PermissionMode, +) -> Result<(), Box> { + let mut cli = LiveCli::new(model, true, allowed_tools, permission_mode)?; + let mut editor = input::LineEditor::new("> ", slash_command_completion_candidates()); + println!("{}", cli.startup_banner()); + + loop { + match editor.read_line()? { + input::ReadOutcome::Submit(input) => { + let trimmed = input.trim(); + if trimmed.is_empty() { + continue; + } + if matches!(trimmed, "/exit" | "/quit") { + cli.persist_session()?; + break; + } + if let Some(command) = SlashCommand::parse(trimmed) { + if cli.handle_repl_command(command)? { + cli.persist_session()?; + } + continue; + } + editor.push_history(&input); + cli.run_turn(&input)?; + } + input::ReadOutcome::Cancel => {} + input::ReadOutcome::Exit => { + cli.persist_session()?; + break; + } + } + } + + Ok(()) +} + +#[derive(Debug, Clone)] +struct SessionHandle { + id: String, + path: PathBuf, +} + +#[derive(Debug, Clone)] +struct ManagedSessionSummary { + id: String, + path: PathBuf, + modified_epoch_secs: u64, + message_count: usize, +} + +struct LiveCli { + model: String, + allowed_tools: Option, + permission_mode: PermissionMode, + system_prompt: Vec, + runtime: ConversationRuntime, + session: SessionHandle, +} + +impl LiveCli { + fn new( + model: String, + enable_tools: bool, + allowed_tools: Option, + permission_mode: PermissionMode, + ) -> Result> { + let system_prompt = build_system_prompt()?; + let session = create_managed_session_handle()?; + let runtime = build_runtime( + Session::new(), + model.clone(), + system_prompt.clone(), + enable_tools, + true, + allowed_tools.clone(), + permission_mode, + None, + )?; + let cli = Self { + model, + allowed_tools, + permission_mode, + system_prompt, + runtime, + session, + }; + cli.persist_session()?; + Ok(cli) + } + + fn startup_banner(&self) -> String { + let color = io::stdout().is_terminal(); + let cwd = env::current_dir().ok(); + let cwd_display = cwd.as_ref().map_or_else( + || "".to_string(), + |path| path.display().to_string(), + ); + let workspace_name = cwd + .as_ref() + .and_then(|path| path.file_name()) + .and_then(|name| name.to_str()) + .unwrap_or("workspace"); + let git_branch = status_context(Some(&self.session.path)) + .ok() + .and_then(|context| context.git_branch); + let workspace_summary = git_branch.as_deref().map_or_else( + || workspace_name.to_string(), + |branch| format!("{workspace_name} · {branch}"), + ); + let has_claw_md = cwd + .as_ref() + .is_some_and(|path| path.join("CLAW.md").is_file()); + let mut lines = vec![ + format!( + "{} {}", + if color { + "\x1b[1;38;5;45m🦞 Claw Code\x1b[0m" + } else { + "Claw Code" + }, + if color { + "\x1b[2m· ready\x1b[0m" + } else { + "· ready" + } + ), + format!(" Workspace {workspace_summary}"), + format!(" Directory {cwd_display}"), + format!(" Model {}", self.model), + format!(" Permissions {}", self.permission_mode.as_str()), + format!(" Session {}", self.session.id), + format!( + " Quick start {}", + if has_claw_md { + "/help · /status · ask for a task" + } else { + "/init · /help · /status" + } + ), + " Editor Tab completes slash commands · /vim toggles modal editing" + .to_string(), + " Multiline Shift+Enter or Ctrl+J inserts a newline".to_string(), + ]; + if !has_claw_md { + lines.push( + " First run /init scaffolds CLAW.md, .claw.json, and local session files" + .to_string(), + ); + } + lines.join("\n") + } + + fn run_turn(&mut self, input: &str) -> Result<(), Box> { + let mut spinner = Spinner::new(); + let mut stdout = io::stdout(); + spinner.tick( + "🦀 Thinking...", + TerminalRenderer::new().color_theme(), + &mut stdout, + )?; + let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); + let result = self.runtime.run_turn(input, Some(&mut permission_prompter)); + match result { + Ok(_) => { + spinner.finish( + "✨ Done", + TerminalRenderer::new().color_theme(), + &mut stdout, + )?; + println!(); + self.persist_session()?; + Ok(()) + } + Err(error) => { + spinner.fail( + "❌ Request failed", + TerminalRenderer::new().color_theme(), + &mut stdout, + )?; + Err(Box::new(error)) + } + } + } + + fn run_turn_with_output( + &mut self, + input: &str, + output_format: CliOutputFormat, + ) -> Result<(), Box> { + match output_format { + CliOutputFormat::Text => self.run_turn(input), + CliOutputFormat::Json => self.run_prompt_json(input), + } + } + + fn run_prompt_json(&mut self, input: &str) -> Result<(), Box> { + let session = self.runtime.session().clone(); + let mut runtime = build_runtime( + session, + self.model.clone(), + self.system_prompt.clone(), + true, + false, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); + let summary = runtime.run_turn(input, Some(&mut permission_prompter))?; + self.runtime = runtime; + self.persist_session()?; + println!( + "{}", + json!({ + "message": final_assistant_text(&summary), + "model": self.model, + "iterations": summary.iterations, + "tool_uses": collect_tool_uses(&summary), + "tool_results": collect_tool_results(&summary), + "usage": { + "input_tokens": summary.usage.input_tokens, + "output_tokens": summary.usage.output_tokens, + "cache_creation_input_tokens": summary.usage.cache_creation_input_tokens, + "cache_read_input_tokens": summary.usage.cache_read_input_tokens, + } + }) + ); + Ok(()) + } + + fn handle_repl_command( + &mut self, + command: SlashCommand, + ) -> Result> { + Ok(match command { + SlashCommand::Help => { + println!("{}", render_repl_help()); + false + } + SlashCommand::Status => { + self.print_status(); + false + } + SlashCommand::Bughunter { scope } => { + self.run_bughunter(scope.as_deref())?; + false + } + SlashCommand::Commit => { + self.run_commit()?; + true + } + SlashCommand::Pr { context } => { + self.run_pr(context.as_deref())?; + false + } + SlashCommand::Issue { context } => { + self.run_issue(context.as_deref())?; + false + } + SlashCommand::Ultraplan { task } => { + self.run_ultraplan(task.as_deref())?; + false + } + SlashCommand::Teleport { target } => { + self.run_teleport(target.as_deref())?; + false + } + SlashCommand::DebugToolCall => { + self.run_debug_tool_call()?; + false + } + SlashCommand::Compact => { + self.compact()?; + false + } + SlashCommand::Model { model } => self.set_model(model)?, + SlashCommand::Permissions { mode } => self.set_permissions(mode)?, + SlashCommand::Clear { confirm } => self.clear_session(confirm)?, + SlashCommand::Cost => { + self.print_cost(); + false + } + SlashCommand::Resume { session_path } => self.resume_session(session_path)?, + SlashCommand::Config { section } => { + Self::print_config(section.as_deref())?; + false + } + SlashCommand::Memory => { + Self::print_memory()?; + false + } + SlashCommand::Init => { + run_init()?; + false + } + SlashCommand::Diff => { + Self::print_diff()?; + false + } + SlashCommand::Version => { + Self::print_version(); + false + } + SlashCommand::Export { path } => { + self.export_session(path.as_deref())?; + false + } + SlashCommand::Session { action, target } => { + self.handle_session_command(action.as_deref(), target.as_deref())? + } + SlashCommand::Plugins { action, target } => { + self.handle_plugins_command(action.as_deref(), target.as_deref())? + } + SlashCommand::Agents { args } => { + Self::print_agents(args.as_deref())?; + false + } + SlashCommand::Skills { args } => { + Self::print_skills(args.as_deref())?; + false + } + SlashCommand::Branch { .. } => { + eprintln!( + "{}", + render_mode_unavailable("branch", "git branch commands") + ); + false + } + SlashCommand::Worktree { .. } => { + eprintln!( + "{}", + render_mode_unavailable("worktree", "git worktree commands") + ); + false + } + SlashCommand::CommitPushPr { .. } => { + eprintln!( + "{}", + render_mode_unavailable("commit-push-pr", "commit + push + PR automation") + ); + false + } + SlashCommand::Unknown(name) => { + eprintln!("{}", render_unknown_repl_command(&name)); + false + } + }) + } + + fn persist_session(&self) -> Result<(), Box> { + self.runtime.session().save_to_path(&self.session.path)?; + Ok(()) + } + + fn print_status(&self) { + let cumulative = self.runtime.usage().cumulative_usage(); + let latest = self.runtime.usage().current_turn_usage(); + println!( + "{}", + format_status_report( + &self.model, + StatusUsage { + message_count: self.runtime.session().messages.len(), + turns: self.runtime.usage().turns(), + latest, + cumulative, + estimated_tokens: self.runtime.estimated_tokens(), + }, + self.permission_mode.as_str(), + &status_context(Some(&self.session.path)).expect("status context should load"), + ) + ); + } + + fn set_model(&mut self, model: Option) -> Result> { + let Some(model) = model else { + println!( + "{}", + format_model_report( + &self.model, + self.runtime.session().messages.len(), + self.runtime.usage().turns(), + ) + ); + return Ok(false); + }; + + let model = resolve_model_alias(&model).to_string(); + + if model == self.model { + println!( + "{}", + format_model_report( + &self.model, + self.runtime.session().messages.len(), + self.runtime.usage().turns(), + ) + ); + return Ok(false); + } + + let previous = self.model.clone(); + let session = self.runtime.session().clone(); + let message_count = session.messages.len(); + self.runtime = build_runtime( + session, + model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + self.model.clone_from(&model); + println!( + "{}", + format_model_switch_report(&previous, &model, message_count) + ); + Ok(true) + } + + fn set_permissions( + &mut self, + mode: Option, + ) -> Result> { + let Some(mode) = mode else { + println!( + "{}", + format_permissions_report(self.permission_mode.as_str()) + ); + return Ok(false); + }; + + let normalized = normalize_permission_mode(&mode).ok_or_else(|| { + format!( + "unsupported permission mode '{mode}'. Use read-only, workspace-write, or danger-full-access." + ) + })?; + + if normalized == self.permission_mode.as_str() { + println!("{}", format_permissions_report(normalized)); + return Ok(false); + } + + let previous = self.permission_mode.as_str().to_string(); + let session = self.runtime.session().clone(); + self.permission_mode = permission_mode_from_label(normalized); + self.runtime = build_runtime( + session, + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + println!( + "{}", + format_permissions_switch_report(&previous, normalized) + ); + Ok(true) + } + + fn clear_session(&mut self, confirm: bool) -> Result> { + if !confirm { + println!( + "clear: confirmation required; run /clear --confirm to start a fresh session." + ); + return Ok(false); + } + + self.session = create_managed_session_handle()?; + self.runtime = build_runtime( + Session::new(), + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + println!( + "Session cleared\n Mode fresh session\n Preserved model {}\n Permission mode {}\n Session {}", + self.model, + self.permission_mode.as_str(), + self.session.id, + ); + Ok(true) + } + + fn print_cost(&self) { + let cumulative = self.runtime.usage().cumulative_usage(); + println!("{}", format_cost_report(cumulative)); + } + + fn resume_session( + &mut self, + session_path: Option, + ) -> Result> { + let Some(session_ref) = session_path else { + println!("Usage: /resume "); + return Ok(false); + }; + + let handle = resolve_session_reference(&session_ref)?; + let session = Session::load_from_path(&handle.path)?; + let message_count = session.messages.len(); + self.runtime = build_runtime( + session, + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + self.session = handle; + println!( + "{}", + format_resume_report( + &self.session.path.display().to_string(), + message_count, + self.runtime.usage().turns(), + ) + ); + Ok(true) + } + + fn print_config(section: Option<&str>) -> Result<(), Box> { + println!("{}", render_config_report(section)?); + Ok(()) + } + + fn print_memory() -> Result<(), Box> { + println!("{}", render_memory_report()?); + Ok(()) + } + + fn print_agents(args: Option<&str>) -> Result<(), Box> { + let cwd = env::current_dir()?; + println!("{}", handle_agents_slash_command(args, &cwd)?); + Ok(()) + } + + fn print_skills(args: Option<&str>) -> Result<(), Box> { + let cwd = env::current_dir()?; + println!("{}", handle_skills_slash_command(args, &cwd)?); + Ok(()) + } + + fn print_diff() -> Result<(), Box> { + println!("{}", render_diff_report()?); + Ok(()) + } + + fn print_version() { + println!("{}", render_version_report()); + } + + fn export_session( + &self, + requested_path: Option<&str>, + ) -> Result<(), Box> { + let export_path = resolve_export_path(requested_path, self.runtime.session())?; + fs::write(&export_path, render_export_text(self.runtime.session()))?; + println!( + "Export\n Result wrote transcript\n File {}\n Messages {}", + export_path.display(), + self.runtime.session().messages.len(), + ); + Ok(()) + } + + fn handle_session_command( + &mut self, + action: Option<&str>, + target: Option<&str>, + ) -> Result> { + match action { + None | Some("list") => { + println!("{}", render_session_list(&self.session.id)?); + Ok(false) + } + Some("switch") => { + let Some(target) = target else { + println!("Usage: /session switch "); + return Ok(false); + }; + let handle = resolve_session_reference(target)?; + let session = Session::load_from_path(&handle.path)?; + let message_count = session.messages.len(); + self.runtime = build_runtime( + session, + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + self.session = handle; + println!( + "Session switched\n Active session {}\n File {}\n Messages {}", + self.session.id, + self.session.path.display(), + message_count, + ); + Ok(true) + } + Some(other) => { + println!("Unknown /session action '{other}'. Use /session list or /session switch ."); + Ok(false) + } + } + } + + fn handle_plugins_command( + &mut self, + action: Option<&str>, + target: Option<&str>, + ) -> Result> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader.load()?; + let mut manager = build_plugin_manager(&cwd, &loader, &runtime_config); + let result = handle_plugins_slash_command(action, target, &mut manager)?; + println!("{}", result.message); + if result.reload_runtime { + self.reload_runtime_features()?; + } + Ok(false) + } + + fn reload_runtime_features(&mut self) -> Result<(), Box> { + self.runtime = build_runtime( + self.runtime.session().clone(), + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + self.persist_session() + } + + fn compact(&mut self) -> Result<(), Box> { + let result = self.runtime.compact(CompactionConfig::default()); + let removed = result.removed_message_count; + let kept = result.compacted_session.messages.len(); + let skipped = removed == 0; + self.runtime = build_runtime( + result.compacted_session, + self.model.clone(), + self.system_prompt.clone(), + true, + true, + self.allowed_tools.clone(), + self.permission_mode, + None, + )?; + self.persist_session()?; + println!("{}", format_compact_report(removed, kept, skipped)); + Ok(()) + } + + fn run_internal_prompt_text_with_progress( + &self, + prompt: &str, + enable_tools: bool, + progress: Option, + ) -> Result> { + let session = self.runtime.session().clone(); + let mut runtime = build_runtime( + session, + self.model.clone(), + self.system_prompt.clone(), + enable_tools, + false, + self.allowed_tools.clone(), + self.permission_mode, + progress, + )?; + let mut permission_prompter = CliPermissionPrompter::new(self.permission_mode); + let summary = runtime.run_turn(prompt, Some(&mut permission_prompter))?; + Ok(final_assistant_text(&summary).trim().to_string()) + } + + fn run_internal_prompt_text( + &self, + prompt: &str, + enable_tools: bool, + ) -> Result> { + self.run_internal_prompt_text_with_progress(prompt, enable_tools, None) + } + + fn run_bughunter(&self, scope: Option<&str>) -> Result<(), Box> { + let scope = scope.unwrap_or("the current repository"); + let prompt = format!( + "You are /bughunter. Inspect {scope} and identify the most likely bugs or correctness issues. Prioritize concrete findings with file paths, severity, and suggested fixes. Use tools if needed." + ); + println!("{}", self.run_internal_prompt_text(&prompt, true)?); + Ok(()) + } + + fn run_ultraplan(&self, task: Option<&str>) -> Result<(), Box> { + let task = task.unwrap_or("the current repo work"); + let prompt = format!( + "You are /ultraplan. Produce a deep multi-step execution plan for {task}. Include goals, risks, implementation sequence, verification steps, and rollback considerations. Use tools if needed." + ); + let mut progress = InternalPromptProgressRun::start_ultraplan(task); + match self.run_internal_prompt_text_with_progress(&prompt, true, Some(progress.reporter())) + { + Ok(plan) => { + progress.finish_success(); + println!("{plan}"); + Ok(()) + } + Err(error) => { + progress.finish_failure(&error.to_string()); + Err(error) + } + } + } + + #[allow(clippy::unused_self)] + fn run_teleport(&self, target: Option<&str>) -> Result<(), Box> { + let Some(target) = target.map(str::trim).filter(|value| !value.is_empty()) else { + println!("Usage: /teleport "); + return Ok(()); + }; + + println!("{}", render_teleport_report(target)?); + Ok(()) + } + + fn run_debug_tool_call(&self) -> Result<(), Box> { + println!("{}", render_last_tool_debug_report(self.runtime.session())?); + Ok(()) + } + + fn run_commit(&mut self) -> Result<(), Box> { + let status = git_output(&["status", "--short"])?; + if status.trim().is_empty() { + println!("Commit\n Result skipped\n Reason no workspace changes"); + return Ok(()); + } + + git_status_ok(&["add", "-A"])?; + let staged_stat = git_output(&["diff", "--cached", "--stat"])?; + let prompt = format!( + "Generate a git commit message in plain text Lore format only. Base it on this staged diff summary:\n\n{}\n\nRecent conversation context:\n{}", + truncate_for_prompt(&staged_stat, 8_000), + recent_user_context(self.runtime.session(), 6) + ); + let message = sanitize_generated_message(&self.run_internal_prompt_text(&prompt, false)?); + if message.trim().is_empty() { + return Err("generated commit message was empty".into()); + } + + let path = write_temp_text_file("claw-commit-message.txt", &message)?; + let output = Command::new("git") + .args(["commit", "--file"]) + .arg(&path) + .current_dir(env::current_dir()?) + .output()?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(format!("git commit failed: {stderr}").into()); + } + + println!( + "Commit\n Result created\n Message file {}\n\n{}", + path.display(), + message.trim() + ); + Ok(()) + } + + fn run_pr(&self, context: Option<&str>) -> Result<(), Box> { + let staged = git_output(&["diff", "--stat"])?; + let prompt = format!( + "Generate a pull request title and body from this conversation and diff summary. Output plain text in this format exactly:\nTITLE: \nBODY:\n<body markdown>\n\nContext hint: {}\n\nDiff summary:\n{}", + context.unwrap_or("none"), + truncate_for_prompt(&staged, 10_000) + ); + let draft = sanitize_generated_message(&self.run_internal_prompt_text(&prompt, false)?); + let (title, body) = parse_titled_body(&draft) + .ok_or_else(|| "failed to parse generated PR title/body".to_string())?; + + if command_exists("gh") { + let body_path = write_temp_text_file("claw-pr-body.md", &body)?; + let output = Command::new("gh") + .args(["pr", "create", "--title", &title, "--body-file"]) + .arg(&body_path) + .current_dir(env::current_dir()?) + .output()?; + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + println!( + "PR\n Result created\n Title {title}\n URL {}", + if stdout.is_empty() { "<unknown>" } else { &stdout } + ); + return Ok(()); + } + } + + println!("PR draft\n Title {title}\n\n{body}"); + Ok(()) + } + + fn run_issue(&self, context: Option<&str>) -> Result<(), Box<dyn std::error::Error>> { + let prompt = format!( + "Generate a GitHub issue title and body from this conversation. Output plain text in this format exactly:\nTITLE: <title>\nBODY:\n<body markdown>\n\nContext hint: {}\n\nConversation context:\n{}", + context.unwrap_or("none"), + truncate_for_prompt(&recent_user_context(self.runtime.session(), 10), 10_000) + ); + let draft = sanitize_generated_message(&self.run_internal_prompt_text(&prompt, false)?); + let (title, body) = parse_titled_body(&draft) + .ok_or_else(|| "failed to parse generated issue title/body".to_string())?; + + if command_exists("gh") { + let body_path = write_temp_text_file("claw-issue-body.md", &body)?; + let output = Command::new("gh") + .args(["issue", "create", "--title", &title, "--body-file"]) + .arg(&body_path) + .current_dir(env::current_dir()?) + .output()?; + if output.status.success() { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + println!( + "Issue\n Result created\n Title {title}\n URL {}", + if stdout.is_empty() { "<unknown>" } else { &stdout } + ); + return Ok(()); + } + } + + println!("Issue draft\n Title {title}\n\n{body}"); + Ok(()) + } +} + +fn sessions_dir() -> Result<PathBuf, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + let path = cwd.join(".claw").join("sessions"); + fs::create_dir_all(&path)?; + Ok(path) +} + +fn create_managed_session_handle() -> Result<SessionHandle, Box<dyn std::error::Error>> { + let id = generate_session_id(); + let path = sessions_dir()?.join(format!("{id}.json")); + Ok(SessionHandle { id, path }) +} + +fn generate_session_id() -> String { + let millis = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_millis()) + .unwrap_or_default(); + format!("session-{millis}") +} + +fn resolve_session_reference(reference: &str) -> Result<SessionHandle, Box<dyn std::error::Error>> { + let direct = PathBuf::from(reference); + let path = if direct.exists() { + direct + } else { + sessions_dir()?.join(format!("{reference}.json")) + }; + if !path.exists() { + return Err(format!("session not found: {reference}").into()); + } + let id = path + .file_stem() + .and_then(|value| value.to_str()) + .unwrap_or(reference) + .to_string(); + Ok(SessionHandle { id, path }) +} + +fn list_managed_sessions() -> Result<Vec<ManagedSessionSummary>, Box<dyn std::error::Error>> { + let mut sessions = Vec::new(); + for entry in fs::read_dir(sessions_dir()?)? { + let entry = entry?; + let path = entry.path(); + if path.extension().and_then(|ext| ext.to_str()) != Some("json") { + continue; + } + let metadata = entry.metadata()?; + let modified_epoch_secs = metadata + .modified() + .ok() + .and_then(|time| time.duration_since(UNIX_EPOCH).ok()) + .map(|duration| duration.as_secs()) + .unwrap_or_default(); + let message_count = Session::load_from_path(&path) + .map(|session| session.messages.len()) + .unwrap_or_default(); + let id = path + .file_stem() + .and_then(|value| value.to_str()) + .unwrap_or("unknown") + .to_string(); + sessions.push(ManagedSessionSummary { + id, + path, + modified_epoch_secs, + message_count, + }); + } + sessions.sort_by(|left, right| right.modified_epoch_secs.cmp(&left.modified_epoch_secs)); + Ok(sessions) +} + +fn format_relative_timestamp(epoch_secs: u64) -> String { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_secs()) + .unwrap_or(epoch_secs); + let elapsed = now.saturating_sub(epoch_secs); + match elapsed { + 0..=59 => format!("{elapsed}s ago"), + 60..=3_599 => format!("{}m ago", elapsed / 60), + 3_600..=86_399 => format!("{}h ago", elapsed / 3_600), + _ => format!("{}d ago", elapsed / 86_400), + } +} + +fn render_session_list(active_session_id: &str) -> Result<String, Box<dyn std::error::Error>> { + let sessions = list_managed_sessions()?; + let mut lines = vec![ + "Sessions".to_string(), + format!(" Directory {}", sessions_dir()?.display()), + ]; + if sessions.is_empty() { + lines.push(" No managed sessions saved yet.".to_string()); + return Ok(lines.join("\n")); + } + for session in sessions { + let marker = if session.id == active_session_id { + "● current" + } else { + "○ saved" + }; + lines.push(format!( + " {id:<20} {marker:<10} {msgs:>3} msgs · updated {modified}", + id = session.id, + msgs = session.message_count, + modified = format_relative_timestamp(session.modified_epoch_secs), + )); + lines.push(format!(" {}", session.path.display())); + } + Ok(lines.join("\n")) +} + +fn render_repl_help() -> String { + [ + "Interactive REPL".to_string(), + " Quick start Ask a task in plain English or use one of the core commands below." + .to_string(), + " Core commands /help · /status · /model · /permissions · /compact".to_string(), + " Exit /exit or /quit".to_string(), + " Vim mode /vim toggles modal editing".to_string(), + " History Up/Down recalls previous prompts".to_string(), + " Completion Tab cycles slash command matches".to_string(), + " Cancel Ctrl-C clears input (or exits on an empty prompt)".to_string(), + " Multiline Shift+Enter or Ctrl+J inserts a newline".to_string(), + String::new(), + render_slash_command_help(), + ] + .join( + " +", + ) +} + +fn append_slash_command_suggestions(lines: &mut Vec<String>, name: &str) { + let suggestions = suggest_slash_commands(name, 3); + if suggestions.is_empty() { + lines.push(" Try /help shows the full slash command map".to_string()); + return; + } + + lines.push(" Try /help shows the full slash command map".to_string()); + lines.push("Suggestions".to_string()); + lines.extend( + suggestions + .into_iter() + .map(|suggestion| format!(" {suggestion}")), + ); +} + +fn render_unknown_repl_command(name: &str) -> String { + let mut lines = vec![ + "Unknown slash command".to_string(), + format!(" Command /{name}"), + ]; + append_repl_command_suggestions(&mut lines, name); + lines.join("\n") +} + +fn append_repl_command_suggestions(lines: &mut Vec<String>, name: &str) { + let suggestions = suggest_repl_commands(name); + if suggestions.is_empty() { + lines.push(" Try /help shows the full slash command map".to_string()); + return; + } + + lines.push(" Try /help shows the full slash command map".to_string()); + lines.push("Suggestions".to_string()); + lines.extend( + suggestions + .into_iter() + .map(|suggestion| format!(" {suggestion}")), + ); +} + +fn render_mode_unavailable(command: &str, label: &str) -> String { + [ + "Command unavailable in this REPL mode".to_string(), + format!(" Command /{command}"), + format!(" Feature {label}"), + " Tip Use /help to find currently wired REPL commands".to_string(), + ] + .join("\n") +} + +fn status_context( + session_path: Option<&Path>, +) -> Result<StatusContext, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let discovered_config_files = loader.discover().len(); + let runtime_config = loader.load()?; + let project_context = ProjectContext::discover_with_git(&cwd, DEFAULT_DATE)?; + let (project_root, git_branch) = + parse_git_status_metadata(project_context.git_status.as_deref()); + Ok(StatusContext { + cwd, + session_path: session_path.map(Path::to_path_buf), + loaded_config_files: runtime_config.loaded_entries().len(), + discovered_config_files, + memory_file_count: project_context.instruction_files.len(), + project_root, + git_branch, + }) +} + +fn format_status_report( + model: &str, + usage: StatusUsage, + permission_mode: &str, + context: &StatusContext, +) -> String { + [ + format!( + "Session + Model {model} + Permissions {permission_mode} + Activity {} messages · {} turns + Tokens est {} · latest {} · total {}", + usage.message_count, + usage.turns, + usage.estimated_tokens, + usage.latest.total_tokens(), + usage.cumulative.total_tokens(), + ), + format!( + "Usage + Cumulative input {} + Cumulative output {} + Cache create {} + Cache read {}", + usage.cumulative.input_tokens, + usage.cumulative.output_tokens, + usage.cumulative.cache_creation_input_tokens, + usage.cumulative.cache_read_input_tokens, + ), + format!( + "Workspace + Folder {} + Project root {} + Git branch {} + Session file {} + Config files loaded {}/{} + Memory files {} + +Next + /help Browse commands + /session list Inspect saved sessions + /diff Review current workspace changes", + context.cwd.display(), + context + .project_root + .as_ref() + .map_or_else(|| "unknown".to_string(), |path| path.display().to_string()), + context.git_branch.as_deref().unwrap_or("unknown"), + context.session_path.as_ref().map_or_else( + || "live-repl".to_string(), + |path| path.display().to_string() + ), + context.loaded_config_files, + context.discovered_config_files, + context.memory_file_count, + ), + ] + .join( + " + +", + ) +} + +fn render_config_report(section: Option<&str>) -> Result<String, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let discovered = loader.discover(); + let runtime_config = loader.load()?; + + let mut lines = vec![ + format!( + "Config + Working directory {} + Loaded files {} + Merged keys {}", + cwd.display(), + runtime_config.loaded_entries().len(), + runtime_config.merged().len() + ), + "Discovered files".to_string(), + ]; + for entry in discovered { + let source = match entry.source { + ConfigSource::User => "user", + ConfigSource::Project => "project", + ConfigSource::Local => "local", + }; + let status = if runtime_config + .loaded_entries() + .iter() + .any(|loaded_entry| loaded_entry.path == entry.path) + { + "loaded" + } else { + "missing" + }; + lines.push(format!( + " {source:<7} {status:<7} {}", + entry.path.display() + )); + } + + if let Some(section) = section { + lines.push(format!("Merged section: {section}")); + let value = match section { + "env" => runtime_config.get("env"), + "hooks" => runtime_config.get("hooks"), + "model" => runtime_config.get("model"), + "plugins" => runtime_config + .get("plugins") + .or_else(|| runtime_config.get("enabledPlugins")), + other => { + lines.push(format!( + " Unsupported config section '{other}'. Use env, hooks, model, or plugins." + )); + return Ok(lines.join( + " +", + )); + } + }; + lines.push(format!( + " {}", + match value { + Some(value) => value.render(), + None => "<unset>".to_string(), + } + )); + return Ok(lines.join( + " +", + )); + } + + lines.push("Merged JSON".to_string()); + lines.push(format!(" {}", runtime_config.as_json().render())); + Ok(lines.join( + " +", + )) +} + +fn render_memory_report() -> Result<String, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + let project_context = ProjectContext::discover(&cwd, DEFAULT_DATE)?; + let mut lines = vec![format!( + "Memory + Working directory {} + Instruction files {}", + cwd.display(), + project_context.instruction_files.len() + )]; + if project_context.instruction_files.is_empty() { + lines.push("Discovered files".to_string()); + lines.push( + " No CLAW instruction files discovered in the current directory ancestry.".to_string(), + ); + } else { + lines.push("Discovered files".to_string()); + for (index, file) in project_context.instruction_files.iter().enumerate() { + let preview = file.content.lines().next().unwrap_or("").trim(); + let preview = if preview.is_empty() { + "<empty>" + } else { + preview + }; + lines.push(format!(" {}. {}", index + 1, file.path.display(),)); + lines.push(format!( + " lines={} preview={}", + file.content.lines().count(), + preview + )); + } + } + Ok(lines.join( + " +", + )) +} + +fn init_claw_md() -> Result<String, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + Ok(initialize_repo(&cwd)?.render()) +} + +fn run_init() -> Result<(), Box<dyn std::error::Error>> { + println!("{}", init_claw_md()?); + Ok(()) +} + +fn normalize_permission_mode(mode: &str) -> Option<&'static str> { + match mode.trim() { + "read-only" => Some("read-only"), + "workspace-write" => Some("workspace-write"), + "danger-full-access" => Some("danger-full-access"), + _ => None, + } +} + +fn render_diff_report() -> Result<String, Box<dyn std::error::Error>> { + let output = std::process::Command::new("git") + .args(["diff", "--", ":(exclude).omx"]) + .current_dir(env::current_dir()?) + .output()?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(format!("git diff failed: {stderr}").into()); + } + let diff = String::from_utf8(output.stdout)?; + if diff.trim().is_empty() { + return Ok( + "Diff\n Result clean working tree\n Detail no current changes" + .to_string(), + ); + } + Ok(format!("Diff\n\n{}", diff.trim_end())) +} + +fn render_teleport_report(target: &str) -> Result<String, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + + let file_list = Command::new("rg") + .args(["--files"]) + .current_dir(&cwd) + .output()?; + let file_matches = if file_list.status.success() { + String::from_utf8(file_list.stdout)? + .lines() + .filter(|line| line.contains(target)) + .take(10) + .map(ToOwned::to_owned) + .collect::<Vec<_>>() + } else { + Vec::new() + }; + + let content_output = Command::new("rg") + .args(["-n", "-S", "--color", "never", target, "."]) + .current_dir(&cwd) + .output()?; + + let mut lines = vec![format!("Teleport\n Target {target}")]; + if !file_matches.is_empty() { + lines.push(String::new()); + lines.push("File matches".to_string()); + lines.extend(file_matches.into_iter().map(|path| format!(" {path}"))); + } + + if content_output.status.success() { + let matches = String::from_utf8(content_output.stdout)?; + if !matches.trim().is_empty() { + lines.push(String::new()); + lines.push("Content matches".to_string()); + lines.push(truncate_for_prompt(&matches, 4_000)); + } + } + + if lines.len() == 1 { + lines.push(" Result no matches found".to_string()); + } + + Ok(lines.join("\n")) +} + +fn render_last_tool_debug_report(session: &Session) -> Result<String, Box<dyn std::error::Error>> { + let last_tool_use = session + .messages + .iter() + .rev() + .find_map(|message| { + message.blocks.iter().rev().find_map(|block| match block { + ContentBlock::ToolUse { id, name, input } => { + Some((id.clone(), name.clone(), input.clone())) + } + _ => None, + }) + }) + .ok_or_else(|| "no prior tool call found in session".to_string())?; + + let tool_result = session.messages.iter().rev().find_map(|message| { + message.blocks.iter().rev().find_map(|block| match block { + ContentBlock::ToolResult { + tool_use_id, + tool_name, + output, + is_error, + } if tool_use_id == &last_tool_use.0 => { + Some((tool_name.clone(), output.clone(), *is_error)) + } + _ => None, + }) + }); + + let mut lines = vec![ + "Debug tool call".to_string(), + format!(" Tool id {}", last_tool_use.0), + format!(" Tool name {}", last_tool_use.1), + " Input".to_string(), + indent_block(&last_tool_use.2, 4), + ]; + + match tool_result { + Some((tool_name, output, is_error)) => { + lines.push(" Result".to_string()); + lines.push(format!(" name {tool_name}")); + lines.push(format!( + " status {}", + if is_error { "error" } else { "ok" } + )); + lines.push(indent_block(&output, 4)); + } + None => lines.push(" Result missing tool result".to_string()), + } + + Ok(lines.join("\n")) +} + +fn indent_block(value: &str, spaces: usize) -> String { + let indent = " ".repeat(spaces); + value + .lines() + .map(|line| format!("{indent}{line}")) + .collect::<Vec<_>>() + .join("\n") +} + +fn git_output(args: &[&str]) -> Result<String, Box<dyn std::error::Error>> { + let output = Command::new("git") + .args(args) + .current_dir(env::current_dir()?) + .output()?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(format!("git {} failed: {stderr}", args.join(" ")).into()); + } + Ok(String::from_utf8(output.stdout)?) +} + +fn git_status_ok(args: &[&str]) -> Result<(), Box<dyn std::error::Error>> { + let output = Command::new("git") + .args(args) + .current_dir(env::current_dir()?) + .output()?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(format!("git {} failed: {stderr}", args.join(" ")).into()); + } + Ok(()) +} + +fn command_exists(name: &str) -> bool { + Command::new("which") + .arg(name) + .output() + .map(|output| output.status.success()) + .unwrap_or(false) +} + +fn write_temp_text_file( + filename: &str, + contents: &str, +) -> Result<PathBuf, Box<dyn std::error::Error>> { + let path = env::temp_dir().join(filename); + fs::write(&path, contents)?; + Ok(path) +} + +fn recent_user_context(session: &Session, limit: usize) -> String { + let requests = session + .messages + .iter() + .filter(|message| message.role == MessageRole::User) + .filter_map(|message| { + message.blocks.iter().find_map(|block| match block { + ContentBlock::Text { text } => Some(text.trim().to_string()), + _ => None, + }) + }) + .rev() + .take(limit) + .collect::<Vec<_>>(); + + if requests.is_empty() { + "<no prior user messages>".to_string() + } else { + requests + .into_iter() + .rev() + .enumerate() + .map(|(index, text)| format!("{}. {}", index + 1, text)) + .collect::<Vec<_>>() + .join("\n") + } +} + +fn truncate_for_prompt(value: &str, limit: usize) -> String { + if value.chars().count() <= limit { + value.trim().to_string() + } else { + let truncated = value.chars().take(limit).collect::<String>(); + format!("{}\n…[truncated]", truncated.trim_end()) + } +} + +fn sanitize_generated_message(value: &str) -> String { + value.trim().trim_matches('`').trim().replace("\r\n", "\n") +} + +fn parse_titled_body(value: &str) -> Option<(String, String)> { + let normalized = sanitize_generated_message(value); + let title = normalized + .lines() + .find_map(|line| line.strip_prefix("TITLE:").map(str::trim))?; + let body_start = normalized.find("BODY:")?; + let body = normalized[body_start + "BODY:".len()..].trim(); + Some((title.to_string(), body.to_string())) +} + +fn render_version_report() -> String { + let git_sha = GIT_SHA.unwrap_or("unknown"); + let target = BUILD_TARGET.unwrap_or("unknown"); + format!( + "Claw Code\n Version {VERSION}\n Git SHA {git_sha}\n Target {target}\n Build date {DEFAULT_DATE}\n\nSupport\n Help claw --help\n REPL /help" + ) +} + +fn render_export_text(session: &Session) -> String { + let mut lines = vec!["# Conversation Export".to_string(), String::new()]; + for (index, message) in session.messages.iter().enumerate() { + let role = match message.role { + MessageRole::System => "system", + MessageRole::User => "user", + MessageRole::Assistant => "assistant", + MessageRole::Tool => "tool", + }; + lines.push(format!("## {}. {role}", index + 1)); + for block in &message.blocks { + match block { + ContentBlock::Text { text } => lines.push(text.clone()), + ContentBlock::ToolUse { id, name, input } => { + lines.push(format!("[tool_use id={id} name={name}] {input}")); + } + ContentBlock::ToolResult { + tool_use_id, + tool_name, + output, + is_error, + } => { + lines.push(format!( + "[tool_result id={tool_use_id} name={tool_name} error={is_error}] {output}" + )); + } + } + } + lines.push(String::new()); + } + lines.join("\n") +} + +fn default_export_filename(session: &Session) -> String { + let stem = session + .messages + .iter() + .find_map(|message| match message.role { + MessageRole::User => message.blocks.iter().find_map(|block| match block { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }), + _ => None, + }) + .map_or("conversation", |text| { + text.lines().next().unwrap_or("conversation") + }) + .chars() + .map(|ch| { + if ch.is_ascii_alphanumeric() { + ch.to_ascii_lowercase() + } else { + '-' + } + }) + .collect::<String>() + .split('-') + .filter(|part| !part.is_empty()) + .take(8) + .collect::<Vec<_>>() + .join("-"); + let fallback = if stem.is_empty() { + "conversation" + } else { + &stem + }; + format!("{fallback}.txt") +} + +fn resolve_export_path( + requested_path: Option<&str>, + session: &Session, +) -> Result<PathBuf, Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + let file_name = + requested_path.map_or_else(|| default_export_filename(session), ToOwned::to_owned); + let final_name = if Path::new(&file_name) + .extension() + .is_some_and(|ext| ext.eq_ignore_ascii_case("txt")) + { + file_name + } else { + format!("{file_name}.txt") + }; + Ok(cwd.join(final_name)) +} + +fn build_system_prompt() -> Result<Vec<String>, Box<dyn std::error::Error>> { + Ok(load_system_prompt( + env::current_dir()?, + DEFAULT_DATE, + env::consts::OS, + "unknown", + )?) +} + +fn build_runtime_plugin_state( +) -> Result<(runtime::RuntimeFeatureConfig, GlobalToolRegistry), Box<dyn std::error::Error>> { + let cwd = env::current_dir()?; + let loader = ConfigLoader::default_for(&cwd); + let runtime_config = loader.load()?; + let plugin_manager = build_plugin_manager(&cwd, &loader, &runtime_config); + let tool_registry = GlobalToolRegistry::with_plugin_tools(plugin_manager.aggregated_tools()?)?; + Ok((runtime_config.feature_config().clone(), tool_registry)) +} + +fn build_plugin_manager( + cwd: &Path, + loader: &ConfigLoader, + runtime_config: &runtime::RuntimeConfig, +) -> PluginManager { + let plugin_settings = runtime_config.plugins(); + let mut plugin_config = PluginManagerConfig::new(loader.config_home().to_path_buf()); + plugin_config.enabled_plugins = plugin_settings.enabled_plugins().clone(); + plugin_config.external_dirs = plugin_settings + .external_directories() + .iter() + .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)) + .collect(); + plugin_config.install_root = plugin_settings + .install_root() + .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); + plugin_config.registry_path = plugin_settings + .registry_path() + .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); + plugin_config.bundled_root = plugin_settings + .bundled_root() + .map(|path| resolve_plugin_path(cwd, loader.config_home(), path)); + PluginManager::new(plugin_config) +} + +fn resolve_plugin_path(cwd: &Path, config_home: &Path, value: &str) -> PathBuf { + let path = PathBuf::from(value); + if path.is_absolute() { + path + } else if value.starts_with('.') { + cwd.join(path) + } else { + config_home.join(path) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct InternalPromptProgressState { + command_label: &'static str, + task_label: String, + step: usize, + phase: String, + detail: Option<String>, + saw_final_text: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum InternalPromptProgressEvent { + Started, + Update, + Heartbeat, + Complete, + Failed, +} + +#[derive(Debug)] +struct InternalPromptProgressShared { + state: Mutex<InternalPromptProgressState>, + output_lock: Mutex<()>, + started_at: Instant, +} + +#[derive(Debug, Clone)] +struct InternalPromptProgressReporter { + shared: Arc<InternalPromptProgressShared>, +} + +#[derive(Debug)] +struct InternalPromptProgressRun { + reporter: InternalPromptProgressReporter, + heartbeat_stop: Option<mpsc::Sender<()>>, + heartbeat_handle: Option<thread::JoinHandle<()>>, +} + +impl InternalPromptProgressReporter { + fn ultraplan(task: &str) -> Self { + Self { + shared: Arc::new(InternalPromptProgressShared { + state: Mutex::new(InternalPromptProgressState { + command_label: "Ultraplan", + task_label: task.to_string(), + step: 0, + phase: "planning started".to_string(), + detail: Some(format!("task: {task}")), + saw_final_text: false, + }), + output_lock: Mutex::new(()), + started_at: Instant::now(), + }), + } + } + + fn emit(&self, event: InternalPromptProgressEvent, error: Option<&str>) { + let snapshot = self.snapshot(); + let line = format_internal_prompt_progress_line(event, &snapshot, self.elapsed(), error); + self.write_line(&line); + } + + fn mark_model_phase(&self) { + let snapshot = { + let mut state = self + .shared + .state + .lock() + .expect("internal prompt progress state poisoned"); + state.step += 1; + state.phase = if state.step == 1 { + "analyzing request".to_string() + } else { + "reviewing findings".to_string() + }; + state.detail = Some(format!("task: {}", state.task_label)); + state.clone() + }; + self.write_line(&format_internal_prompt_progress_line( + InternalPromptProgressEvent::Update, + &snapshot, + self.elapsed(), + None, + )); + } + + fn mark_tool_phase(&self, name: &str, input: &str) { + let detail = describe_tool_progress(name, input); + let snapshot = { + let mut state = self + .shared + .state + .lock() + .expect("internal prompt progress state poisoned"); + state.step += 1; + state.phase = format!("running {name}"); + state.detail = Some(detail); + state.clone() + }; + self.write_line(&format_internal_prompt_progress_line( + InternalPromptProgressEvent::Update, + &snapshot, + self.elapsed(), + None, + )); + } + + fn mark_text_phase(&self, text: &str) { + let trimmed = text.trim(); + if trimmed.is_empty() { + return; + } + let detail = truncate_for_summary(first_visible_line(trimmed), 120); + let snapshot = { + let mut state = self + .shared + .state + .lock() + .expect("internal prompt progress state poisoned"); + if state.saw_final_text { + return; + } + state.saw_final_text = true; + state.step += 1; + state.phase = "drafting final plan".to_string(); + state.detail = (!detail.is_empty()).then_some(detail); + state.clone() + }; + self.write_line(&format_internal_prompt_progress_line( + InternalPromptProgressEvent::Update, + &snapshot, + self.elapsed(), + None, + )); + } + + fn emit_heartbeat(&self) { + let snapshot = self.snapshot(); + self.write_line(&format_internal_prompt_progress_line( + InternalPromptProgressEvent::Heartbeat, + &snapshot, + self.elapsed(), + None, + )); + } + + fn snapshot(&self) -> InternalPromptProgressState { + self.shared + .state + .lock() + .expect("internal prompt progress state poisoned") + .clone() + } + + fn elapsed(&self) -> Duration { + self.shared.started_at.elapsed() + } + + fn write_line(&self, line: &str) { + let _guard = self + .shared + .output_lock + .lock() + .expect("internal prompt progress output lock poisoned"); + let mut stdout = io::stdout(); + let _ = writeln!(stdout, "{line}"); + let _ = stdout.flush(); + } +} + +impl InternalPromptProgressRun { + fn start_ultraplan(task: &str) -> Self { + let reporter = InternalPromptProgressReporter::ultraplan(task); + reporter.emit(InternalPromptProgressEvent::Started, None); + + let (heartbeat_stop, heartbeat_rx) = mpsc::channel(); + let heartbeat_reporter = reporter.clone(); + let heartbeat_handle = thread::spawn(move || loop { + match heartbeat_rx.recv_timeout(INTERNAL_PROGRESS_HEARTBEAT_INTERVAL) { + Ok(()) | Err(RecvTimeoutError::Disconnected) => break, + Err(RecvTimeoutError::Timeout) => heartbeat_reporter.emit_heartbeat(), + } + }); + + Self { + reporter, + heartbeat_stop: Some(heartbeat_stop), + heartbeat_handle: Some(heartbeat_handle), + } + } + + fn reporter(&self) -> InternalPromptProgressReporter { + self.reporter.clone() + } + + fn finish_success(&mut self) { + self.stop_heartbeat(); + self.reporter + .emit(InternalPromptProgressEvent::Complete, None); + } + + fn finish_failure(&mut self, error: &str) { + self.stop_heartbeat(); + self.reporter + .emit(InternalPromptProgressEvent::Failed, Some(error)); + } + + fn stop_heartbeat(&mut self) { + if let Some(sender) = self.heartbeat_stop.take() { + let _ = sender.send(()); + } + if let Some(handle) = self.heartbeat_handle.take() { + let _ = handle.join(); + } + } +} + +impl Drop for InternalPromptProgressRun { + fn drop(&mut self) { + self.stop_heartbeat(); + } +} + +fn format_internal_prompt_progress_line( + event: InternalPromptProgressEvent, + snapshot: &InternalPromptProgressState, + elapsed: Duration, + error: Option<&str>, +) -> String { + let elapsed_seconds = elapsed.as_secs(); + let step_label = if snapshot.step == 0 { + "current step pending".to_string() + } else { + format!("current step {}", snapshot.step) + }; + let mut status_bits = vec![step_label, format!("phase {}", snapshot.phase)]; + if let Some(detail) = snapshot + .detail + .as_deref() + .filter(|detail| !detail.is_empty()) + { + status_bits.push(detail.to_string()); + } + let status = status_bits.join(" · "); + match event { + InternalPromptProgressEvent::Started => { + format!( + "🧭 {} status · planning started · {status}", + snapshot.command_label + ) + } + InternalPromptProgressEvent::Update => { + format!("… {} status · {status}", snapshot.command_label) + } + InternalPromptProgressEvent::Heartbeat => format!( + "… {} heartbeat · {elapsed_seconds}s elapsed · {status}", + snapshot.command_label + ), + InternalPromptProgressEvent::Complete => format!( + "✔ {} status · completed · {elapsed_seconds}s elapsed · {} steps total", + snapshot.command_label, snapshot.step + ), + InternalPromptProgressEvent::Failed => format!( + "✘ {} status · failed · {elapsed_seconds}s elapsed · {}", + snapshot.command_label, + error.unwrap_or("unknown error") + ), + } +} + +fn describe_tool_progress(name: &str, input: &str) -> String { + let parsed: serde_json::Value = + serde_json::from_str(input).unwrap_or(serde_json::Value::String(input.to_string())); + match name { + "bash" | "Bash" => { + let command = parsed + .get("command") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + if command.is_empty() { + "running shell command".to_string() + } else { + format!("command {}", truncate_for_summary(command.trim(), 100)) + } + } + "read_file" | "Read" => format!("reading {}", extract_tool_path(&parsed)), + "write_file" | "Write" => format!("writing {}", extract_tool_path(&parsed)), + "edit_file" | "Edit" => format!("editing {}", extract_tool_path(&parsed)), + "glob_search" | "Glob" => { + let pattern = parsed + .get("pattern") + .and_then(|value| value.as_str()) + .unwrap_or("?"); + let scope = parsed + .get("path") + .and_then(|value| value.as_str()) + .unwrap_or("."); + format!("glob `{pattern}` in {scope}") + } + "grep_search" | "Grep" => { + let pattern = parsed + .get("pattern") + .and_then(|value| value.as_str()) + .unwrap_or("?"); + let scope = parsed + .get("path") + .and_then(|value| value.as_str()) + .unwrap_or("."); + format!("grep `{pattern}` in {scope}") + } + "web_search" | "WebSearch" => parsed + .get("query") + .and_then(|value| value.as_str()) + .map_or_else( + || "running web search".to_string(), + |query| format!("query {}", truncate_for_summary(query, 100)), + ), + _ => { + let summary = summarize_tool_payload(input); + if summary.is_empty() { + format!("running {name}") + } else { + format!("{name}: {summary}") + } + } + } +} + +#[allow(clippy::needless_pass_by_value)] +#[allow(clippy::too_many_arguments)] +fn build_runtime( + session: Session, + model: String, + system_prompt: Vec<String>, + enable_tools: bool, + emit_output: bool, + allowed_tools: Option<AllowedToolSet>, + permission_mode: PermissionMode, + progress_reporter: Option<InternalPromptProgressReporter>, +) -> Result<ConversationRuntime<DefaultRuntimeClient, CliToolExecutor>, Box<dyn std::error::Error>> +{ + let (feature_config, tool_registry) = build_runtime_plugin_state()?; + Ok(ConversationRuntime::new_with_features( + session, + DefaultRuntimeClient::new( + model, + enable_tools, + emit_output, + allowed_tools.clone(), + tool_registry.clone(), + progress_reporter, + )?, + CliToolExecutor::new(allowed_tools.clone(), emit_output, tool_registry.clone()), + permission_policy(permission_mode, &tool_registry), + system_prompt, + feature_config, + )) +} + +struct CliPermissionPrompter { + current_mode: PermissionMode, +} + +impl CliPermissionPrompter { + fn new(current_mode: PermissionMode) -> Self { + Self { current_mode } + } +} + +impl runtime::PermissionPrompter for CliPermissionPrompter { + fn decide( + &mut self, + request: &runtime::PermissionRequest, + ) -> runtime::PermissionPromptDecision { + println!(); + println!("Permission approval required"); + println!(" Tool {}", request.tool_name); + println!(" Current mode {}", self.current_mode.as_str()); + println!(" Required mode {}", request.required_mode.as_str()); + println!(" Input {}", request.input); + print!("Approve this tool call? [y/N]: "); + let _ = io::stdout().flush(); + + let mut response = String::new(); + match io::stdin().read_line(&mut response) { + Ok(_) => { + let normalized = response.trim().to_ascii_lowercase(); + if matches!(normalized.as_str(), "y" | "yes") { + runtime::PermissionPromptDecision::Allow + } else { + runtime::PermissionPromptDecision::Deny { + reason: format!( + "tool '{}' denied by user approval prompt", + request.tool_name + ), + } + } + } + Err(error) => runtime::PermissionPromptDecision::Deny { + reason: format!("permission approval failed: {error}"), + }, + } + } +} + +struct DefaultRuntimeClient { + runtime: tokio::runtime::Runtime, + client: ClawApiClient, + model: String, + enable_tools: bool, + emit_output: bool, + allowed_tools: Option<AllowedToolSet>, + tool_registry: GlobalToolRegistry, + progress_reporter: Option<InternalPromptProgressReporter>, +} + +impl DefaultRuntimeClient { + fn new( + model: String, + enable_tools: bool, + emit_output: bool, + allowed_tools: Option<AllowedToolSet>, + tool_registry: GlobalToolRegistry, + progress_reporter: Option<InternalPromptProgressReporter>, + ) -> Result<Self, Box<dyn std::error::Error>> { + Ok(Self { + runtime: tokio::runtime::Runtime::new()?, + client: ClawApiClient::from_auth(resolve_cli_auth_source()?) + .with_base_url(api::read_base_url()), + model, + enable_tools, + emit_output, + allowed_tools, + tool_registry, + progress_reporter, + }) + } +} + +fn resolve_cli_auth_source() -> Result<AuthSource, Box<dyn std::error::Error>> { + Ok(resolve_startup_auth_source(|| { + let cwd = env::current_dir().map_err(api::ApiError::from)?; + let config = ConfigLoader::default_for(&cwd).load().map_err(|error| { + api::ApiError::Auth(format!("failed to load runtime OAuth config: {error}")) + })?; + Ok(config.oauth().cloned()) + })?) +} + +impl ApiClient for DefaultRuntimeClient { + #[allow(clippy::too_many_lines)] + fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { + if let Some(progress_reporter) = &self.progress_reporter { + progress_reporter.mark_model_phase(); + } + let message_request = MessageRequest { + model: self.model.clone(), + max_tokens: max_tokens_for_model(&self.model), + messages: convert_messages(&request.messages), + system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), + tools: self + .enable_tools + .then(|| filter_tool_specs(&self.tool_registry, self.allowed_tools.as_ref())), + tool_choice: self.enable_tools.then_some(ToolChoice::Auto), + stream: true, + }; + + self.runtime.block_on(async { + let mut stream = self + .client + .stream_message(&message_request) + .await + .map_err(|error| RuntimeError::new(error.to_string()))?; + let mut stdout = io::stdout(); + let mut sink = io::sink(); + let out: &mut dyn Write = if self.emit_output { + &mut stdout + } else { + &mut sink + }; + let renderer = TerminalRenderer::new(); + let mut markdown_stream = MarkdownStreamState::default(); + let mut events = Vec::new(); + let mut pending_tool: Option<(String, String, String)> = None; + let mut saw_stop = false; + + while let Some(event) = stream + .next_event() + .await + .map_err(|error| RuntimeError::new(error.to_string()))? + { + match event { + ApiStreamEvent::MessageStart(start) => { + for block in start.message.content { + push_output_block(block, out, &mut events, &mut pending_tool, true)?; + } + } + ApiStreamEvent::ContentBlockStart(start) => { + push_output_block( + start.content_block, + out, + &mut events, + &mut pending_tool, + true, + )?; + } + ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta { + ContentBlockDelta::TextDelta { text } => { + if !text.is_empty() { + if let Some(progress_reporter) = &self.progress_reporter { + progress_reporter.mark_text_phase(&text); + } + if let Some(rendered) = markdown_stream.push(&renderer, &text) { + write!(out, "{rendered}") + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + } + events.push(AssistantEvent::TextDelta(text)); + } + } + ContentBlockDelta::InputJsonDelta { partial_json } => { + if let Some((_, _, input)) = &mut pending_tool { + input.push_str(&partial_json); + } + } + ContentBlockDelta::ThinkingDelta { .. } + | ContentBlockDelta::SignatureDelta { .. } => {} + }, + ApiStreamEvent::ContentBlockStop(_) => { + if let Some(rendered) = markdown_stream.flush(&renderer) { + write!(out, "{rendered}") + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + } + if let Some((id, name, input)) = pending_tool.take() { + if let Some(progress_reporter) = &self.progress_reporter { + progress_reporter.mark_tool_phase(&name, &input); + } + // Display tool call now that input is fully accumulated + writeln!(out, "\n{}", format_tool_call_start(&name, &input)) + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + events.push(AssistantEvent::ToolUse { id, name, input }); + } + } + ApiStreamEvent::MessageDelta(delta) => { + events.push(AssistantEvent::Usage(TokenUsage { + input_tokens: delta.usage.input_tokens, + output_tokens: delta.usage.output_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + })); + } + ApiStreamEvent::MessageStop(_) => { + saw_stop = true; + if let Some(rendered) = markdown_stream.flush(&renderer) { + write!(out, "{rendered}") + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + } + events.push(AssistantEvent::MessageStop); + } + } + } + + if !saw_stop + && events.iter().any(|event| { + matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) + || matches!(event, AssistantEvent::ToolUse { .. }) + }) + { + events.push(AssistantEvent::MessageStop); + } + + if events + .iter() + .any(|event| matches!(event, AssistantEvent::MessageStop)) + { + return Ok(events); + } + + let response = self + .client + .send_message(&MessageRequest { + stream: false, + ..message_request.clone() + }) + .await + .map_err(|error| RuntimeError::new(error.to_string()))?; + response_to_events(response, out) + }) + } +} + +fn final_assistant_text(summary: &runtime::TurnSummary) -> String { + summary + .assistant_messages + .last() + .map(|message| { + message + .blocks + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::<Vec<_>>() + .join("") + }) + .unwrap_or_default() +} + +fn collect_tool_uses(summary: &runtime::TurnSummary) -> Vec<serde_json::Value> { + summary + .assistant_messages + .iter() + .flat_map(|message| message.blocks.iter()) + .filter_map(|block| match block { + ContentBlock::ToolUse { id, name, input } => Some(json!({ + "id": id, + "name": name, + "input": input, + })), + _ => None, + }) + .collect() +} + +fn collect_tool_results(summary: &runtime::TurnSummary) -> Vec<serde_json::Value> { + summary + .tool_results + .iter() + .flat_map(|message| message.blocks.iter()) + .filter_map(|block| match block { + ContentBlock::ToolResult { + tool_use_id, + tool_name, + output, + is_error, + } => Some(json!({ + "tool_use_id": tool_use_id, + "tool_name": tool_name, + "output": output, + "is_error": is_error, + })), + _ => None, + }) + .collect() +} + +fn slash_command_completion_candidates() -> Vec<String> { + let mut candidates = slash_command_specs() + .iter() + .flat_map(|spec| { + std::iter::once(spec.name) + .chain(spec.aliases.iter().copied()) + .map(|name| format!("/{name}")) + .collect::<Vec<_>>() + }) + .collect::<Vec<_>>(); + candidates.extend([ + String::from("/vim"), + String::from("/exit"), + String::from("/quit"), + ]); + candidates.sort(); + candidates.dedup(); + candidates +} + +fn suggest_repl_commands(name: &str) -> Vec<String> { + let normalized = name.trim().trim_start_matches('/').to_ascii_lowercase(); + if normalized.is_empty() { + return Vec::new(); + } + + let mut ranked = slash_command_completion_candidates() + .into_iter() + .filter_map(|candidate| { + let raw = candidate.trim_start_matches('/').to_ascii_lowercase(); + let distance = edit_distance(&normalized, &raw); + let prefix_match = raw.starts_with(&normalized) || normalized.starts_with(&raw); + let near_match = distance <= 2; + (prefix_match || near_match).then_some((distance, candidate)) + }) + .collect::<Vec<_>>(); + ranked.sort(); + ranked.dedup_by(|left, right| left.1 == right.1); + ranked + .into_iter() + .map(|(_, candidate)| candidate) + .take(3) + .collect() +} + +fn edit_distance(left: &str, right: &str) -> usize { + if left == right { + return 0; + } + if left.is_empty() { + return right.chars().count(); + } + if right.is_empty() { + return left.chars().count(); + } + + let right_chars = right.chars().collect::<Vec<_>>(); + let mut previous = (0..=right_chars.len()).collect::<Vec<_>>(); + let mut current = vec![0; right_chars.len() + 1]; + + for (left_index, left_char) in left.chars().enumerate() { + current[0] = left_index + 1; + for (right_index, right_char) in right_chars.iter().enumerate() { + let substitution_cost = usize::from(left_char != *right_char); + current[right_index + 1] = (previous[right_index + 1] + 1) + .min(current[right_index] + 1) + .min(previous[right_index] + substitution_cost); + } + std::mem::swap(&mut previous, &mut current); + } + + previous[right_chars.len()] +} + +fn format_tool_call_start(name: &str, input: &str) -> String { + let parsed: serde_json::Value = + serde_json::from_str(input).unwrap_or(serde_json::Value::String(input.to_string())); + + let detail = match name { + "bash" | "Bash" => format_bash_call(&parsed), + "read_file" | "Read" => { + let path = extract_tool_path(&parsed); + format!("\x1b[2m📄 Reading {path}…\x1b[0m") + } + "write_file" | "Write" => { + let path = extract_tool_path(&parsed); + let lines = parsed + .get("content") + .and_then(|value| value.as_str()) + .map_or(0, |content| content.lines().count()); + format!("\x1b[1;32m✏️ Writing {path}\x1b[0m \x1b[2m({lines} lines)\x1b[0m") + } + "edit_file" | "Edit" => { + let path = extract_tool_path(&parsed); + let old_value = parsed + .get("old_string") + .or_else(|| parsed.get("oldString")) + .and_then(|value| value.as_str()) + .unwrap_or_default(); + let new_value = parsed + .get("new_string") + .or_else(|| parsed.get("newString")) + .and_then(|value| value.as_str()) + .unwrap_or_default(); + format!( + "\x1b[1;33m📝 Editing {path}\x1b[0m{}", + format_patch_preview(old_value, new_value) + .map(|preview| format!("\n{preview}")) + .unwrap_or_default() + ) + } + "glob_search" | "Glob" => format_search_start("🔎 Glob", &parsed), + "grep_search" | "Grep" => format_search_start("🔎 Grep", &parsed), + "web_search" | "WebSearch" => parsed + .get("query") + .and_then(|value| value.as_str()) + .unwrap_or("?") + .to_string(), + _ => summarize_tool_payload(input), + }; + + let border = "─".repeat(name.len() + 8); + format!( + "\x1b[38;5;245m╭─ \x1b[1;36m{name}\x1b[0;38;5;245m ─╮\x1b[0m\n\x1b[38;5;245m│\x1b[0m {detail}\n\x1b[38;5;245m╰{border}╯\x1b[0m" + ) +} + +fn format_tool_result(name: &str, output: &str, is_error: bool) -> String { + let icon = if is_error { + "\x1b[1;31m✗\x1b[0m" + } else { + "\x1b[1;32m✓\x1b[0m" + }; + if is_error { + let summary = truncate_for_summary(output.trim(), 160); + return if summary.is_empty() { + format!("{icon} \x1b[38;5;245m{name}\x1b[0m") + } else { + format!("{icon} \x1b[38;5;245m{name}\x1b[0m\n\x1b[38;5;203m{summary}\x1b[0m") + }; + } + + let parsed: serde_json::Value = + serde_json::from_str(output).unwrap_or(serde_json::Value::String(output.to_string())); + match name { + "bash" | "Bash" => format_bash_result(icon, &parsed), + "read_file" | "Read" => format_read_result(icon, &parsed), + "write_file" | "Write" => format_write_result(icon, &parsed), + "edit_file" | "Edit" => format_edit_result(icon, &parsed), + "glob_search" | "Glob" => format_glob_result(icon, &parsed), + "grep_search" | "Grep" => format_grep_result(icon, &parsed), + _ => format_generic_tool_result(icon, name, &parsed), + } +} + +const DISPLAY_TRUNCATION_NOTICE: &str = + "\x1b[2m… output truncated for display; full result preserved in session.\x1b[0m"; +const READ_DISPLAY_MAX_LINES: usize = 80; +const READ_DISPLAY_MAX_CHARS: usize = 6_000; +const TOOL_OUTPUT_DISPLAY_MAX_LINES: usize = 60; +const TOOL_OUTPUT_DISPLAY_MAX_CHARS: usize = 4_000; + +fn extract_tool_path(parsed: &serde_json::Value) -> String { + parsed + .get("file_path") + .or_else(|| parsed.get("filePath")) + .or_else(|| parsed.get("path")) + .and_then(|value| value.as_str()) + .unwrap_or("?") + .to_string() +} + +fn format_search_start(label: &str, parsed: &serde_json::Value) -> String { + let pattern = parsed + .get("pattern") + .and_then(|value| value.as_str()) + .unwrap_or("?"); + let scope = parsed + .get("path") + .and_then(|value| value.as_str()) + .unwrap_or("."); + format!("{label} {pattern}\n\x1b[2min {scope}\x1b[0m") +} + +fn format_patch_preview(old_value: &str, new_value: &str) -> Option<String> { + if old_value.is_empty() && new_value.is_empty() { + return None; + } + Some(format!( + "\x1b[38;5;203m- {}\x1b[0m\n\x1b[38;5;70m+ {}\x1b[0m", + truncate_for_summary(first_visible_line(old_value), 72), + truncate_for_summary(first_visible_line(new_value), 72) + )) +} + +fn format_bash_call(parsed: &serde_json::Value) -> String { + let command = parsed + .get("command") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + if command.is_empty() { + String::new() + } else { + format!( + "\x1b[48;5;236;38;5;255m $ {} \x1b[0m", + truncate_for_summary(command, 160) + ) + } +} + +fn first_visible_line(text: &str) -> &str { + text.lines() + .find(|line| !line.trim().is_empty()) + .unwrap_or(text) +} + +fn format_bash_result(icon: &str, parsed: &serde_json::Value) -> String { + let mut lines = vec![format!("{icon} \x1b[38;5;245mbash\x1b[0m")]; + if let Some(task_id) = parsed + .get("backgroundTaskId") + .and_then(|value| value.as_str()) + { + write!(&mut lines[0], " backgrounded ({task_id})").expect("write to string"); + } else if let Some(status) = parsed + .get("returnCodeInterpretation") + .and_then(|value| value.as_str()) + .filter(|status| !status.is_empty()) + { + write!(&mut lines[0], " {status}").expect("write to string"); + } + + if let Some(stdout) = parsed.get("stdout").and_then(|value| value.as_str()) { + if !stdout.trim().is_empty() { + lines.push(truncate_output_for_display( + stdout, + TOOL_OUTPUT_DISPLAY_MAX_LINES, + TOOL_OUTPUT_DISPLAY_MAX_CHARS, + )); + } + } + if let Some(stderr) = parsed.get("stderr").and_then(|value| value.as_str()) { + if !stderr.trim().is_empty() { + lines.push(format!( + "\x1b[38;5;203m{}\x1b[0m", + truncate_output_for_display( + stderr, + TOOL_OUTPUT_DISPLAY_MAX_LINES, + TOOL_OUTPUT_DISPLAY_MAX_CHARS, + ) + )); + } + } + + lines.join("\n\n") +} + +fn format_read_result(icon: &str, parsed: &serde_json::Value) -> String { + let file = parsed.get("file").unwrap_or(parsed); + let path = extract_tool_path(file); + let start_line = file + .get("startLine") + .and_then(serde_json::Value::as_u64) + .unwrap_or(1); + let num_lines = file + .get("numLines") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let total_lines = file + .get("totalLines") + .and_then(serde_json::Value::as_u64) + .unwrap_or(num_lines); + let content = file + .get("content") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + let end_line = start_line.saturating_add(num_lines.saturating_sub(1)); + + format!( + "{icon} \x1b[2m📄 Read {path} (lines {}-{} of {})\x1b[0m\n{}", + start_line, + end_line.max(start_line), + total_lines, + truncate_output_for_display(content, READ_DISPLAY_MAX_LINES, READ_DISPLAY_MAX_CHARS) + ) +} + +fn format_write_result(icon: &str, parsed: &serde_json::Value) -> String { + let path = extract_tool_path(parsed); + let kind = parsed + .get("type") + .and_then(|value| value.as_str()) + .unwrap_or("write"); + let line_count = parsed + .get("content") + .and_then(|value| value.as_str()) + .map_or(0, |content| content.lines().count()); + format!( + "{icon} \x1b[1;32m✏️ {} {path}\x1b[0m \x1b[2m({line_count} lines)\x1b[0m", + if kind == "create" { "Wrote" } else { "Updated" }, + ) +} + +fn format_structured_patch_preview(parsed: &serde_json::Value) -> Option<String> { + let hunks = parsed.get("structuredPatch")?.as_array()?; + let mut preview = Vec::new(); + for hunk in hunks.iter().take(2) { + let lines = hunk.get("lines")?.as_array()?; + for line in lines.iter().filter_map(|value| value.as_str()).take(6) { + match line.chars().next() { + Some('+') => preview.push(format!("\x1b[38;5;70m{line}\x1b[0m")), + Some('-') => preview.push(format!("\x1b[38;5;203m{line}\x1b[0m")), + _ => preview.push(line.to_string()), + } + } + } + if preview.is_empty() { + None + } else { + Some(preview.join("\n")) + } +} + +fn format_edit_result(icon: &str, parsed: &serde_json::Value) -> String { + let path = extract_tool_path(parsed); + let suffix = if parsed + .get("replaceAll") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + { + " (replace all)" + } else { + "" + }; + let preview = format_structured_patch_preview(parsed).or_else(|| { + let old_value = parsed + .get("oldString") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + let new_value = parsed + .get("newString") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + format_patch_preview(old_value, new_value) + }); + + match preview { + Some(preview) => format!("{icon} \x1b[1;33m📝 Edited {path}{suffix}\x1b[0m\n{preview}"), + None => format!("{icon} \x1b[1;33m📝 Edited {path}{suffix}\x1b[0m"), + } +} + +fn format_glob_result(icon: &str, parsed: &serde_json::Value) -> String { + let num_files = parsed + .get("numFiles") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let filenames = parsed + .get("filenames") + .and_then(|value| value.as_array()) + .map(|files| { + files + .iter() + .filter_map(|value| value.as_str()) + .take(8) + .collect::<Vec<_>>() + .join("\n") + }) + .unwrap_or_default(); + if filenames.is_empty() { + format!("{icon} \x1b[38;5;245mglob_search\x1b[0m matched {num_files} files") + } else { + format!("{icon} \x1b[38;5;245mglob_search\x1b[0m matched {num_files} files\n{filenames}") + } +} + +fn format_grep_result(icon: &str, parsed: &serde_json::Value) -> String { + let num_matches = parsed + .get("numMatches") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let num_files = parsed + .get("numFiles") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + let content = parsed + .get("content") + .and_then(|value| value.as_str()) + .unwrap_or_default(); + let filenames = parsed + .get("filenames") + .and_then(|value| value.as_array()) + .map(|files| { + files + .iter() + .filter_map(|value| value.as_str()) + .take(8) + .collect::<Vec<_>>() + .join("\n") + }) + .unwrap_or_default(); + let summary = format!( + "{icon} \x1b[38;5;245mgrep_search\x1b[0m {num_matches} matches across {num_files} files" + ); + if !content.trim().is_empty() { + format!( + "{summary}\n{}", + truncate_output_for_display( + content, + TOOL_OUTPUT_DISPLAY_MAX_LINES, + TOOL_OUTPUT_DISPLAY_MAX_CHARS, + ) + ) + } else if !filenames.is_empty() { + format!("{summary}\n{filenames}") + } else { + summary + } +} + +fn format_generic_tool_result(icon: &str, name: &str, parsed: &serde_json::Value) -> String { + let rendered_output = match parsed { + serde_json::Value::String(text) => text.clone(), + serde_json::Value::Null => String::new(), + serde_json::Value::Object(_) | serde_json::Value::Array(_) => { + serde_json::to_string_pretty(parsed).unwrap_or_else(|_| parsed.to_string()) + } + _ => parsed.to_string(), + }; + let preview = truncate_output_for_display( + &rendered_output, + TOOL_OUTPUT_DISPLAY_MAX_LINES, + TOOL_OUTPUT_DISPLAY_MAX_CHARS, + ); + + if preview.is_empty() { + format!("{icon} \x1b[38;5;245m{name}\x1b[0m") + } else if preview.contains('\n') { + format!("{icon} \x1b[38;5;245m{name}\x1b[0m\n{preview}") + } else { + format!("{icon} \x1b[38;5;245m{name}:\x1b[0m {preview}") + } +} + +fn summarize_tool_payload(payload: &str) -> String { + let compact = match serde_json::from_str::<serde_json::Value>(payload) { + Ok(value) => value.to_string(), + Err(_) => payload.trim().to_string(), + }; + truncate_for_summary(&compact, 96) +} + +fn truncate_for_summary(value: &str, limit: usize) -> String { + let mut chars = value.chars(); + let truncated = chars.by_ref().take(limit).collect::<String>(); + if chars.next().is_some() { + format!("{truncated}…") + } else { + truncated + } +} + +fn truncate_output_for_display(content: &str, max_lines: usize, max_chars: usize) -> String { + let original = content.trim_end_matches('\n'); + if original.is_empty() { + return String::new(); + } + + let mut preview_lines = Vec::new(); + let mut used_chars = 0usize; + let mut truncated = false; + + for (index, line) in original.lines().enumerate() { + if index >= max_lines { + truncated = true; + break; + } + + let newline_cost = usize::from(!preview_lines.is_empty()); + let available = max_chars.saturating_sub(used_chars + newline_cost); + if available == 0 { + truncated = true; + break; + } + + let line_chars = line.chars().count(); + if line_chars > available { + preview_lines.push(line.chars().take(available).collect::<String>()); + truncated = true; + break; + } + + preview_lines.push(line.to_string()); + used_chars += newline_cost + line_chars; + } + + let mut preview = preview_lines.join("\n"); + if truncated { + if !preview.is_empty() { + preview.push('\n'); + } + preview.push_str(DISPLAY_TRUNCATION_NOTICE); + } + preview +} + +fn push_output_block( + block: OutputContentBlock, + out: &mut (impl Write + ?Sized), + events: &mut Vec<AssistantEvent>, + pending_tool: &mut Option<(String, String, String)>, + streaming_tool_input: bool, +) -> Result<(), RuntimeError> { + match block { + OutputContentBlock::Text { text } => { + if !text.is_empty() { + let rendered = TerminalRenderer::new().markdown_to_ansi(&text); + write!(out, "{rendered}") + .and_then(|()| out.flush()) + .map_err(|error| RuntimeError::new(error.to_string()))?; + events.push(AssistantEvent::TextDelta(text)); + } + } + OutputContentBlock::ToolUse { id, name, input } => { + // During streaming, the initial content_block_start has an empty input ({}). + // The real input arrives via input_json_delta events. In + // non-streaming responses, preserve a legitimate empty object. + let initial_input = if streaming_tool_input + && input.is_object() + && input.as_object().is_some_and(serde_json::Map::is_empty) + { + String::new() + } else { + input.to_string() + }; + *pending_tool = Some((id, name, initial_input)); + } + OutputContentBlock::Thinking { .. } | OutputContentBlock::RedactedThinking { .. } => {} + } + Ok(()) +} + +fn response_to_events( + response: MessageResponse, + out: &mut (impl Write + ?Sized), +) -> Result<Vec<AssistantEvent>, RuntimeError> { + let mut events = Vec::new(); + let mut pending_tool = None; + + for block in response.content { + push_output_block(block, out, &mut events, &mut pending_tool, false)?; + if let Some((id, name, input)) = pending_tool.take() { + events.push(AssistantEvent::ToolUse { id, name, input }); + } + } + + events.push(AssistantEvent::Usage(TokenUsage { + input_tokens: response.usage.input_tokens, + output_tokens: response.usage.output_tokens, + cache_creation_input_tokens: response.usage.cache_creation_input_tokens, + cache_read_input_tokens: response.usage.cache_read_input_tokens, + })); + events.push(AssistantEvent::MessageStop); + Ok(events) +} + +struct CliToolExecutor { + renderer: TerminalRenderer, + emit_output: bool, + allowed_tools: Option<AllowedToolSet>, + tool_registry: GlobalToolRegistry, +} + +impl CliToolExecutor { + fn new( + allowed_tools: Option<AllowedToolSet>, + emit_output: bool, + tool_registry: GlobalToolRegistry, + ) -> Self { + Self { + renderer: TerminalRenderer::new(), + emit_output, + allowed_tools, + tool_registry, + } + } +} + +impl ToolExecutor for CliToolExecutor { + fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> { + if self + .allowed_tools + .as_ref() + .is_some_and(|allowed| !allowed.contains(tool_name)) + { + return Err(ToolError::new(format!( + "tool `{tool_name}` is not enabled by the current --allowedTools setting" + ))); + } + let value = serde_json::from_str(input) + .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + match self.tool_registry.execute(tool_name, &value) { + Ok(output) => { + if self.emit_output { + let markdown = format_tool_result(tool_name, &output, false); + self.renderer + .stream_markdown(&markdown, &mut io::stdout()) + .map_err(|error| ToolError::new(error.to_string()))?; + } + Ok(output) + } + Err(error) => { + if self.emit_output { + let markdown = format_tool_result(tool_name, &error, true); + self.renderer + .stream_markdown(&markdown, &mut io::stdout()) + .map_err(|stream_error| ToolError::new(stream_error.to_string()))?; + } + Err(ToolError::new(error)) + } + } + } +} + +fn permission_policy(mode: PermissionMode, tool_registry: &GlobalToolRegistry) -> PermissionPolicy { + tool_registry.permission_specs(None).into_iter().fold( + PermissionPolicy::new(mode), + |policy, (name, required_permission)| { + policy.with_tool_requirement(name, required_permission) + }, + ) +} + +fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> { + messages + .iter() + .filter_map(|message| { + let role = match message.role { + MessageRole::System | MessageRole::User | MessageRole::Tool => "user", + MessageRole::Assistant => "assistant", + }; + let content = message + .blocks + .iter() + .map(|block| match block { + ContentBlock::Text { text } => InputContentBlock::Text { text: text.clone() }, + ContentBlock::ToolUse { id, name, input } => InputContentBlock::ToolUse { + id: id.clone(), + name: name.clone(), + input: serde_json::from_str(input) + .unwrap_or_else(|_| serde_json::json!({ "raw": input })), + }, + ContentBlock::ToolResult { + tool_use_id, + output, + is_error, + .. + } => InputContentBlock::ToolResult { + tool_use_id: tool_use_id.clone(), + content: vec![ToolResultContentBlock::Text { + text: output.clone(), + }], + is_error: *is_error, + }, + }) + .collect::<Vec<_>>(); + (!content.is_empty()).then(|| InputMessage { + role: role.to_string(), + content, + }) + }) + .collect() +} + +fn print_help_to(out: &mut impl Write) -> io::Result<()> { + writeln!(out, "Claw Code CLI v{VERSION}")?; + writeln!( + out, + " Interactive coding assistant for the current workspace." + )?; + writeln!(out)?; + writeln!(out, "Quick start")?; + writeln!( + out, + " claw Start the interactive REPL" + )?; + writeln!( + out, + " claw \"summarize this repo\" Run one prompt and exit" + )?; + writeln!( + out, + " claw prompt \"explain src/main.rs\" Explicit one-shot prompt" + )?; + writeln!( + out, + " claw --resume SESSION.json /status Inspect a saved session" + )?; + writeln!(out)?; + writeln!(out, "Interactive essentials")?; + writeln!( + out, + " /help Browse the full slash command map" + )?; + writeln!( + out, + " /status Inspect session + workspace state" + )?; + writeln!( + out, + " /model <name> Switch models mid-session" + )?; + writeln!( + out, + " /permissions <mode> Adjust tool access" + )?; + writeln!( + out, + " Tab Complete slash commands" + )?; + writeln!( + out, + " /vim Toggle modal editing" + )?; + writeln!( + out, + " Shift+Enter / Ctrl+J Insert a newline" + )?; + writeln!(out)?; + writeln!(out, "Commands")?; + writeln!( + out, + " claw dump-manifests Read upstream TS sources and print extracted counts" + )?; + writeln!( + out, + " claw bootstrap-plan Print the bootstrap phase skeleton" + )?; + writeln!( + out, + " claw agents List configured agents" + )?; + writeln!( + out, + " claw skills List installed skills" + )?; + writeln!(out, " claw system-prompt [--cwd PATH] [--date YYYY-MM-DD]")?; + writeln!( + out, + " claw login Start the OAuth login flow" + )?; + writeln!( + out, + " claw logout Clear saved OAuth credentials" + )?; + writeln!( + out, + " claw init Scaffold CLAW.md + local files" + )?; + writeln!(out)?; + writeln!(out, "Flags")?; + writeln!( + out, + " --model MODEL Override the active model" + )?; + writeln!( + out, + " --output-format FORMAT Non-interactive output: text or json" + )?; + writeln!( + out, + " --permission-mode MODE Set read-only, workspace-write, or danger-full-access" + )?; + writeln!( + out, + " --dangerously-skip-permissions Skip all permission checks" + )?; + writeln!( + out, + " --allowedTools TOOLS Restrict enabled tools (repeatable; comma-separated aliases supported)" + )?; + writeln!( + out, + " --version, -V Print version and build information" + )?; + writeln!(out)?; + writeln!(out, "Slash command reference")?; + writeln!(out, "{}", render_slash_command_help())?; + writeln!(out)?; + let resume_commands = resume_supported_slash_commands() + .into_iter() + .map(|spec| match spec.argument_hint { + Some(argument_hint) => format!("/{} {}", spec.name, argument_hint), + None => format!("/{}", spec.name), + }) + .collect::<Vec<_>>() + .join(", "); + writeln!(out, "Resume-safe commands: {resume_commands}")?; + writeln!(out, "Examples")?; + writeln!(out, " claw --model opus \"summarize this repo\"")?; + writeln!( + out, + " claw --output-format json prompt \"explain src/main.rs\"" + )?; + writeln!( + out, + " claw --allowedTools read,glob \"summarize Cargo.toml\"" + )?; + writeln!( + out, + " claw --resume session.json /status /diff /export notes.txt" + )?; + writeln!(out, " claw agents")?; + writeln!(out, " claw /skills")?; + writeln!(out, " claw login")?; + writeln!(out, " claw init")?; + Ok(()) +} + +fn print_help() { + let _ = print_help_to(&mut io::stdout()); +} + +#[cfg(test)] +mod tests { + use super::{ + describe_tool_progress, filter_tool_specs, format_compact_report, format_cost_report, + format_internal_prompt_progress_line, format_model_report, format_model_switch_report, + format_permissions_report, format_permissions_switch_report, format_resume_report, + format_status_report, format_tool_call_start, format_tool_result, + normalize_permission_mode, parse_args, parse_git_status_metadata, permission_policy, + print_help_to, push_output_block, render_config_report, render_memory_report, + render_repl_help, render_unknown_repl_command, resolve_model_alias, response_to_events, + resume_supported_slash_commands, slash_command_completion_candidates, status_context, + CliAction, CliOutputFormat, InternalPromptProgressEvent, InternalPromptProgressState, + SlashCommand, StatusUsage, DEFAULT_MODEL, + }; + use api::{MessageResponse, OutputContentBlock, Usage}; + use plugins::{PluginTool, PluginToolDefinition, PluginToolPermission}; + use runtime::{AssistantEvent, ContentBlock, ConversationMessage, MessageRole, PermissionMode}; + use serde_json::json; + use std::path::PathBuf; + use std::time::Duration; + use tools::GlobalToolRegistry; + + fn registry_with_plugin_tool() -> GlobalToolRegistry { + GlobalToolRegistry::with_plugin_tools(vec![PluginTool::new( + "plugin-demo@external", + "plugin-demo", + PluginToolDefinition { + name: "plugin_echo".to_string(), + description: Some("Echo plugin payload".to_string()), + input_schema: json!({ + "type": "object", + "properties": { + "message": { "type": "string" } + }, + "required": ["message"], + "additionalProperties": false + }), + }, + "echo".to_string(), + Vec::new(), + PluginToolPermission::WorkspaceWrite, + None, + )]) + .expect("plugin tool registry should build") + } + + #[test] + fn defaults_to_repl_when_no_args() { + assert_eq!( + parse_args(&[]).expect("args should parse"), + CliAction::Repl { + model: DEFAULT_MODEL.to_string(), + allowed_tools: None, + permission_mode: PermissionMode::DangerFullAccess, + } + ); + } + + #[test] + fn parses_prompt_subcommand() { + let args = vec![ + "prompt".to_string(), + "hello".to_string(), + "world".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::Prompt { + prompt: "hello world".to_string(), + model: DEFAULT_MODEL.to_string(), + output_format: CliOutputFormat::Text, + allowed_tools: None, + permission_mode: PermissionMode::DangerFullAccess, + } + ); + } + + #[test] + fn parses_bare_prompt_and_json_output_flag() { + let args = vec![ + "--output-format=json".to_string(), + "--model".to_string(), + "custom-opus".to_string(), + "explain".to_string(), + "this".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::Prompt { + prompt: "explain this".to_string(), + model: "custom-opus".to_string(), + output_format: CliOutputFormat::Json, + allowed_tools: None, + permission_mode: PermissionMode::DangerFullAccess, + } + ); + } + + #[test] + fn resolves_model_aliases_in_args() { + let args = vec![ + "--model".to_string(), + "opus".to_string(), + "explain".to_string(), + "this".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::Prompt { + prompt: "explain this".to_string(), + model: "claude-opus-4-6".to_string(), + output_format: CliOutputFormat::Text, + allowed_tools: None, + permission_mode: PermissionMode::DangerFullAccess, + } + ); + } + + #[test] + fn resolves_known_model_aliases() { + assert_eq!(resolve_model_alias("opus"), "claude-opus-4-6"); + assert_eq!(resolve_model_alias("sonnet"), "claude-sonnet-4-6"); + assert_eq!(resolve_model_alias("haiku"), "claude-haiku-4-5-20251213"); + assert_eq!(resolve_model_alias("custom-opus"), "custom-opus"); + } + + #[test] + fn parses_version_flags_without_initializing_prompt_mode() { + assert_eq!( + parse_args(&["--version".to_string()]).expect("args should parse"), + CliAction::Version + ); + assert_eq!( + parse_args(&["-V".to_string()]).expect("args should parse"), + CliAction::Version + ); + } + + #[test] + fn parses_permission_mode_flag() { + let args = vec!["--permission-mode=read-only".to_string()]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::Repl { + model: DEFAULT_MODEL.to_string(), + allowed_tools: None, + permission_mode: PermissionMode::ReadOnly, + } + ); + } + + #[test] + fn parses_allowed_tools_flags_with_aliases_and_lists() { + let args = vec![ + "--allowedTools".to_string(), + "read,glob".to_string(), + "--allowed-tools=write_file".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::Repl { + model: DEFAULT_MODEL.to_string(), + allowed_tools: Some( + ["glob_search", "read_file", "write_file"] + .into_iter() + .map(str::to_string) + .collect() + ), + permission_mode: PermissionMode::DangerFullAccess, + } + ); + } + + #[test] + fn rejects_unknown_allowed_tools() { + let error = parse_args(&["--allowedTools".to_string(), "teleport".to_string()]) + .expect_err("tool should be rejected"); + assert!(error.contains("unsupported tool in --allowedTools: teleport")); + } + + #[test] + fn parses_system_prompt_options() { + let args = vec![ + "system-prompt".to_string(), + "--cwd".to_string(), + "/tmp/project".to_string(), + "--date".to_string(), + "2026-04-01".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::PrintSystemPrompt { + cwd: PathBuf::from("/tmp/project"), + date: "2026-04-01".to_string(), + } + ); + } + + #[test] + fn parses_login_and_logout_subcommands() { + assert_eq!( + parse_args(&["login".to_string()]).expect("login should parse"), + CliAction::Login + ); + assert_eq!( + parse_args(&["logout".to_string()]).expect("logout should parse"), + CliAction::Logout + ); + assert_eq!( + parse_args(&["init".to_string()]).expect("init should parse"), + CliAction::Init + ); + assert_eq!( + parse_args(&["agents".to_string()]).expect("agents should parse"), + CliAction::Agents { args: None } + ); + assert_eq!( + parse_args(&["skills".to_string()]).expect("skills should parse"), + CliAction::Skills { args: None } + ); + assert_eq!( + parse_args(&["agents".to_string(), "--help".to_string()]) + .expect("agents help should parse"), + CliAction::Agents { + args: Some("--help".to_string()) + } + ); + } + + #[test] + fn parses_direct_agents_and_skills_slash_commands() { + assert_eq!( + parse_args(&["/agents".to_string()]).expect("/agents should parse"), + CliAction::Agents { args: None } + ); + assert_eq!( + parse_args(&["/skills".to_string()]).expect("/skills should parse"), + CliAction::Skills { args: None } + ); + assert_eq!( + parse_args(&["/skills".to_string(), "help".to_string()]) + .expect("/skills help should parse"), + CliAction::Skills { + args: Some("help".to_string()) + } + ); + let error = parse_args(&["/status".to_string()]) + .expect_err("/status should remain REPL-only when invoked directly"); + assert!(error.contains("Direct slash command unavailable")); + assert!(error.contains("/status")); + } + + #[test] + fn parses_resume_flag_with_slash_command() { + let args = vec![ + "--resume".to_string(), + "session.json".to_string(), + "/compact".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::ResumeSession { + session_path: PathBuf::from("session.json"), + commands: vec!["/compact".to_string()], + } + ); + } + + #[test] + fn parses_resume_flag_with_multiple_slash_commands() { + let args = vec![ + "--resume".to_string(), + "session.json".to_string(), + "/status".to_string(), + "/compact".to_string(), + "/cost".to_string(), + ]; + assert_eq!( + parse_args(&args).expect("args should parse"), + CliAction::ResumeSession { + session_path: PathBuf::from("session.json"), + commands: vec![ + "/status".to_string(), + "/compact".to_string(), + "/cost".to_string(), + ], + } + ); + } + + #[test] + fn filtered_tool_specs_respect_allowlist() { + let allowed = ["read_file", "grep_search"] + .into_iter() + .map(str::to_string) + .collect(); + let filtered = filter_tool_specs(&GlobalToolRegistry::builtin(), Some(&allowed)); + let names = filtered + .into_iter() + .map(|spec| spec.name) + .collect::<Vec<_>>(); + assert_eq!(names, vec!["read_file", "grep_search"]); + } + + #[test] + fn filtered_tool_specs_include_plugin_tools() { + let filtered = filter_tool_specs(®istry_with_plugin_tool(), None); + let names = filtered + .into_iter() + .map(|definition| definition.name) + .collect::<Vec<_>>(); + assert!(names.contains(&"bash".to_string())); + assert!(names.contains(&"plugin_echo".to_string())); + } + + #[test] + fn permission_policy_uses_plugin_tool_permissions() { + let policy = permission_policy(PermissionMode::ReadOnly, ®istry_with_plugin_tool()); + let required = policy.required_mode_for("plugin_echo"); + assert_eq!(required, PermissionMode::WorkspaceWrite); + } + + #[test] + fn shared_help_uses_resume_annotation_copy() { + let help = commands::render_slash_command_help(); + assert!(help.contains("Slash commands")); + assert!(help.contains("Tab completes commands inside the REPL.")); + assert!(help.contains("available via claw --resume SESSION.json")); + } + + #[test] + fn repl_help_includes_shared_commands_and_exit() { + let help = render_repl_help(); + assert!(help.contains("Interactive REPL")); + assert!(help.contains("/help")); + assert!(help.contains("/status")); + assert!(help.contains("/model [model]")); + assert!(help.contains("/permissions [read-only|workspace-write|danger-full-access]")); + assert!(help.contains("/clear [--confirm]")); + assert!(help.contains("/cost")); + assert!(help.contains("/resume <session-path>")); + assert!(help.contains("/config [env|hooks|model|plugins]")); + assert!(help.contains("/memory")); + assert!(help.contains("/init")); + assert!(help.contains("/diff")); + assert!(help.contains("/version")); + assert!(help.contains("/export [file]")); + assert!(help.contains("/session [list|switch <session-id>]")); + assert!(help.contains( + "/plugin [list|install <path>|enable <name>|disable <name>|uninstall <id>|update <id>]" + )); + assert!(help.contains("aliases: /plugins, /marketplace")); + assert!(help.contains("/agents")); + assert!(help.contains("/skills")); + assert!(help.contains("/exit")); + assert!(help.contains("Tab cycles slash command matches")); + } + + #[test] + fn completion_candidates_include_repl_only_exit_commands() { + let candidates = slash_command_completion_candidates(); + assert!(candidates.contains(&"/help".to_string())); + assert!(candidates.contains(&"/vim".to_string())); + assert!(candidates.contains(&"/exit".to_string())); + assert!(candidates.contains(&"/quit".to_string())); + } + + #[test] + fn unknown_repl_command_suggestions_include_repl_shortcuts() { + let rendered = render_unknown_repl_command("exi"); + assert!(rendered.contains("Unknown slash command")); + assert!(rendered.contains("/exit")); + assert!(rendered.contains("/help")); + } + + #[test] + fn resume_supported_command_list_matches_expected_surface() { + let names = resume_supported_slash_commands() + .into_iter() + .map(|spec| spec.name) + .collect::<Vec<_>>(); + assert_eq!( + names, + vec![ + "help", "status", "compact", "clear", "cost", "config", "memory", "init", "diff", + "version", "export", "agents", "skills", + ] + ); + } + + #[test] + fn resume_report_uses_sectioned_layout() { + let report = format_resume_report("session.json", 14, 6); + assert!(report.contains("Session resumed")); + assert!(report.contains("Session file session.json")); + assert!(report.contains("History 14 messages · 6 turns")); + assert!(report.contains("/status · /diff · /export")); + } + + #[test] + fn compact_report_uses_structured_output() { + let compacted = format_compact_report(8, 5, false); + assert!(compacted.contains("Compact")); + assert!(compacted.contains("Result compacted")); + assert!(compacted.contains("Messages removed 8")); + assert!(compacted.contains("Use /status")); + let skipped = format_compact_report(0, 3, true); + assert!(skipped.contains("Result skipped")); + } + + #[test] + fn cost_report_uses_sectioned_layout() { + let report = format_cost_report(runtime::TokenUsage { + input_tokens: 20, + output_tokens: 8, + cache_creation_input_tokens: 3, + cache_read_input_tokens: 1, + }); + assert!(report.contains("Cost")); + assert!(report.contains("Input tokens 20")); + assert!(report.contains("Output tokens 8")); + assert!(report.contains("Cache create 3")); + assert!(report.contains("Cache read 1")); + assert!(report.contains("Total tokens 32")); + assert!(report.contains("/compact")); + } + + #[test] + fn permissions_report_uses_sectioned_layout() { + let report = format_permissions_report("workspace-write"); + assert!(report.contains("Permissions")); + assert!(report.contains("Active mode workspace-write")); + assert!(report.contains("Effect Editing tools can modify files in the workspace")); + assert!(report.contains("Modes")); + assert!(report.contains("read-only ○ available Read/search tools only")); + assert!(report.contains("workspace-write ● current Edit files inside the workspace")); + assert!(report.contains("danger-full-access ○ available Unrestricted tool access")); + } + + #[test] + fn permissions_switch_report_is_structured() { + let report = format_permissions_switch_report("read-only", "workspace-write"); + assert!(report.contains("Permissions updated")); + assert!(report.contains("Previous mode read-only")); + assert!(report.contains("Active mode workspace-write")); + assert!(report.contains("Applies to Subsequent tool calls in this REPL")); + } + + #[test] + fn init_help_mentions_direct_subcommand() { + let mut help = Vec::new(); + print_help_to(&mut help).expect("help should render"); + let help = String::from_utf8(help).expect("help should be utf8"); + assert!(help.contains("claw init")); + assert!(help.contains("claw agents")); + assert!(help.contains("claw skills")); + assert!(help.contains("claw /skills")); + } + + #[test] + fn model_report_uses_sectioned_layout() { + let report = format_model_report("sonnet", 12, 4); + assert!(report.contains("Model")); + assert!(report.contains("Current sonnet")); + assert!(report.contains("Session 12 messages · 4 turns")); + assert!(report.contains("Aliases")); + assert!(report.contains("/model <name> Switch models for this REPL session")); + } + + #[test] + fn model_switch_report_preserves_context_summary() { + let report = format_model_switch_report("sonnet", "opus", 9); + assert!(report.contains("Model updated")); + assert!(report.contains("Previous sonnet")); + assert!(report.contains("Current opus")); + assert!(report.contains("Preserved 9 messages")); + } + + #[test] + fn status_line_reports_model_and_token_totals() { + let status = format_status_report( + "sonnet", + StatusUsage { + message_count: 7, + turns: 3, + latest: runtime::TokenUsage { + input_tokens: 5, + output_tokens: 4, + cache_creation_input_tokens: 1, + cache_read_input_tokens: 0, + }, + cumulative: runtime::TokenUsage { + input_tokens: 20, + output_tokens: 8, + cache_creation_input_tokens: 2, + cache_read_input_tokens: 1, + }, + estimated_tokens: 128, + }, + "workspace-write", + &super::StatusContext { + cwd: PathBuf::from("/tmp/project"), + session_path: Some(PathBuf::from("session.json")), + loaded_config_files: 2, + discovered_config_files: 3, + memory_file_count: 4, + project_root: Some(PathBuf::from("/tmp")), + git_branch: Some("main".to_string()), + }, + ); + assert!(status.contains("Session")); + assert!(status.contains("Model sonnet")); + assert!(status.contains("Permissions workspace-write")); + assert!(status.contains("Activity 7 messages · 3 turns")); + assert!(status.contains("Tokens est 128 · latest 10 · total 31")); + assert!(status.contains("Folder /tmp/project")); + assert!(status.contains("Project root /tmp")); + assert!(status.contains("Git branch main")); + assert!(status.contains("Session file session.json")); + assert!(status.contains("Config files loaded 2/3")); + assert!(status.contains("Memory files 4")); + assert!(status.contains("/session list")); + } + + #[test] + fn config_report_supports_section_views() { + let report = render_config_report(Some("env")).expect("config report should render"); + assert!(report.contains("Merged section: env")); + let plugins_report = + render_config_report(Some("plugins")).expect("plugins config report should render"); + assert!(plugins_report.contains("Merged section: plugins")); + } + + #[test] + fn memory_report_uses_sectioned_layout() { + let report = render_memory_report().expect("memory report should render"); + assert!(report.contains("Memory")); + assert!(report.contains("Working directory")); + assert!(report.contains("Instruction files")); + assert!(report.contains("Discovered files")); + } + + #[test] + fn config_report_uses_sectioned_layout() { + let report = render_config_report(None).expect("config report should render"); + assert!(report.contains("Config")); + assert!(report.contains("Discovered files")); + assert!(report.contains("Merged JSON")); + } + + #[test] + fn parses_git_status_metadata() { + let (root, branch) = parse_git_status_metadata(Some( + "## rcc/cli...origin/rcc/cli + M src/main.rs", + )); + assert_eq!(branch.as_deref(), Some("rcc/cli")); + let _ = root; + } + + #[test] + fn status_context_reads_real_workspace_metadata() { + let context = status_context(None).expect("status context should load"); + assert!(context.cwd.is_absolute()); + assert_eq!(context.discovered_config_files, 5); + assert!(context.loaded_config_files <= context.discovered_config_files); + } + + #[test] + fn normalizes_supported_permission_modes() { + assert_eq!(normalize_permission_mode("read-only"), Some("read-only")); + assert_eq!( + normalize_permission_mode("workspace-write"), + Some("workspace-write") + ); + assert_eq!( + normalize_permission_mode("danger-full-access"), + Some("danger-full-access") + ); + assert_eq!(normalize_permission_mode("unknown"), None); + } + + #[test] + fn clear_command_requires_explicit_confirmation_flag() { + assert_eq!( + SlashCommand::parse("/clear"), + Some(SlashCommand::Clear { confirm: false }) + ); + assert_eq!( + SlashCommand::parse("/clear --confirm"), + Some(SlashCommand::Clear { confirm: true }) + ); + } + + #[test] + fn parses_resume_and_config_slash_commands() { + assert_eq!( + SlashCommand::parse("/resume saved-session.json"), + Some(SlashCommand::Resume { + session_path: Some("saved-session.json".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/clear --confirm"), + Some(SlashCommand::Clear { confirm: true }) + ); + assert_eq!( + SlashCommand::parse("/config"), + Some(SlashCommand::Config { section: None }) + ); + assert_eq!( + SlashCommand::parse("/config env"), + Some(SlashCommand::Config { + section: Some("env".to_string()) + }) + ); + assert_eq!(SlashCommand::parse("/memory"), Some(SlashCommand::Memory)); + assert_eq!(SlashCommand::parse("/init"), Some(SlashCommand::Init)); + } + + #[test] + fn init_template_mentions_detected_rust_workspace() { + let rendered = crate::init::render_init_claw_md(std::path::Path::new(".")); + assert!(rendered.contains("# CLAW.md")); + assert!(rendered.contains("cargo clippy --workspace --all-targets -- -D warnings")); + } + + #[test] + fn converts_tool_roundtrip_messages() { + let messages = vec![ + ConversationMessage::user_text("hello"), + ConversationMessage::assistant(vec![ContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "bash".to_string(), + input: "{\"command\":\"pwd\"}".to_string(), + }]), + ConversationMessage { + role: MessageRole::Tool, + blocks: vec![ContentBlock::ToolResult { + tool_use_id: "tool-1".to_string(), + tool_name: "bash".to_string(), + output: "ok".to_string(), + is_error: false, + }], + usage: None, + }, + ]; + + let converted = super::convert_messages(&messages); + assert_eq!(converted.len(), 3); + assert_eq!(converted[1].role, "assistant"); + assert_eq!(converted[2].role, "user"); + } + #[test] + fn repl_help_mentions_history_completion_and_multiline() { + let help = render_repl_help(); + assert!(help.contains("Up/Down")); + assert!(help.contains("Tab cycles")); + assert!(help.contains("Shift+Enter or Ctrl+J")); + } + + #[test] + fn tool_rendering_helpers_compact_output() { + let start = format_tool_call_start("read_file", r#"{"path":"src/main.rs"}"#); + assert!(start.contains("read_file")); + assert!(start.contains("src/main.rs")); + + let done = format_tool_result( + "read_file", + r#"{"file":{"filePath":"src/main.rs","content":"hello","numLines":1,"startLine":1,"totalLines":1}}"#, + false, + ); + assert!(done.contains("📄 Read src/main.rs")); + assert!(done.contains("hello")); + } + + #[test] + fn tool_rendering_truncates_large_read_output_for_display_only() { + let content = (0..200) + .map(|index| format!("line {index:03}")) + .collect::<Vec<_>>() + .join("\n"); + let output = json!({ + "file": { + "filePath": "src/main.rs", + "content": content, + "numLines": 200, + "startLine": 1, + "totalLines": 200 + } + }) + .to_string(); + + let rendered = format_tool_result("read_file", &output, false); + + assert!(rendered.contains("line 000")); + assert!(rendered.contains("line 079")); + assert!(!rendered.contains("line 199")); + assert!(rendered.contains("full result preserved in session")); + assert!(output.contains("line 199")); + } + + #[test] + fn tool_rendering_truncates_large_bash_output_for_display_only() { + let stdout = (0..120) + .map(|index| format!("stdout {index:03}")) + .collect::<Vec<_>>() + .join("\n"); + let output = json!({ + "stdout": stdout, + "stderr": "", + "returnCodeInterpretation": "completed successfully" + }) + .to_string(); + + let rendered = format_tool_result("bash", &output, false); + + assert!(rendered.contains("stdout 000")); + assert!(rendered.contains("stdout 059")); + assert!(!rendered.contains("stdout 119")); + assert!(rendered.contains("full result preserved in session")); + assert!(output.contains("stdout 119")); + } + + #[test] + fn tool_rendering_truncates_generic_long_output_for_display_only() { + let items = (0..120) + .map(|index| format!("payload {index:03}")) + .collect::<Vec<_>>(); + let output = json!({ + "summary": "plugin payload", + "items": items, + }) + .to_string(); + + let rendered = format_tool_result("plugin_echo", &output, false); + + assert!(rendered.contains("plugin_echo")); + assert!(rendered.contains("payload 000")); + assert!(rendered.contains("payload 040")); + assert!(!rendered.contains("payload 080")); + assert!(!rendered.contains("payload 119")); + assert!(rendered.contains("full result preserved in session")); + assert!(output.contains("payload 119")); + } + + #[test] + fn tool_rendering_truncates_raw_generic_output_for_display_only() { + let output = (0..120) + .map(|index| format!("raw {index:03}")) + .collect::<Vec<_>>() + .join("\n"); + + let rendered = format_tool_result("plugin_echo", &output, false); + + assert!(rendered.contains("plugin_echo")); + assert!(rendered.contains("raw 000")); + assert!(rendered.contains("raw 059")); + assert!(!rendered.contains("raw 119")); + assert!(rendered.contains("full result preserved in session")); + assert!(output.contains("raw 119")); + } + + #[test] + fn ultraplan_progress_lines_include_phase_step_and_elapsed_status() { + let snapshot = InternalPromptProgressState { + command_label: "Ultraplan", + task_label: "ship plugin progress".to_string(), + step: 3, + phase: "running read_file".to_string(), + detail: Some("reading rust/crates/claw-cli/src/main.rs".to_string()), + saw_final_text: false, + }; + + let started = format_internal_prompt_progress_line( + InternalPromptProgressEvent::Started, + &snapshot, + Duration::from_secs(0), + None, + ); + let heartbeat = format_internal_prompt_progress_line( + InternalPromptProgressEvent::Heartbeat, + &snapshot, + Duration::from_secs(9), + None, + ); + let completed = format_internal_prompt_progress_line( + InternalPromptProgressEvent::Complete, + &snapshot, + Duration::from_secs(12), + None, + ); + let failed = format_internal_prompt_progress_line( + InternalPromptProgressEvent::Failed, + &snapshot, + Duration::from_secs(12), + Some("network timeout"), + ); + + assert!(started.contains("planning started")); + assert!(started.contains("current step 3")); + assert!(heartbeat.contains("heartbeat")); + assert!(heartbeat.contains("9s elapsed")); + assert!(heartbeat.contains("phase running read_file")); + assert!(completed.contains("completed")); + assert!(completed.contains("3 steps total")); + assert!(failed.contains("failed")); + assert!(failed.contains("network timeout")); + } + + #[test] + fn describe_tool_progress_summarizes_known_tools() { + assert_eq!( + describe_tool_progress("read_file", r#"{"path":"src/main.rs"}"#), + "reading src/main.rs" + ); + assert!( + describe_tool_progress("bash", r#"{"command":"cargo test -p claw-cli"}"#) + .contains("cargo test -p claw-cli") + ); + assert_eq!( + describe_tool_progress("grep_search", r#"{"pattern":"ultraplan","path":"rust"}"#), + "grep `ultraplan` in rust" + ); + } + + #[test] + fn push_output_block_renders_markdown_text() { + let mut out = Vec::new(); + let mut events = Vec::new(); + let mut pending_tool = None; + + push_output_block( + OutputContentBlock::Text { + text: "# Heading".to_string(), + }, + &mut out, + &mut events, + &mut pending_tool, + false, + ) + .expect("text block should render"); + + let rendered = String::from_utf8(out).expect("utf8"); + assert!(rendered.contains("Heading")); + assert!(rendered.contains('\u{1b}')); + } + + #[test] + fn push_output_block_skips_empty_object_prefix_for_tool_streams() { + let mut out = Vec::new(); + let mut events = Vec::new(); + let mut pending_tool = None; + + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }, + &mut out, + &mut events, + &mut pending_tool, + true, + ) + .expect("tool block should accumulate"); + + assert!(events.is_empty()); + assert_eq!( + pending_tool, + Some(("tool-1".to_string(), "read_file".to_string(), String::new(),)) + ); + } + + #[test] + fn response_to_events_preserves_empty_object_json_input_outside_streaming() { + let mut out = Vec::new(); + let events = response_to_events( + MessageResponse { + id: "msg-1".to_string(), + kind: "message".to_string(), + model: "claude-opus-4-6".to_string(), + role: "assistant".to_string(), + content: vec![OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }], + stop_reason: Some("tool_use".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 1, + output_tokens: 1, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + request_id: None, + }, + &mut out, + ) + .expect("response conversion should succeed"); + + assert!(matches!( + &events[0], + AssistantEvent::ToolUse { name, input, .. } + if name == "read_file" && input == "{}" + )); + } + + #[test] + fn response_to_events_preserves_non_empty_json_input_outside_streaming() { + let mut out = Vec::new(); + let events = response_to_events( + MessageResponse { + id: "msg-2".to_string(), + kind: "message".to_string(), + model: "claude-opus-4-6".to_string(), + role: "assistant".to_string(), + content: vec![OutputContentBlock::ToolUse { + id: "tool-2".to_string(), + name: "read_file".to_string(), + input: json!({ "path": "rust/Cargo.toml" }), + }], + stop_reason: Some("tool_use".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 1, + output_tokens: 1, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + request_id: None, + }, + &mut out, + ) + .expect("response conversion should succeed"); + + assert!(matches!( + &events[0], + AssistantEvent::ToolUse { name, input, .. } + if name == "read_file" && input == "{\"path\":\"rust/Cargo.toml\"}" + )); + } + + #[test] + fn response_to_events_ignores_thinking_blocks() { + let mut out = Vec::new(); + let events = response_to_events( + MessageResponse { + id: "msg-3".to_string(), + kind: "message".to_string(), + model: "claude-opus-4-6".to_string(), + role: "assistant".to_string(), + content: vec![ + OutputContentBlock::Thinking { + thinking: "step 1".to_string(), + signature: Some("sig_123".to_string()), + }, + OutputContentBlock::Text { + text: "Final answer".to_string(), + }, + ], + stop_reason: Some("end_turn".to_string()), + stop_sequence: None, + usage: Usage { + input_tokens: 1, + output_tokens: 1, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }, + request_id: None, + }, + &mut out, + ) + .expect("response conversion should succeed"); + + assert!(matches!( + &events[0], + AssistantEvent::TextDelta(text) if text == "Final answer" + )); + assert!(!String::from_utf8(out).expect("utf8").contains("step 1")); + } +} diff --git a/rust/crates/claw-cli/src/render.rs b/rust/crates/claw-cli/src/render.rs new file mode 100644 index 0000000..01751fd --- /dev/null +++ b/rust/crates/claw-cli/src/render.rs @@ -0,0 +1,797 @@ +use std::fmt::Write as FmtWrite; +use std::io::{self, Write}; + +use crossterm::cursor::{MoveToColumn, RestorePosition, SavePosition}; +use crossterm::style::{Color, Print, ResetColor, SetForegroundColor, Stylize}; +use crossterm::terminal::{Clear, ClearType}; +use crossterm::{execute, queue}; +use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag, TagEnd}; +use syntect::easy::HighlightLines; +use syntect::highlighting::{Theme, ThemeSet}; +use syntect::parsing::SyntaxSet; +use syntect::util::{as_24_bit_terminal_escaped, LinesWithEndings}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ColorTheme { + heading: Color, + emphasis: Color, + strong: Color, + inline_code: Color, + link: Color, + quote: Color, + table_border: Color, + code_block_border: Color, + spinner_active: Color, + spinner_done: Color, + spinner_failed: Color, +} + +impl Default for ColorTheme { + fn default() -> Self { + Self { + heading: Color::Cyan, + emphasis: Color::Magenta, + strong: Color::Yellow, + inline_code: Color::Green, + link: Color::Blue, + quote: Color::DarkGrey, + table_border: Color::DarkCyan, + code_block_border: Color::DarkGrey, + spinner_active: Color::Blue, + spinner_done: Color::Green, + spinner_failed: Color::Red, + } + } +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct Spinner { + frame_index: usize, +} + +impl Spinner { + const FRAMES: [&str; 10] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]; + + #[must_use] + pub fn new() -> Self { + Self::default() + } + + pub fn tick( + &mut self, + label: &str, + theme: &ColorTheme, + out: &mut impl Write, + ) -> io::Result<()> { + let frame = Self::FRAMES[self.frame_index % Self::FRAMES.len()]; + self.frame_index += 1; + queue!( + out, + SavePosition, + MoveToColumn(0), + Clear(ClearType::CurrentLine), + SetForegroundColor(theme.spinner_active), + Print(format!("{frame} {label}")), + ResetColor, + RestorePosition + )?; + out.flush() + } + + pub fn finish( + &mut self, + label: &str, + theme: &ColorTheme, + out: &mut impl Write, + ) -> io::Result<()> { + self.frame_index = 0; + execute!( + out, + MoveToColumn(0), + Clear(ClearType::CurrentLine), + SetForegroundColor(theme.spinner_done), + Print(format!("✔ {label}\n")), + ResetColor + )?; + out.flush() + } + + pub fn fail( + &mut self, + label: &str, + theme: &ColorTheme, + out: &mut impl Write, + ) -> io::Result<()> { + self.frame_index = 0; + execute!( + out, + MoveToColumn(0), + Clear(ClearType::CurrentLine), + SetForegroundColor(theme.spinner_failed), + Print(format!("✘ {label}\n")), + ResetColor + )?; + out.flush() + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum ListKind { + Unordered, + Ordered { next_index: u64 }, +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +struct TableState { + headers: Vec<String>, + rows: Vec<Vec<String>>, + current_row: Vec<String>, + current_cell: String, + in_head: bool, +} + +impl TableState { + fn push_cell(&mut self) { + let cell = self.current_cell.trim().to_string(); + self.current_row.push(cell); + self.current_cell.clear(); + } + + fn finish_row(&mut self) { + if self.current_row.is_empty() { + return; + } + let row = std::mem::take(&mut self.current_row); + if self.in_head { + self.headers = row; + } else { + self.rows.push(row); + } + } +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +struct RenderState { + emphasis: usize, + strong: usize, + heading_level: Option<u8>, + quote: usize, + list_stack: Vec<ListKind>, + link_stack: Vec<LinkState>, + table: Option<TableState>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct LinkState { + destination: String, + text: String, +} + +impl RenderState { + fn style_text(&self, text: &str, theme: &ColorTheme) -> String { + let mut style = text.stylize(); + + if matches!(self.heading_level, Some(1 | 2)) || self.strong > 0 { + style = style.bold(); + } + if self.emphasis > 0 { + style = style.italic(); + } + + if let Some(level) = self.heading_level { + style = match level { + 1 => style.with(theme.heading), + 2 => style.white(), + 3 => style.with(Color::Blue), + _ => style.with(Color::Grey), + }; + } else if self.strong > 0 { + style = style.with(theme.strong); + } else if self.emphasis > 0 { + style = style.with(theme.emphasis); + } + + if self.quote > 0 { + style = style.with(theme.quote); + } + + format!("{style}") + } + + fn append_raw(&mut self, output: &mut String, text: &str) { + if let Some(link) = self.link_stack.last_mut() { + link.text.push_str(text); + } else if let Some(table) = self.table.as_mut() { + table.current_cell.push_str(text); + } else { + output.push_str(text); + } + } + + fn append_styled(&mut self, output: &mut String, text: &str, theme: &ColorTheme) { + let styled = self.style_text(text, theme); + self.append_raw(output, &styled); + } +} + +#[derive(Debug)] +pub struct TerminalRenderer { + syntax_set: SyntaxSet, + syntax_theme: Theme, + color_theme: ColorTheme, +} + +impl Default for TerminalRenderer { + fn default() -> Self { + let syntax_set = SyntaxSet::load_defaults_newlines(); + let syntax_theme = ThemeSet::load_defaults() + .themes + .remove("base16-ocean.dark") + .unwrap_or_default(); + Self { + syntax_set, + syntax_theme, + color_theme: ColorTheme::default(), + } + } +} + +impl TerminalRenderer { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + #[must_use] + pub fn color_theme(&self) -> &ColorTheme { + &self.color_theme + } + + #[must_use] + pub fn render_markdown(&self, markdown: &str) -> String { + let mut output = String::new(); + let mut state = RenderState::default(); + let mut code_language = String::new(); + let mut code_buffer = String::new(); + let mut in_code_block = false; + + for event in Parser::new_ext(markdown, Options::all()) { + self.render_event( + event, + &mut state, + &mut output, + &mut code_buffer, + &mut code_language, + &mut in_code_block, + ); + } + + output.trim_end().to_string() + } + + #[must_use] + pub fn markdown_to_ansi(&self, markdown: &str) -> String { + self.render_markdown(markdown) + } + + #[allow(clippy::too_many_lines)] + fn render_event( + &self, + event: Event<'_>, + state: &mut RenderState, + output: &mut String, + code_buffer: &mut String, + code_language: &mut String, + in_code_block: &mut bool, + ) { + match event { + Event::Start(Tag::Heading { level, .. }) => { + self.start_heading(state, level as u8, output); + } + Event::End(TagEnd::Paragraph) => output.push_str("\n\n"), + Event::Start(Tag::BlockQuote(..)) => self.start_quote(state, output), + Event::End(TagEnd::BlockQuote(..)) => { + state.quote = state.quote.saturating_sub(1); + output.push('\n'); + } + Event::End(TagEnd::Heading(..)) => { + state.heading_level = None; + output.push_str("\n\n"); + } + Event::End(TagEnd::Item) | Event::SoftBreak | Event::HardBreak => { + state.append_raw(output, "\n"); + } + Event::Start(Tag::List(first_item)) => { + let kind = match first_item { + Some(index) => ListKind::Ordered { next_index: index }, + None => ListKind::Unordered, + }; + state.list_stack.push(kind); + } + Event::End(TagEnd::List(..)) => { + state.list_stack.pop(); + output.push('\n'); + } + Event::Start(Tag::Item) => Self::start_item(state, output), + Event::Start(Tag::CodeBlock(kind)) => { + *in_code_block = true; + *code_language = match kind { + CodeBlockKind::Indented => String::from("text"), + CodeBlockKind::Fenced(lang) => lang.to_string(), + }; + code_buffer.clear(); + self.start_code_block(code_language, output); + } + Event::End(TagEnd::CodeBlock) => { + self.finish_code_block(code_buffer, code_language, output); + *in_code_block = false; + code_language.clear(); + code_buffer.clear(); + } + Event::Start(Tag::Emphasis) => state.emphasis += 1, + Event::End(TagEnd::Emphasis) => state.emphasis = state.emphasis.saturating_sub(1), + Event::Start(Tag::Strong) => state.strong += 1, + Event::End(TagEnd::Strong) => state.strong = state.strong.saturating_sub(1), + Event::Code(code) => { + let rendered = + format!("{}", format!("`{code}`").with(self.color_theme.inline_code)); + state.append_raw(output, &rendered); + } + Event::Rule => output.push_str("---\n"), + Event::Text(text) => { + self.push_text(text.as_ref(), state, output, code_buffer, *in_code_block); + } + Event::Html(html) | Event::InlineHtml(html) => { + state.append_raw(output, &html); + } + Event::FootnoteReference(reference) => { + state.append_raw(output, &format!("[{reference}]")); + } + Event::TaskListMarker(done) => { + state.append_raw(output, if done { "[x] " } else { "[ ] " }); + } + Event::InlineMath(math) | Event::DisplayMath(math) => { + state.append_raw(output, &math); + } + Event::Start(Tag::Link { dest_url, .. }) => { + state.link_stack.push(LinkState { + destination: dest_url.to_string(), + text: String::new(), + }); + } + Event::End(TagEnd::Link) => { + if let Some(link) = state.link_stack.pop() { + let label = if link.text.is_empty() { + link.destination.clone() + } else { + link.text + }; + let rendered = format!( + "{}", + format!("[{label}]({})", link.destination) + .underlined() + .with(self.color_theme.link) + ); + state.append_raw(output, &rendered); + } + } + Event::Start(Tag::Image { dest_url, .. }) => { + let rendered = format!( + "{}", + format!("[image:{dest_url}]").with(self.color_theme.link) + ); + state.append_raw(output, &rendered); + } + Event::Start(Tag::Table(..)) => state.table = Some(TableState::default()), + Event::End(TagEnd::Table) => { + if let Some(table) = state.table.take() { + output.push_str(&self.render_table(&table)); + output.push_str("\n\n"); + } + } + Event::Start(Tag::TableHead) => { + if let Some(table) = state.table.as_mut() { + table.in_head = true; + } + } + Event::End(TagEnd::TableHead) => { + if let Some(table) = state.table.as_mut() { + table.finish_row(); + table.in_head = false; + } + } + Event::Start(Tag::TableRow) => { + if let Some(table) = state.table.as_mut() { + table.current_row.clear(); + table.current_cell.clear(); + } + } + Event::End(TagEnd::TableRow) => { + if let Some(table) = state.table.as_mut() { + table.finish_row(); + } + } + Event::Start(Tag::TableCell) => { + if let Some(table) = state.table.as_mut() { + table.current_cell.clear(); + } + } + Event::End(TagEnd::TableCell) => { + if let Some(table) = state.table.as_mut() { + table.push_cell(); + } + } + Event::Start(Tag::Paragraph | Tag::MetadataBlock(..) | _) + | Event::End(TagEnd::Image | TagEnd::MetadataBlock(..) | _) => {} + } + } + + #[allow(clippy::unused_self)] + fn start_heading(&self, state: &mut RenderState, level: u8, output: &mut String) { + state.heading_level = Some(level); + if !output.is_empty() { + output.push('\n'); + } + } + + fn start_quote(&self, state: &mut RenderState, output: &mut String) { + state.quote += 1; + let _ = write!(output, "{}", "│ ".with(self.color_theme.quote)); + } + + fn start_item(state: &mut RenderState, output: &mut String) { + let depth = state.list_stack.len().saturating_sub(1); + output.push_str(&" ".repeat(depth)); + + let marker = match state.list_stack.last_mut() { + Some(ListKind::Ordered { next_index }) => { + let value = *next_index; + *next_index += 1; + format!("{value}. ") + } + _ => "• ".to_string(), + }; + output.push_str(&marker); + } + + fn start_code_block(&self, code_language: &str, output: &mut String) { + let label = if code_language.is_empty() { + "code".to_string() + } else { + code_language.to_string() + }; + let _ = writeln!( + output, + "{}", + format!("╭─ {label}") + .bold() + .with(self.color_theme.code_block_border) + ); + } + + fn finish_code_block(&self, code_buffer: &str, code_language: &str, output: &mut String) { + output.push_str(&self.highlight_code(code_buffer, code_language)); + let _ = write!( + output, + "{}", + "╰─".bold().with(self.color_theme.code_block_border) + ); + output.push_str("\n\n"); + } + + fn push_text( + &self, + text: &str, + state: &mut RenderState, + output: &mut String, + code_buffer: &mut String, + in_code_block: bool, + ) { + if in_code_block { + code_buffer.push_str(text); + } else { + state.append_styled(output, text, &self.color_theme); + } + } + + fn render_table(&self, table: &TableState) -> String { + let mut rows = Vec::new(); + if !table.headers.is_empty() { + rows.push(table.headers.clone()); + } + rows.extend(table.rows.iter().cloned()); + + if rows.is_empty() { + return String::new(); + } + + let column_count = rows.iter().map(Vec::len).max().unwrap_or(0); + let widths = (0..column_count) + .map(|column| { + rows.iter() + .filter_map(|row| row.get(column)) + .map(|cell| visible_width(cell)) + .max() + .unwrap_or(0) + }) + .collect::<Vec<_>>(); + + let border = format!("{}", "│".with(self.color_theme.table_border)); + let separator = widths + .iter() + .map(|width| "─".repeat(*width + 2)) + .collect::<Vec<_>>() + .join(&format!("{}", "┼".with(self.color_theme.table_border))); + let separator = format!("{border}{separator}{border}"); + + let mut output = String::new(); + if !table.headers.is_empty() { + output.push_str(&self.render_table_row(&table.headers, &widths, true)); + output.push('\n'); + output.push_str(&separator); + if !table.rows.is_empty() { + output.push('\n'); + } + } + + for (index, row) in table.rows.iter().enumerate() { + output.push_str(&self.render_table_row(row, &widths, false)); + if index + 1 < table.rows.len() { + output.push('\n'); + } + } + + output + } + + fn render_table_row(&self, row: &[String], widths: &[usize], is_header: bool) -> String { + let border = format!("{}", "│".with(self.color_theme.table_border)); + let mut line = String::new(); + line.push_str(&border); + + for (index, width) in widths.iter().enumerate() { + let cell = row.get(index).map_or("", String::as_str); + line.push(' '); + if is_header { + let _ = write!(line, "{}", cell.bold().with(self.color_theme.heading)); + } else { + line.push_str(cell); + } + let padding = width.saturating_sub(visible_width(cell)); + line.push_str(&" ".repeat(padding + 1)); + line.push_str(&border); + } + + line + } + + #[must_use] + pub fn highlight_code(&self, code: &str, language: &str) -> String { + let syntax = self + .syntax_set + .find_syntax_by_token(language) + .unwrap_or_else(|| self.syntax_set.find_syntax_plain_text()); + let mut syntax_highlighter = HighlightLines::new(syntax, &self.syntax_theme); + let mut colored_output = String::new(); + + for line in LinesWithEndings::from(code) { + match syntax_highlighter.highlight_line(line, &self.syntax_set) { + Ok(ranges) => { + let escaped = as_24_bit_terminal_escaped(&ranges[..], false); + colored_output.push_str(&apply_code_block_background(&escaped)); + } + Err(_) => colored_output.push_str(&apply_code_block_background(line)), + } + } + + colored_output + } + + pub fn stream_markdown(&self, markdown: &str, out: &mut impl Write) -> io::Result<()> { + let rendered_markdown = self.markdown_to_ansi(markdown); + write!(out, "{rendered_markdown}")?; + if !rendered_markdown.ends_with('\n') { + writeln!(out)?; + } + out.flush() + } +} + +#[derive(Debug, Default, Clone, PartialEq, Eq)] +pub struct MarkdownStreamState { + pending: String, +} + +impl MarkdownStreamState { + #[must_use] + pub fn push(&mut self, renderer: &TerminalRenderer, delta: &str) -> Option<String> { + self.pending.push_str(delta); + let split = find_stream_safe_boundary(&self.pending)?; + let ready = self.pending[..split].to_string(); + self.pending.drain(..split); + Some(renderer.markdown_to_ansi(&ready)) + } + + #[must_use] + pub fn flush(&mut self, renderer: &TerminalRenderer) -> Option<String> { + if self.pending.trim().is_empty() { + self.pending.clear(); + None + } else { + let pending = std::mem::take(&mut self.pending); + Some(renderer.markdown_to_ansi(&pending)) + } + } +} + +fn apply_code_block_background(line: &str) -> String { + let trimmed = line.trim_end_matches('\n'); + let trailing_newline = if trimmed.len() == line.len() { + "" + } else { + "\n" + }; + let with_background = trimmed.replace("\u{1b}[0m", "\u{1b}[0;48;5;236m"); + format!("\u{1b}[48;5;236m{with_background}\u{1b}[0m{trailing_newline}") +} + +fn find_stream_safe_boundary(markdown: &str) -> Option<usize> { + let mut in_fence = false; + let mut last_boundary = None; + + for (offset, line) in markdown.split_inclusive('\n').scan(0usize, |cursor, line| { + let start = *cursor; + *cursor += line.len(); + Some((start, line)) + }) { + let trimmed = line.trim_start(); + if trimmed.starts_with("```") || trimmed.starts_with("~~~") { + in_fence = !in_fence; + if !in_fence { + last_boundary = Some(offset + line.len()); + } + continue; + } + + if in_fence { + continue; + } + + if trimmed.is_empty() { + last_boundary = Some(offset + line.len()); + } + } + + last_boundary +} + +fn visible_width(input: &str) -> usize { + strip_ansi(input).chars().count() +} + +fn strip_ansi(input: &str) -> String { + let mut output = String::new(); + let mut chars = input.chars().peekable(); + + while let Some(ch) = chars.next() { + if ch == '\u{1b}' { + if chars.peek() == Some(&'[') { + chars.next(); + for next in chars.by_ref() { + if next.is_ascii_alphabetic() { + break; + } + } + } + } else { + output.push(ch); + } + } + + output +} + +#[cfg(test)] +mod tests { + use super::{strip_ansi, MarkdownStreamState, Spinner, TerminalRenderer}; + + #[test] + fn renders_markdown_with_styling_and_lists() { + let terminal_renderer = TerminalRenderer::new(); + let markdown_output = terminal_renderer + .render_markdown("# Heading\n\nThis is **bold** and *italic*.\n\n- item\n\n`code`"); + + assert!(markdown_output.contains("Heading")); + assert!(markdown_output.contains("• item")); + assert!(markdown_output.contains("code")); + assert!(markdown_output.contains('\u{1b}')); + } + + #[test] + fn renders_links_as_colored_markdown_labels() { + let terminal_renderer = TerminalRenderer::new(); + let markdown_output = + terminal_renderer.render_markdown("See [Claw](https://example.com/docs) now."); + let plain_text = strip_ansi(&markdown_output); + + assert!(plain_text.contains("[Claw](https://example.com/docs)")); + assert!(markdown_output.contains('\u{1b}')); + } + + #[test] + fn highlights_fenced_code_blocks() { + let terminal_renderer = TerminalRenderer::new(); + let markdown_output = + terminal_renderer.markdown_to_ansi("```rust\nfn hi() { println!(\"hi\"); }\n```"); + let plain_text = strip_ansi(&markdown_output); + + assert!(plain_text.contains("╭─ rust")); + assert!(plain_text.contains("fn hi")); + assert!(markdown_output.contains('\u{1b}')); + assert!(markdown_output.contains("[48;5;236m")); + } + + #[test] + fn renders_ordered_and_nested_lists() { + let terminal_renderer = TerminalRenderer::new(); + let markdown_output = + terminal_renderer.render_markdown("1. first\n2. second\n - nested\n - child"); + let plain_text = strip_ansi(&markdown_output); + + assert!(plain_text.contains("1. first")); + assert!(plain_text.contains("2. second")); + assert!(plain_text.contains(" • nested")); + assert!(plain_text.contains(" • child")); + } + + #[test] + fn renders_tables_with_alignment() { + let terminal_renderer = TerminalRenderer::new(); + let markdown_output = terminal_renderer + .render_markdown("| Name | Value |\n| ---- | ----- |\n| alpha | 1 |\n| beta | 22 |"); + let plain_text = strip_ansi(&markdown_output); + let lines = plain_text.lines().collect::<Vec<_>>(); + + assert_eq!(lines[0], "│ Name │ Value │"); + assert_eq!(lines[1], "│───────┼───────│"); + assert_eq!(lines[2], "│ alpha │ 1 │"); + assert_eq!(lines[3], "│ beta │ 22 │"); + assert!(markdown_output.contains('\u{1b}')); + } + + #[test] + fn streaming_state_waits_for_complete_blocks() { + let renderer = TerminalRenderer::new(); + let mut state = MarkdownStreamState::default(); + + assert_eq!(state.push(&renderer, "# Heading"), None); + let flushed = state + .push(&renderer, "\n\nParagraph\n\n") + .expect("completed block"); + let plain_text = strip_ansi(&flushed); + assert!(plain_text.contains("Heading")); + assert!(plain_text.contains("Paragraph")); + + assert_eq!(state.push(&renderer, "```rust\nfn main() {}\n"), None); + let code = state + .push(&renderer, "```\n") + .expect("closed code fence flushes"); + assert!(strip_ansi(&code).contains("fn main()")); + } + + #[test] + fn spinner_advances_frames() { + let terminal_renderer = TerminalRenderer::new(); + let mut spinner = Spinner::new(); + let mut out = Vec::new(); + spinner + .tick("Working", terminal_renderer.color_theme(), &mut out) + .expect("tick succeeds"); + spinner + .tick("Working", terminal_renderer.color_theme(), &mut out) + .expect("tick succeeds"); + + let output = String::from_utf8_lossy(&out); + assert!(output.contains("Working")); + } +} diff --git a/rust/crates/commands/Cargo.toml b/rust/crates/commands/Cargo.toml index d465bff..2263f7a 100644 --- a/rust/crates/commands/Cargo.toml +++ b/rust/crates/commands/Cargo.toml @@ -9,4 +9,6 @@ publish.workspace = true workspace = true [dependencies] +plugins = { path = "../plugins" } runtime = { path = "../runtime" } +serde_json.workspace = true diff --git a/rust/crates/commands/src/lib.rs b/rust/crates/commands/src/lib.rs index b396bb0..da7f1a4 100644 --- a/rust/crates/commands/src/lib.rs +++ b/rust/crates/commands/src/lib.rs @@ -1,3 +1,12 @@ +use std::collections::BTreeMap; +use std::env; +use std::fs; +use std::io; +use std::path::{Path, PathBuf}; +use std::process::Command; +use std::time::{SystemTime, UNIX_EPOCH}; + +use plugins::{PluginError, PluginManager, PluginSummary}; use runtime::{compact_session, CompactionConfig, Session}; #[derive(Debug, Clone, PartialEq, Eq)] @@ -30,104 +39,263 @@ impl CommandRegistry { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SlashCommandCategory { + Core, + Workspace, + Session, + Git, + Automation, +} + +impl SlashCommandCategory { + const fn title(self) -> &'static str { + match self { + Self::Core => "Core flow", + Self::Workspace => "Workspace & memory", + Self::Session => "Sessions & output", + Self::Git => "Git & GitHub", + Self::Automation => "Automation & discovery", + } + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct SlashCommandSpec { pub name: &'static str, + pub aliases: &'static [&'static str], pub summary: &'static str, pub argument_hint: Option<&'static str>, pub resume_supported: bool, + pub category: SlashCommandCategory, } const SLASH_COMMAND_SPECS: &[SlashCommandSpec] = &[ SlashCommandSpec { name: "help", + aliases: &[], summary: "Show available slash commands", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "status", + aliases: &[], summary: "Show current session status", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "compact", + aliases: &[], summary: "Compact local session history", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "model", + aliases: &[], summary: "Show or switch the active model", argument_hint: Some("[model]"), resume_supported: false, + category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "permissions", + aliases: &[], summary: "Show or switch the active permission mode", argument_hint: Some("[read-only|workspace-write|danger-full-access]"), resume_supported: false, + category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "clear", + aliases: &[], summary: "Start a fresh local session", argument_hint: Some("[--confirm]"), resume_supported: true, + category: SlashCommandCategory::Session, }, SlashCommandSpec { name: "cost", + aliases: &[], summary: "Show cumulative token usage for this session", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Core, }, SlashCommandSpec { name: "resume", + aliases: &[], summary: "Load a saved session into the REPL", argument_hint: Some("<session-path>"), resume_supported: false, + category: SlashCommandCategory::Session, }, SlashCommandSpec { name: "config", - summary: "Inspect Claude config files or merged sections", - argument_hint: Some("[env|hooks|model]"), + aliases: &[], + summary: "Inspect Claw config files or merged sections", + argument_hint: Some("[env|hooks|model|plugins]"), resume_supported: true, + category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "memory", - summary: "Inspect loaded Claude instruction memory files", + aliases: &[], + summary: "Inspect loaded Claw instruction memory files", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "init", - summary: "Create a starter CLAUDE.md for this repo", + aliases: &[], + summary: "Create a starter CLAW.md for this repo", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "diff", + aliases: &[], summary: "Show git diff for current workspace changes", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Workspace, }, SlashCommandSpec { name: "version", + aliases: &[], summary: "Show CLI version and build information", argument_hint: None, resume_supported: true, + category: SlashCommandCategory::Workspace, + }, + SlashCommandSpec { + name: "bughunter", + aliases: &[], + summary: "Inspect the codebase for likely bugs", + argument_hint: Some("[scope]"), + resume_supported: false, + category: SlashCommandCategory::Automation, + }, + SlashCommandSpec { + name: "branch", + aliases: &[], + summary: "List, create, or switch git branches", + argument_hint: Some("[list|create <name>|switch <name>]"), + resume_supported: false, + category: SlashCommandCategory::Git, + }, + SlashCommandSpec { + name: "worktree", + aliases: &[], + summary: "List, add, remove, or prune git worktrees", + argument_hint: Some("[list|add <path> [branch]|remove <path>|prune]"), + resume_supported: false, + category: SlashCommandCategory::Git, + }, + SlashCommandSpec { + name: "commit", + aliases: &[], + summary: "Generate a commit message and create a git commit", + argument_hint: None, + resume_supported: false, + category: SlashCommandCategory::Git, + }, + SlashCommandSpec { + name: "commit-push-pr", + aliases: &[], + summary: "Commit workspace changes, push the branch, and open a PR", + argument_hint: Some("[context]"), + resume_supported: false, + category: SlashCommandCategory::Git, + }, + SlashCommandSpec { + name: "pr", + aliases: &[], + summary: "Draft or create a pull request from the conversation", + argument_hint: Some("[context]"), + resume_supported: false, + category: SlashCommandCategory::Git, + }, + SlashCommandSpec { + name: "issue", + aliases: &[], + summary: "Draft or create a GitHub issue from the conversation", + argument_hint: Some("[context]"), + resume_supported: false, + category: SlashCommandCategory::Git, + }, + SlashCommandSpec { + name: "ultraplan", + aliases: &[], + summary: "Run a deep planning prompt with multi-step reasoning", + argument_hint: Some("[task]"), + resume_supported: false, + category: SlashCommandCategory::Automation, + }, + SlashCommandSpec { + name: "teleport", + aliases: &[], + summary: "Jump to a file or symbol by searching the workspace", + argument_hint: Some("<symbol-or-path>"), + resume_supported: false, + category: SlashCommandCategory::Workspace, + }, + SlashCommandSpec { + name: "debug-tool-call", + aliases: &[], + summary: "Replay the last tool call with debug details", + argument_hint: None, + resume_supported: false, + category: SlashCommandCategory::Automation, }, SlashCommandSpec { name: "export", + aliases: &[], summary: "Export the current conversation to a file", argument_hint: Some("[file]"), resume_supported: true, + category: SlashCommandCategory::Session, }, SlashCommandSpec { name: "session", + aliases: &[], summary: "List or switch managed local sessions", argument_hint: Some("[list|switch <session-id>]"), resume_supported: false, + category: SlashCommandCategory::Session, + }, + SlashCommandSpec { + name: "plugin", + aliases: &["plugins", "marketplace"], + summary: "Manage Claw Code plugins", + argument_hint: Some( + "[list|install <path>|enable <name>|disable <name>|uninstall <id>|update <id>]", + ), + resume_supported: false, + category: SlashCommandCategory::Automation, + }, + SlashCommandSpec { + name: "agents", + aliases: &[], + summary: "List configured agents", + argument_hint: None, + resume_supported: true, + category: SlashCommandCategory::Automation, + }, + SlashCommandSpec { + name: "skills", + aliases: &[], + summary: "List available skills", + argument_hint: None, + resume_supported: true, + category: SlashCommandCategory::Automation, }, ]; @@ -136,6 +304,35 @@ pub enum SlashCommand { Help, Status, Compact, + Branch { + action: Option<String>, + target: Option<String>, + }, + Bughunter { + scope: Option<String>, + }, + Worktree { + action: Option<String>, + path: Option<String>, + branch: Option<String>, + }, + Commit, + CommitPushPr { + context: Option<String>, + }, + Pr { + context: Option<String>, + }, + Issue { + context: Option<String>, + }, + Ultraplan { + task: Option<String>, + }, + Teleport { + target: Option<String>, + }, + DebugToolCall, Model { model: Option<String>, }, @@ -163,6 +360,16 @@ pub enum SlashCommand { action: Option<String>, target: Option<String>, }, + Plugins { + action: Option<String>, + target: Option<String>, + }, + Agents { + args: Option<String>, + }, + Skills { + args: Option<String>, + }, Unknown(String), } @@ -180,6 +387,35 @@ impl SlashCommand { "help" => Self::Help, "status" => Self::Status, "compact" => Self::Compact, + "branch" => Self::Branch { + action: parts.next().map(ToOwned::to_owned), + target: parts.next().map(ToOwned::to_owned), + }, + "bughunter" => Self::Bughunter { + scope: remainder_after_command(trimmed, command), + }, + "worktree" => Self::Worktree { + action: parts.next().map(ToOwned::to_owned), + path: parts.next().map(ToOwned::to_owned), + branch: parts.next().map(ToOwned::to_owned), + }, + "commit" => Self::Commit, + "commit-push-pr" => Self::CommitPushPr { + context: remainder_after_command(trimmed, command), + }, + "pr" => Self::Pr { + context: remainder_after_command(trimmed, command), + }, + "issue" => Self::Issue { + context: remainder_after_command(trimmed, command), + }, + "ultraplan" => Self::Ultraplan { + task: remainder_after_command(trimmed, command), + }, + "teleport" => Self::Teleport { + target: remainder_after_command(trimmed, command), + }, + "debug-tool-call" => Self::DebugToolCall, "model" => Self::Model { model: parts.next().map(ToOwned::to_owned), }, @@ -207,11 +443,33 @@ impl SlashCommand { action: parts.next().map(ToOwned::to_owned), target: parts.next().map(ToOwned::to_owned), }, + "plugin" | "plugins" | "marketplace" => Self::Plugins { + action: parts.next().map(ToOwned::to_owned), + target: { + let remainder = parts.collect::<Vec<_>>().join(" "); + (!remainder.is_empty()).then_some(remainder) + }, + }, + "agents" => Self::Agents { + args: remainder_after_command(trimmed, command), + }, + "skills" => Self::Skills { + args: remainder_after_command(trimmed, command), + }, other => Self::Unknown(other.to_string()), }) } } +fn remainder_after_command(input: &str, command: &str) -> Option<String> { + input + .trim() + .strip_prefix(&format!("/{command}")) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToOwned::to_owned) +} + #[must_use] pub fn slash_command_specs() -> &'static [SlashCommandSpec] { SLASH_COMMAND_SPECS @@ -229,29 +487,1250 @@ pub fn resume_supported_slash_commands() -> Vec<&'static SlashCommandSpec> { pub fn render_slash_command_help() -> String { let mut lines = vec![ "Slash commands".to_string(), - " [resume] means the command also works with --resume SESSION.json".to_string(), + " Tab completes commands inside the REPL.".to_string(), + " [resume] = also available via claw --resume SESSION.json".to_string(), ]; - for spec in slash_command_specs() { - let name = match spec.argument_hint { - Some(argument_hint) => format!("/{} {}", spec.name, argument_hint), - None => format!("/{}", spec.name), - }; - let resume = if spec.resume_supported { - " [resume]" - } else { - "" - }; - lines.push(format!(" {name:<20} {}{}", spec.summary, resume)); + + for category in [ + SlashCommandCategory::Core, + SlashCommandCategory::Workspace, + SlashCommandCategory::Session, + SlashCommandCategory::Git, + SlashCommandCategory::Automation, + ] { + lines.push(String::new()); + lines.push(category.title().to_string()); + lines.extend( + slash_command_specs() + .iter() + .filter(|spec| spec.category == category) + .map(render_slash_command_entry), + ); } + lines.join("\n") } +fn render_slash_command_entry(spec: &SlashCommandSpec) -> String { + let alias_suffix = if spec.aliases.is_empty() { + String::new() + } else { + format!( + " (aliases: {})", + spec.aliases + .iter() + .map(|alias| format!("/{alias}")) + .collect::<Vec<_>>() + .join(", ") + ) + }; + let resume = if spec.resume_supported { + " [resume]" + } else { + "" + }; + format!( + " {name:<46} {}{alias_suffix}{resume}", + spec.summary, + name = render_slash_command_name(spec), + ) +} + +fn render_slash_command_name(spec: &SlashCommandSpec) -> String { + match spec.argument_hint { + Some(argument_hint) => format!("/{} {}", spec.name, argument_hint), + None => format!("/{}", spec.name), + } +} + +fn levenshtein_distance(left: &str, right: &str) -> usize { + if left == right { + return 0; + } + if left.is_empty() { + return right.chars().count(); + } + if right.is_empty() { + return left.chars().count(); + } + + let right_chars = right.chars().collect::<Vec<_>>(); + let mut previous = (0..=right_chars.len()).collect::<Vec<_>>(); + let mut current = vec![0; right_chars.len() + 1]; + + for (left_index, left_char) in left.chars().enumerate() { + current[0] = left_index + 1; + for (right_index, right_char) in right_chars.iter().enumerate() { + let cost = usize::from(left_char != *right_char); + current[right_index + 1] = (previous[right_index + 1] + 1) + .min(current[right_index] + 1) + .min(previous[right_index] + cost); + } + std::mem::swap(&mut previous, &mut current); + } + + previous[right_chars.len()] +} + +#[must_use] +pub fn suggest_slash_commands(input: &str, limit: usize) -> Vec<String> { + let normalized = input.trim().trim_start_matches('/').to_ascii_lowercase(); + if normalized.is_empty() || limit == 0 { + return Vec::new(); + } + + let mut ranked = slash_command_specs() + .iter() + .filter_map(|spec| { + let score = std::iter::once(spec.name) + .chain(spec.aliases.iter().copied()) + .map(str::to_ascii_lowercase) + .filter_map(|alias| { + if alias == normalized { + Some((0_usize, alias.len())) + } else if alias.starts_with(&normalized) { + Some((1, alias.len())) + } else if alias.contains(&normalized) { + Some((2, alias.len())) + } else { + let distance = levenshtein_distance(&alias, &normalized); + (distance <= 2).then_some((3 + distance, alias.len())) + } + }) + .min(); + + score.map(|(bucket, len)| (bucket, len, render_slash_command_name(spec))) + }) + .collect::<Vec<_>>(); + + ranked.sort_by(|left, right| left.cmp(right)); + ranked.dedup_by(|left, right| left.2 == right.2); + ranked + .into_iter() + .take(limit) + .map(|(_, _, display)| display) + .collect() +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct SlashCommandResult { pub message: String, pub session: Session, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PluginsCommandResult { + pub message: String, + pub reload_runtime: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +enum DefinitionSource { + ProjectCodex, + ProjectClaw, + UserCodexHome, + UserCodex, + UserClaw, +} + +impl DefinitionSource { + fn label(self) -> &'static str { + match self { + Self::ProjectCodex => "Project (.codex)", + Self::ProjectClaw => "Project (.claw)", + Self::UserCodexHome => "User ($CODEX_HOME)", + Self::UserCodex => "User (~/.codex)", + Self::UserClaw => "User (~/.claw)", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct AgentSummary { + name: String, + description: Option<String>, + model: Option<String>, + reasoning_effort: Option<String>, + source: DefinitionSource, + shadowed_by: Option<DefinitionSource>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SkillSummary { + name: String, + description: Option<String>, + source: DefinitionSource, + shadowed_by: Option<DefinitionSource>, + origin: SkillOrigin, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SkillOrigin { + SkillsDir, + LegacyCommandsDir, +} + +impl SkillOrigin { + fn detail_label(self) -> Option<&'static str> { + match self { + Self::SkillsDir => None, + Self::LegacyCommandsDir => Some("legacy /commands"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct SkillRoot { + source: DefinitionSource, + path: PathBuf, + origin: SkillOrigin, +} + +#[allow(clippy::too_many_lines)] +pub fn handle_plugins_slash_command( + action: Option<&str>, + target: Option<&str>, + manager: &mut PluginManager, +) -> Result<PluginsCommandResult, PluginError> { + match action { + None | Some("list") => Ok(PluginsCommandResult { + message: render_plugins_report(&manager.list_installed_plugins()?), + reload_runtime: false, + }), + Some("install") => { + let Some(target) = target else { + return Ok(PluginsCommandResult { + message: "Usage: /plugins install <path>".to_string(), + reload_runtime: false, + }); + }; + let install = manager.install(target)?; + let plugin = manager + .list_installed_plugins()? + .into_iter() + .find(|plugin| plugin.metadata.id == install.plugin_id); + Ok(PluginsCommandResult { + message: render_plugin_install_report(&install.plugin_id, plugin.as_ref()), + reload_runtime: true, + }) + } + Some("enable") => { + let Some(target) = target else { + return Ok(PluginsCommandResult { + message: "Usage: /plugins enable <name>".to_string(), + reload_runtime: false, + }); + }; + let plugin = resolve_plugin_target(manager, target)?; + manager.enable(&plugin.metadata.id)?; + Ok(PluginsCommandResult { + message: format!( + "Plugins\n Result enabled {}\n Name {}\n Version {}\n Status enabled", + plugin.metadata.id, plugin.metadata.name, plugin.metadata.version + ), + reload_runtime: true, + }) + } + Some("disable") => { + let Some(target) = target else { + return Ok(PluginsCommandResult { + message: "Usage: /plugins disable <name>".to_string(), + reload_runtime: false, + }); + }; + let plugin = resolve_plugin_target(manager, target)?; + manager.disable(&plugin.metadata.id)?; + Ok(PluginsCommandResult { + message: format!( + "Plugins\n Result disabled {}\n Name {}\n Version {}\n Status disabled", + plugin.metadata.id, plugin.metadata.name, plugin.metadata.version + ), + reload_runtime: true, + }) + } + Some("uninstall") => { + let Some(target) = target else { + return Ok(PluginsCommandResult { + message: "Usage: /plugins uninstall <plugin-id>".to_string(), + reload_runtime: false, + }); + }; + manager.uninstall(target)?; + Ok(PluginsCommandResult { + message: format!("Plugins\n Result uninstalled {target}"), + reload_runtime: true, + }) + } + Some("update") => { + let Some(target) = target else { + return Ok(PluginsCommandResult { + message: "Usage: /plugins update <plugin-id>".to_string(), + reload_runtime: false, + }); + }; + let update = manager.update(target)?; + let plugin = manager + .list_installed_plugins()? + .into_iter() + .find(|plugin| plugin.metadata.id == update.plugin_id); + Ok(PluginsCommandResult { + message: format!( + "Plugins\n Result updated {}\n Name {}\n Old version {}\n New version {}\n Status {}", + update.plugin_id, + plugin + .as_ref() + .map_or_else(|| update.plugin_id.clone(), |plugin| plugin.metadata.name.clone()), + update.old_version, + update.new_version, + plugin + .as_ref() + .map_or("unknown", |plugin| if plugin.enabled { "enabled" } else { "disabled" }), + ), + reload_runtime: true, + }) + } + Some(other) => Ok(PluginsCommandResult { + message: format!( + "Unknown /plugins action '{other}'. Use list, install, enable, disable, uninstall, or update." + ), + reload_runtime: false, + }), + } +} + +pub fn handle_agents_slash_command(args: Option<&str>, cwd: &Path) -> std::io::Result<String> { + match normalize_optional_args(args) { + None | Some("list") => { + let roots = discover_definition_roots(cwd, "agents"); + let agents = load_agents_from_roots(&roots)?; + Ok(render_agents_report(&agents)) + } + Some("-h" | "--help" | "help") => Ok(render_agents_usage(None)), + Some(args) => Ok(render_agents_usage(Some(args))), + } +} + +pub fn handle_skills_slash_command(args: Option<&str>, cwd: &Path) -> std::io::Result<String> { + match normalize_optional_args(args) { + None | Some("list") => { + let roots = discover_skill_roots(cwd); + let skills = load_skills_from_roots(&roots)?; + Ok(render_skills_report(&skills)) + } + Some("-h" | "--help" | "help") => Ok(render_skills_usage(None)), + Some(args) => Ok(render_skills_usage(Some(args))), + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CommitPushPrRequest { + pub commit_message: Option<String>, + pub pr_title: String, + pub pr_body: String, + pub branch_name_hint: String, +} + +pub fn handle_branch_slash_command( + action: Option<&str>, + target: Option<&str>, + cwd: &Path, +) -> io::Result<String> { + match normalize_optional_args(action) { + None | Some("list") => { + let branches = git_stdout(cwd, &["branch", "--list", "--verbose"])?; + let trimmed = branches.trim(); + Ok(if trimmed.is_empty() { + "Branch\n Result no branches found".to_string() + } else { + format!("Branch\n Result listed\n\n{}", trimmed) + }) + } + Some("create") => { + let Some(target) = target.filter(|value| !value.trim().is_empty()) else { + return Ok("Usage: /branch create <name>".to_string()); + }; + git_status_ok(cwd, &["switch", "-c", target])?; + Ok(format!( + "Branch\n Result created and switched\n Branch {target}" + )) + } + Some("switch") => { + let Some(target) = target.filter(|value| !value.trim().is_empty()) else { + return Ok("Usage: /branch switch <name>".to_string()); + }; + git_status_ok(cwd, &["switch", target])?; + Ok(format!( + "Branch\n Result switched\n Branch {target}" + )) + } + Some(other) => Ok(format!( + "Unknown /branch action '{other}'. Use /branch list, /branch create <name>, or /branch switch <name>." + )), + } +} + +pub fn handle_worktree_slash_command( + action: Option<&str>, + path: Option<&str>, + branch: Option<&str>, + cwd: &Path, +) -> io::Result<String> { + match normalize_optional_args(action) { + None | Some("list") => { + let worktrees = git_stdout(cwd, &["worktree", "list"])?; + let trimmed = worktrees.trim(); + Ok(if trimmed.is_empty() { + "Worktree\n Result no worktrees found".to_string() + } else { + format!("Worktree\n Result listed\n\n{}", trimmed) + }) + } + Some("add") => { + let Some(path) = path.filter(|value| !value.trim().is_empty()) else { + return Ok("Usage: /worktree add <path> [branch]".to_string()); + }; + if let Some(branch) = branch.filter(|value| !value.trim().is_empty()) { + if branch_exists(cwd, branch) { + git_status_ok(cwd, &["worktree", "add", path, branch])?; + } else { + git_status_ok(cwd, &["worktree", "add", path, "-b", branch])?; + } + Ok(format!( + "Worktree\n Result added\n Path {path}\n Branch {branch}" + )) + } else { + git_status_ok(cwd, &["worktree", "add", path])?; + Ok(format!( + "Worktree\n Result added\n Path {path}" + )) + } + } + Some("remove") => { + let Some(path) = path.filter(|value| !value.trim().is_empty()) else { + return Ok("Usage: /worktree remove <path>".to_string()); + }; + git_status_ok(cwd, &["worktree", "remove", path])?; + Ok(format!( + "Worktree\n Result removed\n Path {path}" + )) + } + Some("prune") => { + git_status_ok(cwd, &["worktree", "prune"])?; + Ok("Worktree\n Result pruned".to_string()) + } + Some(other) => Ok(format!( + "Unknown /worktree action '{other}'. Use /worktree list, /worktree add <path> [branch], /worktree remove <path>, or /worktree prune." + )), + } +} + +pub fn handle_commit_slash_command(message: &str, cwd: &Path) -> io::Result<String> { + let status = git_stdout(cwd, &["status", "--short"])?; + if status.trim().is_empty() { + return Ok( + "Commit\n Result skipped\n Reason no workspace changes" + .to_string(), + ); + } + + let message = message.trim(); + if message.is_empty() { + return Err(io::Error::other("generated commit message was empty")); + } + + git_status_ok(cwd, &["add", "-A"])?; + let path = write_temp_text_file("claw-commit-message", "txt", message)?; + let path_string = path.to_string_lossy().into_owned(); + git_status_ok(cwd, &["commit", "--file", path_string.as_str()])?; + + Ok(format!( + "Commit\n Result created\n Message file {}\n\n{}", + path.display(), + message + )) +} + +pub fn handle_commit_push_pr_slash_command( + request: &CommitPushPrRequest, + cwd: &Path, +) -> io::Result<String> { + if !command_exists("gh") { + return Err(io::Error::other("gh CLI is required for /commit-push-pr")); + } + + let default_branch = detect_default_branch(cwd)?; + let mut branch = current_branch(cwd)?; + let mut created_branch = false; + if branch == default_branch { + let hint = if request.branch_name_hint.trim().is_empty() { + request.pr_title.as_str() + } else { + request.branch_name_hint.as_str() + }; + let next_branch = build_branch_name(hint); + git_status_ok(cwd, &["switch", "-c", next_branch.as_str()])?; + branch = next_branch; + created_branch = true; + } + + let workspace_has_changes = !git_stdout(cwd, &["status", "--short"])?.trim().is_empty(); + let commit_report = if workspace_has_changes { + let Some(message) = request.commit_message.as_deref() else { + return Err(io::Error::other( + "commit message is required when workspace changes are present", + )); + }; + Some(handle_commit_slash_command(message, cwd)?) + } else { + None + }; + + let branch_diff = git_stdout( + cwd, + &["diff", "--stat", &format!("{default_branch}...HEAD")], + )?; + if branch_diff.trim().is_empty() { + return Ok( + "Commit/Push/PR\n Result skipped\n Reason no branch changes to push or open as a pull request" + .to_string(), + ); + } + + git_status_ok(cwd, &["push", "--set-upstream", "origin", branch.as_str()])?; + + let body_path = write_temp_text_file("claw-pr-body", "md", request.pr_body.trim())?; + let body_path_string = body_path.to_string_lossy().into_owned(); + let create = Command::new("gh") + .args([ + "pr", + "create", + "--title", + request.pr_title.as_str(), + "--body-file", + body_path_string.as_str(), + "--base", + default_branch.as_str(), + ]) + .current_dir(cwd) + .output()?; + + let (result, url) = if create.status.success() { + ( + "created", + parse_pr_url(&String::from_utf8_lossy(&create.stdout)) + .unwrap_or_else(|| "<unknown>".to_string()), + ) + } else { + let view = Command::new("gh") + .args(["pr", "view", "--json", "url"]) + .current_dir(cwd) + .output()?; + if !view.status.success() { + return Err(io::Error::other(command_failure( + "gh", + &["pr", "create"], + &create, + ))); + } + ( + "existing", + parse_pr_json_url(&String::from_utf8_lossy(&view.stdout)) + .unwrap_or_else(|| "<unknown>".to_string()), + ) + }; + + let mut lines = vec![ + "Commit/Push/PR".to_string(), + format!(" Result {result}"), + format!(" Branch {branch}"), + format!(" Base {default_branch}"), + format!(" Body file {}", body_path.display()), + format!(" URL {url}"), + ]; + if created_branch { + lines.insert(2, " Branch action created and switched".to_string()); + } + if let Some(report) = commit_report { + lines.push(String::new()); + lines.push(report); + } + Ok(lines.join("\n")) +} + +pub fn detect_default_branch(cwd: &Path) -> io::Result<String> { + if let Ok(reference) = git_stdout(cwd, &["symbolic-ref", "refs/remotes/origin/HEAD"]) { + if let Some(branch) = reference + .trim() + .rsplit('/') + .next() + .filter(|value| !value.is_empty()) + { + return Ok(branch.to_string()); + } + } + + for branch in ["main", "master"] { + if branch_exists(cwd, branch) { + return Ok(branch.to_string()); + } + } + + current_branch(cwd) +} + +fn git_stdout(cwd: &Path, args: &[&str]) -> io::Result<String> { + run_command_stdout("git", args, cwd) +} + +fn git_status_ok(cwd: &Path, args: &[&str]) -> io::Result<()> { + run_command_success("git", args, cwd) +} + +fn run_command_stdout(program: &str, args: &[&str], cwd: &Path) -> io::Result<String> { + let output = Command::new(program).args(args).current_dir(cwd).output()?; + if !output.status.success() { + return Err(io::Error::other(command_failure(program, args, &output))); + } + String::from_utf8(output.stdout) + .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)) +} + +fn run_command_success(program: &str, args: &[&str], cwd: &Path) -> io::Result<()> { + let output = Command::new(program).args(args).current_dir(cwd).output()?; + if !output.status.success() { + return Err(io::Error::other(command_failure(program, args, &output))); + } + Ok(()) +} + +fn command_failure(program: &str, args: &[&str], output: &std::process::Output) -> String { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + let detail = if stderr.is_empty() { stdout } else { stderr }; + if detail.is_empty() { + format!("{program} {} failed", args.join(" ")) + } else { + format!("{program} {} failed: {detail}", args.join(" ")) + } +} + +fn branch_exists(cwd: &Path, branch: &str) -> bool { + Command::new("git") + .args([ + "show-ref", + "--verify", + "--quiet", + &format!("refs/heads/{branch}"), + ]) + .current_dir(cwd) + .output() + .map(|output| output.status.success()) + .unwrap_or(false) +} + +fn current_branch(cwd: &Path) -> io::Result<String> { + let branch = git_stdout(cwd, &["branch", "--show-current"])?; + let branch = branch.trim(); + if branch.is_empty() { + Err(io::Error::other("unable to determine current git branch")) + } else { + Ok(branch.to_string()) + } +} + +fn command_exists(name: &str) -> bool { + Command::new(name) + .arg("--version") + .output() + .map(|output| output.status.success()) + .unwrap_or(false) +} + +fn write_temp_text_file(prefix: &str, extension: &str, contents: &str) -> io::Result<PathBuf> { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|duration| duration.as_nanos()) + .unwrap_or_default(); + let path = env::temp_dir().join(format!("{prefix}-{nanos}.{extension}")); + fs::write(&path, contents)?; + Ok(path) +} + +fn build_branch_name(hint: &str) -> String { + let slug = slugify(hint); + let owner = env::var("SAFEUSER") + .ok() + .filter(|value| !value.trim().is_empty()) + .or_else(|| { + env::var("USER") + .ok() + .filter(|value| !value.trim().is_empty()) + }); + match owner { + Some(owner) => format!("{owner}/{slug}"), + None => slug, + } +} + +fn slugify(value: &str) -> String { + let mut slug = String::new(); + let mut last_was_dash = false; + for ch in value.chars() { + if ch.is_ascii_alphanumeric() { + slug.push(ch.to_ascii_lowercase()); + last_was_dash = false; + } else if !last_was_dash { + slug.push('-'); + last_was_dash = true; + } + } + let slug = slug.trim_matches('-').to_string(); + if slug.is_empty() { + "change".to_string() + } else { + slug + } +} + +fn parse_pr_url(stdout: &str) -> Option<String> { + stdout + .lines() + .map(str::trim) + .find(|line| line.starts_with("http://") || line.starts_with("https://")) + .map(ToOwned::to_owned) +} + +fn parse_pr_json_url(stdout: &str) -> Option<String> { + serde_json::from_str::<serde_json::Value>(stdout) + .ok()? + .get("url")? + .as_str() + .map(ToOwned::to_owned) +} + +#[must_use] +pub fn render_plugins_report(plugins: &[PluginSummary]) -> String { + let mut lines = vec!["Plugins".to_string()]; + if plugins.is_empty() { + lines.push(" No plugins installed.".to_string()); + return lines.join("\n"); + } + for plugin in plugins { + let enabled = if plugin.enabled { + "enabled" + } else { + "disabled" + }; + lines.push(format!( + " {name:<20} v{version:<10} {enabled}", + name = plugin.metadata.name, + version = plugin.metadata.version, + )); + } + lines.join("\n") +} + +fn render_plugin_install_report(plugin_id: &str, plugin: Option<&PluginSummary>) -> String { + let name = plugin.map_or(plugin_id, |plugin| plugin.metadata.name.as_str()); + let version = plugin.map_or("unknown", |plugin| plugin.metadata.version.as_str()); + let enabled = plugin.is_some_and(|plugin| plugin.enabled); + format!( + "Plugins\n Result installed {plugin_id}\n Name {name}\n Version {version}\n Status {}", + if enabled { "enabled" } else { "disabled" } + ) +} + +fn resolve_plugin_target( + manager: &PluginManager, + target: &str, +) -> Result<PluginSummary, PluginError> { + let mut matches = manager + .list_installed_plugins()? + .into_iter() + .filter(|plugin| plugin.metadata.id == target || plugin.metadata.name == target) + .collect::<Vec<_>>(); + match matches.len() { + 1 => Ok(matches.remove(0)), + 0 => Err(PluginError::NotFound(format!( + "plugin `{target}` is not installed or discoverable" + ))), + _ => Err(PluginError::InvalidManifest(format!( + "plugin name `{target}` is ambiguous; use the full plugin id" + ))), + } +} + +fn discover_definition_roots(cwd: &Path, leaf: &str) -> Vec<(DefinitionSource, PathBuf)> { + let mut roots = Vec::new(); + + for ancestor in cwd.ancestors() { + push_unique_root( + &mut roots, + DefinitionSource::ProjectCodex, + ancestor.join(".codex").join(leaf), + ); + push_unique_root( + &mut roots, + DefinitionSource::ProjectClaw, + ancestor.join(".claw").join(leaf), + ); + } + + if let Ok(codex_home) = env::var("CODEX_HOME") { + push_unique_root( + &mut roots, + DefinitionSource::UserCodexHome, + PathBuf::from(codex_home).join(leaf), + ); + } + + if let Some(home) = env::var_os("HOME") { + let home = PathBuf::from(home); + push_unique_root( + &mut roots, + DefinitionSource::UserCodex, + home.join(".codex").join(leaf), + ); + push_unique_root( + &mut roots, + DefinitionSource::UserClaw, + home.join(".claw").join(leaf), + ); + } + + roots +} + +fn discover_skill_roots(cwd: &Path) -> Vec<SkillRoot> { + let mut roots = Vec::new(); + + for ancestor in cwd.ancestors() { + push_unique_skill_root( + &mut roots, + DefinitionSource::ProjectCodex, + ancestor.join(".codex").join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::ProjectClaw, + ancestor.join(".claw").join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::ProjectCodex, + ancestor.join(".codex").join("commands"), + SkillOrigin::LegacyCommandsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::ProjectClaw, + ancestor.join(".claw").join("commands"), + SkillOrigin::LegacyCommandsDir, + ); + } + + if let Ok(codex_home) = env::var("CODEX_HOME") { + let codex_home = PathBuf::from(codex_home); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserCodexHome, + codex_home.join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserCodexHome, + codex_home.join("commands"), + SkillOrigin::LegacyCommandsDir, + ); + } + + if let Some(home) = env::var_os("HOME") { + let home = PathBuf::from(home); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserCodex, + home.join(".codex").join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserCodex, + home.join(".codex").join("commands"), + SkillOrigin::LegacyCommandsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClaw, + home.join(".claw").join("skills"), + SkillOrigin::SkillsDir, + ); + push_unique_skill_root( + &mut roots, + DefinitionSource::UserClaw, + home.join(".claw").join("commands"), + SkillOrigin::LegacyCommandsDir, + ); + } + + roots +} + +fn push_unique_root( + roots: &mut Vec<(DefinitionSource, PathBuf)>, + source: DefinitionSource, + path: PathBuf, +) { + if path.is_dir() && !roots.iter().any(|(_, existing)| existing == &path) { + roots.push((source, path)); + } +} + +fn push_unique_skill_root( + roots: &mut Vec<SkillRoot>, + source: DefinitionSource, + path: PathBuf, + origin: SkillOrigin, +) { + if path.is_dir() && !roots.iter().any(|existing| existing.path == path) { + roots.push(SkillRoot { + source, + path, + origin, + }); + } +} + +fn load_agents_from_roots( + roots: &[(DefinitionSource, PathBuf)], +) -> std::io::Result<Vec<AgentSummary>> { + let mut agents = Vec::new(); + let mut active_sources = BTreeMap::<String, DefinitionSource>::new(); + + for (source, root) in roots { + let mut root_agents = Vec::new(); + for entry in fs::read_dir(root)? { + let entry = entry?; + if entry.path().extension().is_none_or(|ext| ext != "toml") { + continue; + } + let contents = fs::read_to_string(entry.path())?; + let fallback_name = entry.path().file_stem().map_or_else( + || entry.file_name().to_string_lossy().to_string(), + |stem| stem.to_string_lossy().to_string(), + ); + root_agents.push(AgentSummary { + name: parse_toml_string(&contents, "name").unwrap_or(fallback_name), + description: parse_toml_string(&contents, "description"), + model: parse_toml_string(&contents, "model"), + reasoning_effort: parse_toml_string(&contents, "model_reasoning_effort"), + source: *source, + shadowed_by: None, + }); + } + root_agents.sort_by(|left, right| left.name.cmp(&right.name)); + + for mut agent in root_agents { + let key = agent.name.to_ascii_lowercase(); + if let Some(existing) = active_sources.get(&key) { + agent.shadowed_by = Some(*existing); + } else { + active_sources.insert(key, agent.source); + } + agents.push(agent); + } + } + + Ok(agents) +} + +fn load_skills_from_roots(roots: &[SkillRoot]) -> std::io::Result<Vec<SkillSummary>> { + let mut skills = Vec::new(); + let mut active_sources = BTreeMap::<String, DefinitionSource>::new(); + + for root in roots { + let mut root_skills = Vec::new(); + for entry in fs::read_dir(&root.path)? { + let entry = entry?; + match root.origin { + SkillOrigin::SkillsDir => { + if !entry.path().is_dir() { + continue; + } + let skill_path = entry.path().join("SKILL.md"); + if !skill_path.is_file() { + continue; + } + let contents = fs::read_to_string(skill_path)?; + let (name, description) = parse_skill_frontmatter(&contents); + root_skills.push(SkillSummary { + name: name + .unwrap_or_else(|| entry.file_name().to_string_lossy().to_string()), + description, + source: root.source, + shadowed_by: None, + origin: root.origin, + }); + } + SkillOrigin::LegacyCommandsDir => { + let path = entry.path(); + let markdown_path = if path.is_dir() { + let skill_path = path.join("SKILL.md"); + if !skill_path.is_file() { + continue; + } + skill_path + } else if path + .extension() + .is_some_and(|ext| ext.to_string_lossy().eq_ignore_ascii_case("md")) + { + path + } else { + continue; + }; + + let contents = fs::read_to_string(&markdown_path)?; + let fallback_name = markdown_path.file_stem().map_or_else( + || entry.file_name().to_string_lossy().to_string(), + |stem| stem.to_string_lossy().to_string(), + ); + let (name, description) = parse_skill_frontmatter(&contents); + root_skills.push(SkillSummary { + name: name.unwrap_or(fallback_name), + description, + source: root.source, + shadowed_by: None, + origin: root.origin, + }); + } + } + } + root_skills.sort_by(|left, right| left.name.cmp(&right.name)); + + for mut skill in root_skills { + let key = skill.name.to_ascii_lowercase(); + if let Some(existing) = active_sources.get(&key) { + skill.shadowed_by = Some(*existing); + } else { + active_sources.insert(key, skill.source); + } + skills.push(skill); + } + } + + Ok(skills) +} + +fn parse_toml_string(contents: &str, key: &str) -> Option<String> { + let prefix = format!("{key} ="); + for line in contents.lines() { + let trimmed = line.trim(); + if trimmed.starts_with('#') { + continue; + } + let Some(value) = trimmed.strip_prefix(&prefix) else { + continue; + }; + let value = value.trim(); + let Some(value) = value + .strip_prefix('"') + .and_then(|value| value.strip_suffix('"')) + else { + continue; + }; + if !value.is_empty() { + return Some(value.to_string()); + } + } + None +} + +fn parse_skill_frontmatter(contents: &str) -> (Option<String>, Option<String>) { + let mut lines = contents.lines(); + if lines.next().map(str::trim) != Some("---") { + return (None, None); + } + + let mut name = None; + let mut description = None; + for line in lines { + let trimmed = line.trim(); + if trimmed == "---" { + break; + } + if let Some(value) = trimmed.strip_prefix("name:") { + let value = unquote_frontmatter_value(value.trim()); + if !value.is_empty() { + name = Some(value); + } + continue; + } + if let Some(value) = trimmed.strip_prefix("description:") { + let value = unquote_frontmatter_value(value.trim()); + if !value.is_empty() { + description = Some(value); + } + } + } + + (name, description) +} + +fn unquote_frontmatter_value(value: &str) -> String { + value + .strip_prefix('"') + .and_then(|trimmed| trimmed.strip_suffix('"')) + .or_else(|| { + value + .strip_prefix('\'') + .and_then(|trimmed| trimmed.strip_suffix('\'')) + }) + .unwrap_or(value) + .trim() + .to_string() +} + +fn render_agents_report(agents: &[AgentSummary]) -> String { + if agents.is_empty() { + return "No agents found.".to_string(); + } + + let total_active = agents + .iter() + .filter(|agent| agent.shadowed_by.is_none()) + .count(); + let mut lines = vec![ + "Agents".to_string(), + format!(" {total_active} active agents"), + String::new(), + ]; + + for source in [ + DefinitionSource::ProjectCodex, + DefinitionSource::ProjectClaw, + DefinitionSource::UserCodexHome, + DefinitionSource::UserCodex, + DefinitionSource::UserClaw, + ] { + let group = agents + .iter() + .filter(|agent| agent.source == source) + .collect::<Vec<_>>(); + if group.is_empty() { + continue; + } + + lines.push(format!("{}:", source.label())); + for agent in group { + let detail = agent_detail(agent); + match agent.shadowed_by { + Some(winner) => lines.push(format!(" (shadowed by {}) {detail}", winner.label())), + None => lines.push(format!(" {detail}")), + } + } + lines.push(String::new()); + } + + lines.join("\n").trim_end().to_string() +} + +fn agent_detail(agent: &AgentSummary) -> String { + let mut parts = vec![agent.name.clone()]; + if let Some(description) = &agent.description { + parts.push(description.clone()); + } + if let Some(model) = &agent.model { + parts.push(model.clone()); + } + if let Some(reasoning) = &agent.reasoning_effort { + parts.push(reasoning.clone()); + } + parts.join(" · ") +} + +fn render_skills_report(skills: &[SkillSummary]) -> String { + if skills.is_empty() { + return "No skills found.".to_string(); + } + + let total_active = skills + .iter() + .filter(|skill| skill.shadowed_by.is_none()) + .count(); + let mut lines = vec![ + "Skills".to_string(), + format!(" {total_active} available skills"), + String::new(), + ]; + + for source in [ + DefinitionSource::ProjectCodex, + DefinitionSource::ProjectClaw, + DefinitionSource::UserCodexHome, + DefinitionSource::UserCodex, + DefinitionSource::UserClaw, + ] { + let group = skills + .iter() + .filter(|skill| skill.source == source) + .collect::<Vec<_>>(); + if group.is_empty() { + continue; + } + + lines.push(format!("{}:", source.label())); + for skill in group { + let mut parts = vec![skill.name.clone()]; + if let Some(description) = &skill.description { + parts.push(description.clone()); + } + if let Some(detail) = skill.origin.detail_label() { + parts.push(detail.to_string()); + } + let detail = parts.join(" · "); + match skill.shadowed_by { + Some(winner) => lines.push(format!(" (shadowed by {}) {detail}", winner.label())), + None => lines.push(format!(" {detail}")), + } + } + lines.push(String::new()); + } + + lines.join("\n").trim_end().to_string() +} + +fn normalize_optional_args(args: Option<&str>) -> Option<&str> { + args.map(str::trim).filter(|value| !value.is_empty()) +} + +fn render_agents_usage(unexpected: Option<&str>) -> String { + let mut lines = vec![ + "Agents".to_string(), + " Usage /agents".to_string(), + " Direct CLI claw agents".to_string(), + " Sources .codex/agents, .claw/agents, $CODEX_HOME/agents".to_string(), + ]; + if let Some(args) = unexpected { + lines.push(format!(" Unexpected {args}")); + } + lines.join("\n") +} + +fn render_skills_usage(unexpected: Option<&str>) -> String { + let mut lines = vec![ + "Skills".to_string(), + " Usage /skills".to_string(), + " Direct CLI claw skills".to_string(), + " Sources .codex/skills, .claw/skills, legacy /commands".to_string(), + ]; + if let Some(args) = unexpected { + lines.push(format!(" Unexpected {args}")); + } + lines.join("\n") +} + #[must_use] pub fn handle_slash_command( input: &str, @@ -279,6 +1758,16 @@ pub fn handle_slash_command( session: session.clone(), }), SlashCommand::Status + | SlashCommand::Branch { .. } + | SlashCommand::Bughunter { .. } + | SlashCommand::Worktree { .. } + | SlashCommand::Commit + | SlashCommand::CommitPushPr { .. } + | SlashCommand::Pr { .. } + | SlashCommand::Issue { .. } + | SlashCommand::Ultraplan { .. } + | SlashCommand::Teleport { .. } + | SlashCommand::DebugToolCall | SlashCommand::Model { .. } | SlashCommand::Permissions { .. } | SlashCommand::Clear { .. } @@ -291,6 +1780,9 @@ pub fn handle_slash_command( | SlashCommand::Version | SlashCommand::Export { .. } | SlashCommand::Session { .. } + | SlashCommand::Plugins { .. } + | SlashCommand::Agents { .. } + | SlashCommand::Skills { .. } | SlashCommand::Unknown(_) => None, } } @@ -298,19 +1790,237 @@ pub fn handle_slash_command( #[cfg(test)] mod tests { use super::{ - handle_slash_command, render_slash_command_help, resume_supported_slash_commands, - slash_command_specs, SlashCommand, + handle_branch_slash_command, handle_commit_push_pr_slash_command, + handle_commit_slash_command, handle_plugins_slash_command, handle_slash_command, + handle_worktree_slash_command, load_agents_from_roots, load_skills_from_roots, + render_agents_report, render_plugins_report, render_skills_report, + render_slash_command_help, resume_supported_slash_commands, slash_command_specs, + suggest_slash_commands, CommitPushPrRequest, DefinitionSource, SkillOrigin, SkillRoot, + SlashCommand, }; + use plugins::{PluginKind, PluginManager, PluginManagerConfig, PluginMetadata, PluginSummary}; use runtime::{CompactionConfig, ContentBlock, ConversationMessage, MessageRole, Session}; + use std::env; + use std::fs; + use std::path::{Path, PathBuf}; + use std::process::Command; + use std::sync::{Mutex, OnceLock}; + use std::time::{SystemTime, UNIX_EPOCH}; + #[cfg(unix)] + use std::os::unix::fs::PermissionsExt; + + fn temp_dir(label: &str) -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("commands-plugin-{label}-{nanos}")) + } + + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock<Mutex<()>> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock") + } + + fn run_command(cwd: &Path, program: &str, args: &[&str]) -> String { + let output = Command::new(program) + .args(args) + .current_dir(cwd) + .output() + .expect("command should run"); + assert!( + output.status.success(), + "{} {} failed: {}", + program, + args.join(" "), + String::from_utf8_lossy(&output.stderr) + ); + String::from_utf8(output.stdout).expect("stdout should be utf8") + } + + fn init_git_repo(label: &str) -> PathBuf { + let root = temp_dir(label); + fs::create_dir_all(&root).expect("repo root"); + + let init = Command::new("git") + .args(["init", "-b", "main"]) + .current_dir(&root) + .output() + .expect("git init should run"); + if !init.status.success() { + let fallback = Command::new("git") + .arg("init") + .current_dir(&root) + .output() + .expect("fallback git init should run"); + assert!( + fallback.status.success(), + "fallback git init should succeed" + ); + let rename = Command::new("git") + .args(["branch", "-m", "main"]) + .current_dir(&root) + .output() + .expect("git branch -m should run"); + assert!(rename.status.success(), "git branch -m main should succeed"); + } + + run_command(&root, "git", &["config", "user.name", "Claw Tests"]); + run_command(&root, "git", &["config", "user.email", "claw@example.com"]); + fs::write(root.join("README.md"), "seed\n").expect("seed file"); + run_command(&root, "git", &["add", "README.md"]); + run_command(&root, "git", &["commit", "-m", "chore: seed repo"]); + root + } + + fn init_bare_repo(label: &str) -> PathBuf { + let root = temp_dir(label); + let output = Command::new("git") + .args(["init", "--bare"]) + .arg(&root) + .output() + .expect("bare repo should initialize"); + assert!(output.status.success(), "git init --bare should succeed"); + root + } + + #[cfg(unix)] + fn write_fake_gh(bin_dir: &Path, log_path: &Path, url: &str) { + fs::create_dir_all(bin_dir).expect("bin dir"); + let script = format!( + "#!/bin/sh\nif [ \"$1\" = \"--version\" ]; then\n echo 'gh 1.0.0'\n exit 0\nfi\nprintf '%s\\n' \"$*\" >> \"{}\"\nif [ \"$1\" = \"pr\" ] && [ \"$2\" = \"create\" ]; then\n echo '{}'\n exit 0\nfi\nif [ \"$1\" = \"pr\" ] && [ \"$2\" = \"view\" ]; then\n echo '{{\"url\":\"{}\"}}'\n exit 0\nfi\nexit 0\n", + log_path.display(), + url, + url, + ); + let path = bin_dir.join("gh"); + fs::write(&path, script).expect("gh stub"); + let mut permissions = fs::metadata(&path).expect("metadata").permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&path, permissions).expect("chmod"); + } + + fn write_external_plugin(root: &Path, name: &str, version: &str) { + fs::create_dir_all(root.join(".claw-plugin")).expect("manifest dir"); + fs::write( + root.join(".claw-plugin").join("plugin.json"), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"commands plugin\"\n}}" + ), + ) + .expect("write manifest"); + } + + fn write_bundled_plugin(root: &Path, name: &str, version: &str, default_enabled: bool) { + fs::create_dir_all(root.join(".claw-plugin")).expect("manifest dir"); + fs::write( + root.join(".claw-plugin").join("plugin.json"), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"bundled commands plugin\",\n \"defaultEnabled\": {}\n}}", + if default_enabled { "true" } else { "false" } + ), + ) + .expect("write bundled manifest"); + } + + fn write_agent(root: &Path, name: &str, description: &str, model: &str, reasoning: &str) { + fs::create_dir_all(root).expect("agent root"); + fs::write( + root.join(format!("{name}.toml")), + format!( + "name = \"{name}\"\ndescription = \"{description}\"\nmodel = \"{model}\"\nmodel_reasoning_effort = \"{reasoning}\"\n" + ), + ) + .expect("write agent"); + } + + fn write_skill(root: &Path, name: &str, description: &str) { + let skill_root = root.join(name); + fs::create_dir_all(&skill_root).expect("skill root"); + fs::write( + skill_root.join("SKILL.md"), + format!("---\nname: {name}\ndescription: {description}\n---\n\n# {name}\n"), + ) + .expect("write skill"); + } + + fn write_legacy_command(root: &Path, name: &str, description: &str) { + fs::create_dir_all(root).expect("commands root"); + fs::write( + root.join(format!("{name}.md")), + format!("---\nname: {name}\ndescription: {description}\n---\n\n# {name}\n"), + ) + .expect("write command"); + } + + #[allow(clippy::too_many_lines)] #[test] fn parses_supported_slash_commands() { assert_eq!(SlashCommand::parse("/help"), Some(SlashCommand::Help)); assert_eq!(SlashCommand::parse(" /status "), Some(SlashCommand::Status)); assert_eq!( - SlashCommand::parse("/model claude-opus"), + SlashCommand::parse("/bughunter runtime"), + Some(SlashCommand::Bughunter { + scope: Some("runtime".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/branch create feature/demo"), + Some(SlashCommand::Branch { + action: Some("create".to_string()), + target: Some("feature/demo".to_string()), + }) + ); + assert_eq!( + SlashCommand::parse("/worktree add ../demo wt-demo"), + Some(SlashCommand::Worktree { + action: Some("add".to_string()), + path: Some("../demo".to_string()), + branch: Some("wt-demo".to_string()), + }) + ); + assert_eq!(SlashCommand::parse("/commit"), Some(SlashCommand::Commit)); + assert_eq!( + SlashCommand::parse("/commit-push-pr ready for review"), + Some(SlashCommand::CommitPushPr { + context: Some("ready for review".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/pr ready for review"), + Some(SlashCommand::Pr { + context: Some("ready for review".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/issue flaky test"), + Some(SlashCommand::Issue { + context: Some("flaky test".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/ultraplan ship both features"), + Some(SlashCommand::Ultraplan { + task: Some("ship both features".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/teleport conversation.rs"), + Some(SlashCommand::Teleport { + target: Some("conversation.rs".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/debug-tool-call"), + Some(SlashCommand::DebugToolCall) + ); + assert_eq!( + SlashCommand::parse("/model opus"), Some(SlashCommand::Model { - model: Some("claude-opus".to_string()), + model: Some("opus".to_string()), }) ); assert_eq!( @@ -365,29 +2075,85 @@ mod tests { target: Some("abc123".to_string()) }) ); + assert_eq!( + SlashCommand::parse("/plugins install demo"), + Some(SlashCommand::Plugins { + action: Some("install".to_string()), + target: Some("demo".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/plugins list"), + Some(SlashCommand::Plugins { + action: Some("list".to_string()), + target: None + }) + ); + assert_eq!( + SlashCommand::parse("/plugins enable demo"), + Some(SlashCommand::Plugins { + action: Some("enable".to_string()), + target: Some("demo".to_string()) + }) + ); + assert_eq!( + SlashCommand::parse("/plugins disable demo"), + Some(SlashCommand::Plugins { + action: Some("disable".to_string()), + target: Some("demo".to_string()) + }) + ); } #[test] fn renders_help_from_shared_specs() { let help = render_slash_command_help(); - assert!(help.contains("works with --resume SESSION.json")); + assert!(help.contains("available via claw --resume SESSION.json")); + assert!(help.contains("Core flow")); + assert!(help.contains("Workspace & memory")); + assert!(help.contains("Sessions & output")); + assert!(help.contains("Git & GitHub")); + assert!(help.contains("Automation & discovery")); assert!(help.contains("/help")); assert!(help.contains("/status")); assert!(help.contains("/compact")); + assert!(help.contains("/bughunter [scope]")); + assert!(help.contains("/branch [list|create <name>|switch <name>]")); + assert!(help.contains("/worktree [list|add <path> [branch]|remove <path>|prune]")); + assert!(help.contains("/commit")); + assert!(help.contains("/commit-push-pr [context]")); + assert!(help.contains("/pr [context]")); + assert!(help.contains("/issue [context]")); + assert!(help.contains("/ultraplan [task]")); + assert!(help.contains("/teleport <symbol-or-path>")); + assert!(help.contains("/debug-tool-call")); assert!(help.contains("/model [model]")); assert!(help.contains("/permissions [read-only|workspace-write|danger-full-access]")); assert!(help.contains("/clear [--confirm]")); assert!(help.contains("/cost")); assert!(help.contains("/resume <session-path>")); - assert!(help.contains("/config [env|hooks|model]")); + assert!(help.contains("/config [env|hooks|model|plugins]")); assert!(help.contains("/memory")); assert!(help.contains("/init")); assert!(help.contains("/diff")); assert!(help.contains("/version")); assert!(help.contains("/export [file]")); assert!(help.contains("/session [list|switch <session-id>]")); - assert_eq!(slash_command_specs().len(), 15); - assert_eq!(resume_supported_slash_commands().len(), 11); + assert!(help.contains( + "/plugin [list|install <path>|enable <name>|disable <name>|uninstall <id>|update <id>]" + )); + assert!(help.contains("aliases: /plugins, /marketplace")); + assert!(help.contains("/agents")); + assert!(help.contains("/skills")); + assert_eq!(slash_command_specs().len(), 28); + assert_eq!(resume_supported_slash_commands().len(), 13); + } + + #[test] + fn suggests_close_slash_commands() { + let suggestions = suggest_slash_commands("stats", 3); + assert!(!suggestions.is_empty()); + assert_eq!(suggestions[0], "/status"); } #[test] @@ -435,7 +2201,35 @@ mod tests { assert!(handle_slash_command("/unknown", &session, CompactionConfig::default()).is_none()); assert!(handle_slash_command("/status", &session, CompactionConfig::default()).is_none()); assert!( - handle_slash_command("/model claude", &session, CompactionConfig::default()).is_none() + handle_slash_command("/branch list", &session, CompactionConfig::default()).is_none() + ); + assert!( + handle_slash_command("/bughunter", &session, CompactionConfig::default()).is_none() + ); + assert!( + handle_slash_command("/worktree list", &session, CompactionConfig::default()).is_none() + ); + assert!(handle_slash_command("/commit", &session, CompactionConfig::default()).is_none()); + assert!(handle_slash_command( + "/commit-push-pr review notes", + &session, + CompactionConfig::default() + ) + .is_none()); + assert!(handle_slash_command("/pr", &session, CompactionConfig::default()).is_none()); + assert!(handle_slash_command("/issue", &session, CompactionConfig::default()).is_none()); + assert!( + handle_slash_command("/ultraplan", &session, CompactionConfig::default()).is_none() + ); + assert!( + handle_slash_command("/teleport foo", &session, CompactionConfig::default()).is_none() + ); + assert!( + handle_slash_command("/debug-tool-call", &session, CompactionConfig::default()) + .is_none() + ); + assert!( + handle_slash_command("/model sonnet", &session, CompactionConfig::default()).is_none() ); assert!(handle_slash_command( "/permissions read-only", @@ -468,5 +2262,406 @@ mod tests { assert!( handle_slash_command("/session list", &session, CompactionConfig::default()).is_none() ); + assert!( + handle_slash_command("/plugins list", &session, CompactionConfig::default()).is_none() + ); + } + + #[test] + fn renders_plugins_report_with_name_version_and_status() { + let rendered = render_plugins_report(&[ + PluginSummary { + metadata: PluginMetadata { + id: "demo@external".to_string(), + name: "demo".to_string(), + version: "1.2.3".to_string(), + description: "demo plugin".to_string(), + kind: PluginKind::External, + source: "demo".to_string(), + default_enabled: false, + root: None, + }, + enabled: true, + }, + PluginSummary { + metadata: PluginMetadata { + id: "sample@external".to_string(), + name: "sample".to_string(), + version: "0.9.0".to_string(), + description: "sample plugin".to_string(), + kind: PluginKind::External, + source: "sample".to_string(), + default_enabled: false, + root: None, + }, + enabled: false, + }, + ]); + + assert!(rendered.contains("demo")); + assert!(rendered.contains("v1.2.3")); + assert!(rendered.contains("enabled")); + assert!(rendered.contains("sample")); + assert!(rendered.contains("v0.9.0")); + assert!(rendered.contains("disabled")); + } + + #[test] + fn lists_agents_from_project_and_user_roots() { + let workspace = temp_dir("agents-workspace"); + let project_agents = workspace.join(".codex").join("agents"); + let user_home = temp_dir("agents-home"); + let user_agents = user_home.join(".codex").join("agents"); + + write_agent( + &project_agents, + "planner", + "Project planner", + "gpt-5.4", + "medium", + ); + write_agent( + &user_agents, + "planner", + "User planner", + "gpt-5.4-mini", + "high", + ); + write_agent( + &user_agents, + "verifier", + "Verification agent", + "gpt-5.4-mini", + "high", + ); + + let roots = vec![ + (DefinitionSource::ProjectCodex, project_agents), + (DefinitionSource::UserCodex, user_agents), + ]; + let report = + render_agents_report(&load_agents_from_roots(&roots).expect("agent roots should load")); + + assert!(report.contains("Agents")); + assert!(report.contains("2 active agents")); + assert!(report.contains("Project (.codex):")); + assert!(report.contains("planner · Project planner · gpt-5.4 · medium")); + assert!(report.contains("User (~/.codex):")); + assert!(report.contains("(shadowed by Project (.codex)) planner · User planner")); + assert!(report.contains("verifier · Verification agent · gpt-5.4-mini · high")); + + let _ = fs::remove_dir_all(workspace); + let _ = fs::remove_dir_all(user_home); + } + + #[test] + fn lists_skills_from_project_and_user_roots() { + let workspace = temp_dir("skills-workspace"); + let project_skills = workspace.join(".codex").join("skills"); + let project_commands = workspace.join(".claw").join("commands"); + let user_home = temp_dir("skills-home"); + let user_skills = user_home.join(".codex").join("skills"); + + write_skill(&project_skills, "plan", "Project planning guidance"); + write_legacy_command(&project_commands, "deploy", "Legacy deployment guidance"); + write_skill(&user_skills, "plan", "User planning guidance"); + write_skill(&user_skills, "help", "Help guidance"); + + let roots = vec![ + SkillRoot { + source: DefinitionSource::ProjectCodex, + path: project_skills, + origin: SkillOrigin::SkillsDir, + }, + SkillRoot { + source: DefinitionSource::ProjectClaw, + path: project_commands, + origin: SkillOrigin::LegacyCommandsDir, + }, + SkillRoot { + source: DefinitionSource::UserCodex, + path: user_skills, + origin: SkillOrigin::SkillsDir, + }, + ]; + let report = + render_skills_report(&load_skills_from_roots(&roots).expect("skill roots should load")); + + assert!(report.contains("Skills")); + assert!(report.contains("3 available skills")); + assert!(report.contains("Project (.codex):")); + assert!(report.contains("plan · Project planning guidance")); + assert!(report.contains("Project (.claw):")); + assert!(report.contains("deploy · Legacy deployment guidance · legacy /commands")); + assert!(report.contains("User (~/.codex):")); + assert!(report.contains("(shadowed by Project (.codex)) plan · User planning guidance")); + assert!(report.contains("help · Help guidance")); + + let _ = fs::remove_dir_all(workspace); + let _ = fs::remove_dir_all(user_home); + } + + #[test] + fn agents_and_skills_usage_support_help_and_unexpected_args() { + let cwd = temp_dir("slash-usage"); + + let agents_help = + super::handle_agents_slash_command(Some("help"), &cwd).expect("agents help"); + assert!(agents_help.contains("Usage /agents")); + assert!(agents_help.contains("Direct CLI claw agents")); + + let agents_unexpected = + super::handle_agents_slash_command(Some("show planner"), &cwd).expect("agents usage"); + assert!(agents_unexpected.contains("Unexpected show planner")); + + let skills_help = + super::handle_skills_slash_command(Some("--help"), &cwd).expect("skills help"); + assert!(skills_help.contains("Usage /skills")); + assert!(skills_help.contains("legacy /commands")); + + let skills_unexpected = + super::handle_skills_slash_command(Some("show help"), &cwd).expect("skills usage"); + assert!(skills_unexpected.contains("Unexpected show help")); + + let _ = fs::remove_dir_all(cwd); + } + + #[test] + fn parses_quoted_skill_frontmatter_values() { + let contents = "---\nname: \"hud\"\ndescription: 'Quoted description'\n---\n"; + let (name, description) = super::parse_skill_frontmatter(contents); + assert_eq!(name.as_deref(), Some("hud")); + assert_eq!(description.as_deref(), Some("Quoted description")); + } + + #[test] + fn installs_plugin_from_path_and_lists_it() { + let config_home = temp_dir("home"); + let source_root = temp_dir("source"); + write_external_plugin(&source_root, "demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let install = handle_plugins_slash_command( + Some("install"), + Some(source_root.to_str().expect("utf8 path")), + &mut manager, + ) + .expect("install command should succeed"); + assert!(install.reload_runtime); + assert!(install.message.contains("installed demo@external")); + assert!(install.message.contains("Name demo")); + assert!(install.message.contains("Version 1.0.0")); + assert!(install.message.contains("Status enabled")); + + let list = handle_plugins_slash_command(Some("list"), None, &mut manager) + .expect("list command should succeed"); + assert!(!list.reload_runtime); + assert!(list.message.contains("demo")); + assert!(list.message.contains("v1.0.0")); + assert!(list.message.contains("enabled")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn enables_and_disables_plugin_by_name() { + let config_home = temp_dir("toggle-home"); + let source_root = temp_dir("toggle-source"); + write_external_plugin(&source_root, "demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + handle_plugins_slash_command( + Some("install"), + Some(source_root.to_str().expect("utf8 path")), + &mut manager, + ) + .expect("install command should succeed"); + + let disable = handle_plugins_slash_command(Some("disable"), Some("demo"), &mut manager) + .expect("disable command should succeed"); + assert!(disable.reload_runtime); + assert!(disable.message.contains("disabled demo@external")); + assert!(disable.message.contains("Name demo")); + assert!(disable.message.contains("Status disabled")); + + let list = handle_plugins_slash_command(Some("list"), None, &mut manager) + .expect("list command should succeed"); + assert!(list.message.contains("demo")); + assert!(list.message.contains("disabled")); + + let enable = handle_plugins_slash_command(Some("enable"), Some("demo"), &mut manager) + .expect("enable command should succeed"); + assert!(enable.reload_runtime); + assert!(enable.message.contains("enabled demo@external")); + assert!(enable.message.contains("Name demo")); + assert!(enable.message.contains("Status enabled")); + + let list = handle_plugins_slash_command(Some("list"), None, &mut manager) + .expect("list command should succeed"); + assert!(list.message.contains("demo")); + assert!(list.message.contains("enabled")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn lists_auto_installed_bundled_plugins_with_status() { + let config_home = temp_dir("bundled-home"); + let bundled_root = temp_dir("bundled-root"); + let bundled_plugin = bundled_root.join("starter"); + write_bundled_plugin(&bundled_plugin, "starter", "0.1.0", false); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + let mut manager = PluginManager::new(config); + + let list = handle_plugins_slash_command(Some("list"), None, &mut manager) + .expect("list command should succeed"); + assert!(!list.reload_runtime); + assert!(list.message.contains("starter")); + assert!(list.message.contains("v0.1.0")); + assert!(list.message.contains("disabled")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn branch_and_worktree_commands_manage_git_state() { + // given + let repo = init_git_repo("branch-worktree"); + let worktree_path = repo + .parent() + .expect("repo should have parent") + .join("branch-worktree-linked"); + + // when + let branch_list = + handle_branch_slash_command(Some("list"), None, &repo).expect("branch list succeeds"); + let created = handle_branch_slash_command(Some("create"), Some("feature/demo"), &repo) + .expect("branch create succeeds"); + let switched = handle_branch_slash_command(Some("switch"), Some("main"), &repo) + .expect("branch switch succeeds"); + let added = handle_worktree_slash_command( + Some("add"), + Some(worktree_path.to_str().expect("utf8 path")), + Some("wt-demo"), + &repo, + ) + .expect("worktree add succeeds"); + let listed_worktrees = + handle_worktree_slash_command(Some("list"), None, None, &repo).expect("list succeeds"); + let removed = handle_worktree_slash_command( + Some("remove"), + Some(worktree_path.to_str().expect("utf8 path")), + None, + &repo, + ) + .expect("remove succeeds"); + + // then + assert!(branch_list.contains("main")); + assert!(created.contains("feature/demo")); + assert!(switched.contains("main")); + assert!(added.contains("wt-demo")); + assert!(listed_worktrees.contains(worktree_path.to_str().expect("utf8 path"))); + assert!(removed.contains("Result removed")); + + let _ = fs::remove_dir_all(repo); + let _ = fs::remove_dir_all(worktree_path); + } + + #[test] + fn commit_command_stages_and_commits_changes() { + // given + let repo = init_git_repo("commit-command"); + fs::write(repo.join("notes.txt"), "hello\n").expect("write notes"); + + // when + let report = + handle_commit_slash_command("feat: add notes", &repo).expect("commit succeeds"); + let status = run_command(&repo, "git", &["status", "--short"]); + let message = run_command(&repo, "git", &["log", "-1", "--pretty=%B"]); + + // then + assert!(report.contains("Result created")); + assert!(status.trim().is_empty()); + assert_eq!(message.trim(), "feat: add notes"); + + let _ = fs::remove_dir_all(repo); + } + + #[cfg(unix)] + #[test] + fn commit_push_pr_command_commits_pushes_and_creates_pr() { + // given + let _guard = env_lock(); + let repo = init_git_repo("commit-push-pr"); + let remote = init_bare_repo("commit-push-pr-remote"); + run_command( + &repo, + "git", + &[ + "remote", + "add", + "origin", + remote.to_str().expect("utf8 remote"), + ], + ); + run_command(&repo, "git", &["push", "-u", "origin", "main"]); + fs::write(repo.join("feature.txt"), "feature\n").expect("write feature file"); + + let fake_bin = temp_dir("fake-gh-bin"); + let gh_log = fake_bin.join("gh.log"); + write_fake_gh(&fake_bin, &gh_log, "https://example.com/pr/123"); + + let previous_path = env::var_os("PATH"); + let mut new_path = fake_bin.display().to_string(); + if let Some(path) = &previous_path { + new_path.push(':'); + new_path.push_str(&path.to_string_lossy()); + } + env::set_var("PATH", &new_path); + let previous_safeuser = env::var_os("SAFEUSER"); + env::set_var("SAFEUSER", "tester"); + + let request = CommitPushPrRequest { + commit_message: Some("feat: add feature file".to_string()), + pr_title: "Add feature file".to_string(), + pr_body: "## Summary\n- add feature file".to_string(), + branch_name_hint: "Add feature file".to_string(), + }; + + // when + let report = + handle_commit_push_pr_slash_command(&request, &repo).expect("commit-push-pr succeeds"); + let branch = run_command(&repo, "git", &["branch", "--show-current"]); + let message = run_command(&repo, "git", &["log", "-1", "--pretty=%B"]); + let gh_invocations = fs::read_to_string(&gh_log).expect("gh log should exist"); + + // then + assert!(report.contains("Result created")); + assert!(report.contains("URL https://example.com/pr/123")); + assert_eq!(branch.trim(), "tester/add-feature-file"); + assert_eq!(message.trim(), "feat: add feature file"); + assert!(gh_invocations.contains("pr create")); + assert!(gh_invocations.contains("--base main")); + + if let Some(path) = previous_path { + env::set_var("PATH", path); + } else { + env::remove_var("PATH"); + } + if let Some(safeuser) = previous_safeuser { + env::set_var("SAFEUSER", safeuser); + } else { + env::remove_var("SAFEUSER"); + } + + let _ = fs::remove_dir_all(repo); + let _ = fs::remove_dir_all(remote); + let _ = fs::remove_dir_all(fake_bin); } } diff --git a/rust/crates/lsp/Cargo.toml b/rust/crates/lsp/Cargo.toml new file mode 100644 index 0000000..a2f1aec --- /dev/null +++ b/rust/crates/lsp/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "lsp" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true + +[dependencies] +lsp-types.workspace = true +serde = { version = "1", features = ["derive"] } +serde_json.workspace = true +tokio = { version = "1", features = ["io-util", "macros", "process", "rt", "rt-multi-thread", "sync", "time"] } +url = "2" + +[lints] +workspace = true diff --git a/rust/crates/lsp/src/client.rs b/rust/crates/lsp/src/client.rs new file mode 100644 index 0000000..7ec663b --- /dev/null +++ b/rust/crates/lsp/src/client.rs @@ -0,0 +1,463 @@ +use std::collections::BTreeMap; +use std::path::{Path, PathBuf}; +use std::process::Stdio; +use std::sync::Arc; +use std::sync::atomic::{AtomicI64, Ordering}; + +use lsp_types::{ + Diagnostic, GotoDefinitionResponse, Location, LocationLink, Position, PublishDiagnosticsParams, +}; +use serde_json::{json, Value}; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}; +use tokio::process::{Child, ChildStdin, ChildStdout, Command}; +use tokio::sync::{oneshot, Mutex}; + +use crate::error::LspError; +use crate::types::{LspServerConfig, SymbolLocation}; + +pub(crate) struct LspClient { + config: LspServerConfig, + writer: Mutex<BufWriter<ChildStdin>>, + child: Mutex<Child>, + pending_requests: Arc<Mutex<BTreeMap<i64, oneshot::Sender<Result<Value, LspError>>>>>, + diagnostics: Arc<Mutex<BTreeMap<String, Vec<Diagnostic>>>>, + open_documents: Mutex<BTreeMap<PathBuf, i32>>, + next_request_id: AtomicI64, +} + +impl LspClient { + pub(crate) async fn connect(config: LspServerConfig) -> Result<Self, LspError> { + let mut command = Command::new(&config.command); + command + .args(&config.args) + .current_dir(&config.workspace_root) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .envs(config.env.clone()); + + let mut child = command.spawn()?; + let stdin = child + .stdin + .take() + .ok_or_else(|| LspError::Protocol("missing LSP stdin pipe".to_string()))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| LspError::Protocol("missing LSP stdout pipe".to_string()))?; + let stderr = child.stderr.take(); + + let client = Self { + config, + writer: Mutex::new(BufWriter::new(stdin)), + child: Mutex::new(child), + pending_requests: Arc::new(Mutex::new(BTreeMap::new())), + diagnostics: Arc::new(Mutex::new(BTreeMap::new())), + open_documents: Mutex::new(BTreeMap::new()), + next_request_id: AtomicI64::new(1), + }; + + client.spawn_reader(stdout); + if let Some(stderr) = stderr { + client.spawn_stderr_drain(stderr); + } + client.initialize().await?; + Ok(client) + } + + pub(crate) async fn ensure_document_open(&self, path: &Path) -> Result<(), LspError> { + if self.is_document_open(path).await { + return Ok(()); + } + + let contents = std::fs::read_to_string(path)?; + self.open_document(path, &contents).await + } + + pub(crate) async fn open_document(&self, path: &Path, text: &str) -> Result<(), LspError> { + let uri = file_url(path)?; + let language_id = self + .config + .language_id_for(path) + .ok_or_else(|| LspError::UnsupportedDocument(path.to_path_buf()))?; + + self.notify( + "textDocument/didOpen", + json!({ + "textDocument": { + "uri": uri, + "languageId": language_id, + "version": 1, + "text": text, + } + }), + ) + .await?; + + self.open_documents + .lock() + .await + .insert(path.to_path_buf(), 1); + Ok(()) + } + + pub(crate) async fn change_document(&self, path: &Path, text: &str) -> Result<(), LspError> { + if !self.is_document_open(path).await { + return self.open_document(path, text).await; + } + + let uri = file_url(path)?; + let next_version = { + let mut open_documents = self.open_documents.lock().await; + let version = open_documents + .entry(path.to_path_buf()) + .and_modify(|value| *value += 1) + .or_insert(1); + *version + }; + + self.notify( + "textDocument/didChange", + json!({ + "textDocument": { + "uri": uri, + "version": next_version, + }, + "contentChanges": [{ + "text": text, + }], + }), + ) + .await + } + + pub(crate) async fn save_document(&self, path: &Path) -> Result<(), LspError> { + if !self.is_document_open(path).await { + return Ok(()); + } + + self.notify( + "textDocument/didSave", + json!({ + "textDocument": { + "uri": file_url(path)?, + } + }), + ) + .await + } + + pub(crate) async fn close_document(&self, path: &Path) -> Result<(), LspError> { + if !self.is_document_open(path).await { + return Ok(()); + } + + self.notify( + "textDocument/didClose", + json!({ + "textDocument": { + "uri": file_url(path)?, + } + }), + ) + .await?; + + self.open_documents.lock().await.remove(path); + Ok(()) + } + + pub(crate) async fn is_document_open(&self, path: &Path) -> bool { + self.open_documents.lock().await.contains_key(path) + } + + pub(crate) async fn go_to_definition( + &self, + path: &Path, + position: Position, + ) -> Result<Vec<SymbolLocation>, LspError> { + self.ensure_document_open(path).await?; + let response = self + .request::<Option<GotoDefinitionResponse>>( + "textDocument/definition", + json!({ + "textDocument": { "uri": file_url(path)? }, + "position": position, + }), + ) + .await?; + + Ok(match response { + Some(GotoDefinitionResponse::Scalar(location)) => { + location_to_symbol_locations(vec![location]) + } + Some(GotoDefinitionResponse::Array(locations)) => location_to_symbol_locations(locations), + Some(GotoDefinitionResponse::Link(links)) => location_links_to_symbol_locations(links), + None => Vec::new(), + }) + } + + pub(crate) async fn find_references( + &self, + path: &Path, + position: Position, + include_declaration: bool, + ) -> Result<Vec<SymbolLocation>, LspError> { + self.ensure_document_open(path).await?; + let response = self + .request::<Option<Vec<Location>>>( + "textDocument/references", + json!({ + "textDocument": { "uri": file_url(path)? }, + "position": position, + "context": { + "includeDeclaration": include_declaration, + }, + }), + ) + .await?; + + Ok(location_to_symbol_locations(response.unwrap_or_default())) + } + + pub(crate) async fn diagnostics_snapshot(&self) -> BTreeMap<String, Vec<Diagnostic>> { + self.diagnostics.lock().await.clone() + } + + pub(crate) async fn shutdown(&self) -> Result<(), LspError> { + let _ = self.request::<Value>("shutdown", json!({})).await; + let _ = self.notify("exit", Value::Null).await; + + let mut child = self.child.lock().await; + if child.kill().await.is_err() { + let _ = child.wait().await; + return Ok(()); + } + let _ = child.wait().await; + Ok(()) + } + + fn spawn_reader(&self, stdout: ChildStdout) { + let diagnostics = &self.diagnostics; + let pending_requests = &self.pending_requests; + + let diagnostics = diagnostics.clone(); + let pending_requests = pending_requests.clone(); + tokio::spawn(async move { + let mut reader = BufReader::new(stdout); + let result = async { + while let Some(message) = read_message(&mut reader).await? { + if let Some(id) = message.get("id").and_then(Value::as_i64) { + let response = if let Some(error) = message.get("error") { + Err(LspError::Protocol(error.to_string())) + } else { + Ok(message.get("result").cloned().unwrap_or(Value::Null)) + }; + + if let Some(sender) = pending_requests.lock().await.remove(&id) { + let _ = sender.send(response); + } + continue; + } + + let Some(method) = message.get("method").and_then(Value::as_str) else { + continue; + }; + if method != "textDocument/publishDiagnostics" { + continue; + } + + let params = message.get("params").cloned().unwrap_or(Value::Null); + let notification = serde_json::from_value::<PublishDiagnosticsParams>(params)?; + let mut diagnostics_map = diagnostics.lock().await; + if notification.diagnostics.is_empty() { + diagnostics_map.remove(¬ification.uri.to_string()); + } else { + diagnostics_map.insert(notification.uri.to_string(), notification.diagnostics); + } + } + Ok::<(), LspError>(()) + } + .await; + + if let Err(error) = result { + let mut pending = pending_requests.lock().await; + let drained = pending + .iter() + .map(|(id, _)| *id) + .collect::<Vec<_>>(); + for id in drained { + if let Some(sender) = pending.remove(&id) { + let _ = sender.send(Err(LspError::Protocol(error.to_string()))); + } + } + } + }); + } + + fn spawn_stderr_drain<R>(&self, stderr: R) + where + R: AsyncRead + Unpin + Send + 'static, + { + tokio::spawn(async move { + let mut reader = BufReader::new(stderr); + let mut sink = Vec::new(); + let _ = reader.read_to_end(&mut sink).await; + }); + } + + async fn initialize(&self) -> Result<(), LspError> { + let workspace_uri = file_url(&self.config.workspace_root)?; + let _ = self + .request::<Value>( + "initialize", + json!({ + "processId": std::process::id(), + "rootUri": workspace_uri, + "rootPath": self.config.workspace_root, + "workspaceFolders": [{ + "uri": workspace_uri, + "name": self.config.name, + }], + "initializationOptions": self.config.initialization_options.clone().unwrap_or(Value::Null), + "capabilities": { + "textDocument": { + "publishDiagnostics": { + "relatedInformation": true, + }, + "definition": { + "linkSupport": true, + }, + "references": {} + }, + "workspace": { + "configuration": false, + "workspaceFolders": true, + }, + "general": { + "positionEncodings": ["utf-16"], + } + } + }), + ) + .await?; + self.notify("initialized", json!({})).await + } + + async fn request<T>(&self, method: &str, params: Value) -> Result<T, LspError> + where + T: for<'de> serde::Deserialize<'de>, + { + let id = self.next_request_id.fetch_add(1, Ordering::Relaxed); + let (sender, receiver) = oneshot::channel(); + self.pending_requests.lock().await.insert(id, sender); + + if let Err(error) = self + .send_message(&json!({ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": params, + })) + .await + { + self.pending_requests.lock().await.remove(&id); + return Err(error); + } + + let response = receiver + .await + .map_err(|_| LspError::Protocol(format!("request channel closed for {method}")))??; + Ok(serde_json::from_value(response)?) + } + + async fn notify(&self, method: &str, params: Value) -> Result<(), LspError> { + self.send_message(&json!({ + "jsonrpc": "2.0", + "method": method, + "params": params, + })) + .await + } + + async fn send_message(&self, payload: &Value) -> Result<(), LspError> { + let body = serde_json::to_vec(payload)?; + let mut writer = self.writer.lock().await; + writer + .write_all(format!("Content-Length: {}\r\n\r\n", body.len()).as_bytes()) + .await?; + writer.write_all(&body).await?; + writer.flush().await?; + Ok(()) + } +} + +async fn read_message<R>(reader: &mut BufReader<R>) -> Result<Option<Value>, LspError> +where + R: AsyncRead + Unpin, +{ + let mut content_length = None; + + loop { + let mut line = String::new(); + let read = reader.read_line(&mut line).await?; + if read == 0 { + return Ok(None); + } + + if line == "\r\n" { + break; + } + + let trimmed = line.trim_end_matches(['\r', '\n']); + if let Some((name, value)) = trimmed.split_once(':') { + if name.eq_ignore_ascii_case("Content-Length") { + let value = value.trim().to_string(); + content_length = Some( + value + .parse::<usize>() + .map_err(|_| LspError::InvalidContentLength(value.clone()))?, + ); + } + } else { + return Err(LspError::InvalidHeader(trimmed.to_string())); + } + } + + let content_length = content_length.ok_or(LspError::MissingContentLength)?; + let mut body = vec![0_u8; content_length]; + reader.read_exact(&mut body).await?; + Ok(Some(serde_json::from_slice(&body)?)) +} + +fn file_url(path: &Path) -> Result<String, LspError> { + url::Url::from_file_path(path) + .map(|url| url.to_string()) + .map_err(|()| LspError::PathToUrl(path.to_path_buf())) +} + +fn location_to_symbol_locations(locations: Vec<Location>) -> Vec<SymbolLocation> { + locations + .into_iter() + .filter_map(|location| { + uri_to_path(&location.uri.to_string()).map(|path| SymbolLocation { + path, + range: location.range, + }) + }) + .collect() +} + +fn location_links_to_symbol_locations(links: Vec<LocationLink>) -> Vec<SymbolLocation> { + links.into_iter() + .filter_map(|link| { + uri_to_path(&link.target_uri.to_string()).map(|path| SymbolLocation { + path, + range: link.target_selection_range, + }) + }) + .collect() +} + +fn uri_to_path(uri: &str) -> Option<PathBuf> { + url::Url::parse(uri).ok()?.to_file_path().ok() +} diff --git a/rust/crates/lsp/src/error.rs b/rust/crates/lsp/src/error.rs new file mode 100644 index 0000000..6be1413 --- /dev/null +++ b/rust/crates/lsp/src/error.rs @@ -0,0 +1,62 @@ +use std::fmt::{Display, Formatter}; +use std::path::PathBuf; + +#[derive(Debug)] +pub enum LspError { + Io(std::io::Error), + Json(serde_json::Error), + InvalidHeader(String), + MissingContentLength, + InvalidContentLength(String), + UnsupportedDocument(PathBuf), + UnknownServer(String), + DuplicateExtension { + extension: String, + existing_server: String, + new_server: String, + }, + PathToUrl(PathBuf), + Protocol(String), +} + +impl Display for LspError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Io(error) => write!(f, "{error}"), + Self::Json(error) => write!(f, "{error}"), + Self::InvalidHeader(header) => write!(f, "invalid LSP header: {header}"), + Self::MissingContentLength => write!(f, "missing LSP Content-Length header"), + Self::InvalidContentLength(value) => { + write!(f, "invalid LSP Content-Length value: {value}") + } + Self::UnsupportedDocument(path) => { + write!(f, "no LSP server configured for {}", path.display()) + } + Self::UnknownServer(name) => write!(f, "unknown LSP server: {name}"), + Self::DuplicateExtension { + extension, + existing_server, + new_server, + } => write!( + f, + "duplicate LSP extension mapping for {extension}: {existing_server} and {new_server}" + ), + Self::PathToUrl(path) => write!(f, "failed to convert path to file URL: {}", path.display()), + Self::Protocol(message) => write!(f, "LSP protocol error: {message}"), + } + } +} + +impl std::error::Error for LspError {} + +impl From<std::io::Error> for LspError { + fn from(value: std::io::Error) -> Self { + Self::Io(value) + } +} + +impl From<serde_json::Error> for LspError { + fn from(value: serde_json::Error) -> Self { + Self::Json(value) + } +} diff --git a/rust/crates/lsp/src/lib.rs b/rust/crates/lsp/src/lib.rs new file mode 100644 index 0000000..9b1b099 --- /dev/null +++ b/rust/crates/lsp/src/lib.rs @@ -0,0 +1,283 @@ +mod client; +mod error; +mod manager; +mod types; + +pub use error::LspError; +pub use manager::LspManager; +pub use types::{ + FileDiagnostics, LspContextEnrichment, LspServerConfig, SymbolLocation, WorkspaceDiagnostics, +}; + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + use std::fs; + use std::path::PathBuf; + use std::process::Command; + use std::time::{Duration, SystemTime, UNIX_EPOCH}; + + use lsp_types::{DiagnosticSeverity, Position}; + + use crate::{LspManager, LspServerConfig}; + + fn temp_dir(label: &str) -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("lsp-{label}-{nanos}")) + } + + fn python3_path() -> Option<String> { + let candidates = ["python3", "/usr/bin/python3"]; + candidates.iter().find_map(|candidate| { + Command::new(candidate) + .arg("--version") + .output() + .ok() + .filter(|output| output.status.success()) + .map(|_| (*candidate).to_string()) + }) + } + + fn write_mock_server_script(root: &std::path::Path) -> PathBuf { + let script_path = root.join("mock_lsp_server.py"); + fs::write( + &script_path, + r#"import json +import sys + + +def read_message(): + headers = {} + while True: + line = sys.stdin.buffer.readline() + if not line: + return None + if line == b"\r\n": + break + key, value = line.decode("utf-8").split(":", 1) + headers[key.lower()] = value.strip() + length = int(headers["content-length"]) + body = sys.stdin.buffer.read(length) + return json.loads(body) + + +def write_message(payload): + raw = json.dumps(payload).encode("utf-8") + sys.stdout.buffer.write(f"Content-Length: {len(raw)}\r\n\r\n".encode("utf-8")) + sys.stdout.buffer.write(raw) + sys.stdout.buffer.flush() + + +while True: + message = read_message() + if message is None: + break + + method = message.get("method") + if method == "initialize": + write_message({ + "jsonrpc": "2.0", + "id": message["id"], + "result": { + "capabilities": { + "definitionProvider": True, + "referencesProvider": True, + "textDocumentSync": 1, + } + }, + }) + elif method == "initialized": + continue + elif method == "textDocument/didOpen": + document = message["params"]["textDocument"] + write_message({ + "jsonrpc": "2.0", + "method": "textDocument/publishDiagnostics", + "params": { + "uri": document["uri"], + "diagnostics": [ + { + "range": { + "start": {"line": 0, "character": 0}, + "end": {"line": 0, "character": 3}, + }, + "severity": 1, + "source": "mock-server", + "message": "mock error", + } + ], + }, + }) + elif method == "textDocument/didChange": + continue + elif method == "textDocument/didSave": + continue + elif method == "textDocument/definition": + uri = message["params"]["textDocument"]["uri"] + write_message({ + "jsonrpc": "2.0", + "id": message["id"], + "result": [ + { + "uri": uri, + "range": { + "start": {"line": 0, "character": 0}, + "end": {"line": 0, "character": 3}, + }, + } + ], + }) + elif method == "textDocument/references": + uri = message["params"]["textDocument"]["uri"] + write_message({ + "jsonrpc": "2.0", + "id": message["id"], + "result": [ + { + "uri": uri, + "range": { + "start": {"line": 0, "character": 0}, + "end": {"line": 0, "character": 3}, + }, + }, + { + "uri": uri, + "range": { + "start": {"line": 1, "character": 4}, + "end": {"line": 1, "character": 7}, + }, + }, + ], + }) + elif method == "shutdown": + write_message({"jsonrpc": "2.0", "id": message["id"], "result": None}) + elif method == "exit": + break +"#, + ) + .expect("mock server should be written"); + script_path + } + + async fn wait_for_diagnostics(manager: &LspManager) { + tokio::time::timeout(Duration::from_secs(2), async { + loop { + if manager + .collect_workspace_diagnostics() + .await + .expect("diagnostics snapshot should load") + .total_diagnostics() + > 0 + { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + }) + .await + .expect("diagnostics should arrive from mock server"); + } + + #[tokio::test(flavor = "current_thread")] + async fn collects_diagnostics_and_symbol_navigation_from_mock_server() { + let Some(python) = python3_path() else { + return; + }; + + // given + let root = temp_dir("manager"); + fs::create_dir_all(root.join("src")).expect("workspace root should exist"); + let script_path = write_mock_server_script(&root); + let source_path = root.join("src").join("main.rs"); + fs::write(&source_path, "fn main() {}\nlet value = 1;\n").expect("source file should exist"); + let manager = LspManager::new(vec![LspServerConfig { + name: "rust-analyzer".to_string(), + command: python, + args: vec![script_path.display().to_string()], + env: BTreeMap::new(), + workspace_root: root.clone(), + initialization_options: None, + extension_to_language: BTreeMap::from([(".rs".to_string(), "rust".to_string())]), + }]) + .expect("manager should build"); + manager + .open_document(&source_path, &fs::read_to_string(&source_path).expect("source read should succeed")) + .await + .expect("document should open"); + wait_for_diagnostics(&manager).await; + + // when + let diagnostics = manager + .collect_workspace_diagnostics() + .await + .expect("diagnostics should be available"); + let definitions = manager + .go_to_definition(&source_path, Position::new(0, 0)) + .await + .expect("definition request should succeed"); + let references = manager + .find_references(&source_path, Position::new(0, 0), true) + .await + .expect("references request should succeed"); + + // then + assert_eq!(diagnostics.files.len(), 1); + assert_eq!(diagnostics.total_diagnostics(), 1); + assert_eq!(diagnostics.files[0].diagnostics[0].severity, Some(DiagnosticSeverity::ERROR)); + assert_eq!(definitions.len(), 1); + assert_eq!(definitions[0].start_line(), 1); + assert_eq!(references.len(), 2); + + manager.shutdown().await.expect("shutdown should succeed"); + fs::remove_dir_all(root).expect("temp workspace should be removed"); + } + + #[tokio::test(flavor = "current_thread")] + async fn renders_runtime_context_enrichment_for_prompt_usage() { + let Some(python) = python3_path() else { + return; + }; + + // given + let root = temp_dir("prompt"); + fs::create_dir_all(root.join("src")).expect("workspace root should exist"); + let script_path = write_mock_server_script(&root); + let source_path = root.join("src").join("lib.rs"); + fs::write(&source_path, "pub fn answer() -> i32 { 42 }\n").expect("source file should exist"); + let manager = LspManager::new(vec![LspServerConfig { + name: "rust-analyzer".to_string(), + command: python, + args: vec![script_path.display().to_string()], + env: BTreeMap::new(), + workspace_root: root.clone(), + initialization_options: None, + extension_to_language: BTreeMap::from([(".rs".to_string(), "rust".to_string())]), + }]) + .expect("manager should build"); + manager + .open_document(&source_path, &fs::read_to_string(&source_path).expect("source read should succeed")) + .await + .expect("document should open"); + wait_for_diagnostics(&manager).await; + + // when + let enrichment = manager + .context_enrichment(&source_path, Position::new(0, 0)) + .await + .expect("context enrichment should succeed"); + let rendered = enrichment.render_prompt_section(); + + // then + assert!(rendered.contains("# LSP context")); + assert!(rendered.contains("Workspace diagnostics: 1 across 1 file(s)")); + assert!(rendered.contains("Definitions:")); + assert!(rendered.contains("References:")); + assert!(rendered.contains("mock error")); + + manager.shutdown().await.expect("shutdown should succeed"); + fs::remove_dir_all(root).expect("temp workspace should be removed"); + } +} diff --git a/rust/crates/lsp/src/manager.rs b/rust/crates/lsp/src/manager.rs new file mode 100644 index 0000000..3c99f96 --- /dev/null +++ b/rust/crates/lsp/src/manager.rs @@ -0,0 +1,191 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::path::Path; +use std::sync::Arc; + +use lsp_types::Position; +use tokio::sync::Mutex; + +use crate::client::LspClient; +use crate::error::LspError; +use crate::types::{ + normalize_extension, FileDiagnostics, LspContextEnrichment, LspServerConfig, SymbolLocation, + WorkspaceDiagnostics, +}; + +pub struct LspManager { + server_configs: BTreeMap<String, LspServerConfig>, + extension_map: BTreeMap<String, String>, + clients: Mutex<BTreeMap<String, Arc<LspClient>>>, +} + +impl LspManager { + pub fn new(server_configs: Vec<LspServerConfig>) -> Result<Self, LspError> { + let mut configs_by_name = BTreeMap::new(); + let mut extension_map = BTreeMap::new(); + + for config in server_configs { + for extension in config.extension_to_language.keys() { + let normalized = normalize_extension(extension); + if let Some(existing_server) = extension_map.insert(normalized.clone(), config.name.clone()) { + return Err(LspError::DuplicateExtension { + extension: normalized, + existing_server, + new_server: config.name.clone(), + }); + } + } + configs_by_name.insert(config.name.clone(), config); + } + + Ok(Self { + server_configs: configs_by_name, + extension_map, + clients: Mutex::new(BTreeMap::new()), + }) + } + + #[must_use] + pub fn supports_path(&self, path: &Path) -> bool { + path.extension().is_some_and(|extension| { + let normalized = normalize_extension(extension.to_string_lossy().as_ref()); + self.extension_map.contains_key(&normalized) + }) + } + + pub async fn open_document(&self, path: &Path, text: &str) -> Result<(), LspError> { + self.client_for_path(path).await?.open_document(path, text).await + } + + pub async fn sync_document_from_disk(&self, path: &Path) -> Result<(), LspError> { + let contents = std::fs::read_to_string(path)?; + self.change_document(path, &contents).await?; + self.save_document(path).await + } + + pub async fn change_document(&self, path: &Path, text: &str) -> Result<(), LspError> { + self.client_for_path(path).await?.change_document(path, text).await + } + + pub async fn save_document(&self, path: &Path) -> Result<(), LspError> { + self.client_for_path(path).await?.save_document(path).await + } + + pub async fn close_document(&self, path: &Path) -> Result<(), LspError> { + self.client_for_path(path).await?.close_document(path).await + } + + pub async fn go_to_definition( + &self, + path: &Path, + position: Position, + ) -> Result<Vec<SymbolLocation>, LspError> { + let mut locations = self.client_for_path(path).await?.go_to_definition(path, position).await?; + dedupe_locations(&mut locations); + Ok(locations) + } + + pub async fn find_references( + &self, + path: &Path, + position: Position, + include_declaration: bool, + ) -> Result<Vec<SymbolLocation>, LspError> { + let mut locations = self + .client_for_path(path) + .await? + .find_references(path, position, include_declaration) + .await?; + dedupe_locations(&mut locations); + Ok(locations) + } + + pub async fn collect_workspace_diagnostics(&self) -> Result<WorkspaceDiagnostics, LspError> { + let clients = self.clients.lock().await.values().cloned().collect::<Vec<_>>(); + let mut files = Vec::new(); + + for client in clients { + for (uri, diagnostics) in client.diagnostics_snapshot().await { + let Ok(path) = url::Url::parse(&uri) + .and_then(|url| url.to_file_path().map_err(|()| url::ParseError::RelativeUrlWithoutBase)) + else { + continue; + }; + if diagnostics.is_empty() { + continue; + } + files.push(FileDiagnostics { + path, + uri, + diagnostics, + }); + } + } + + files.sort_by(|left, right| left.path.cmp(&right.path)); + Ok(WorkspaceDiagnostics { files }) + } + + pub async fn context_enrichment( + &self, + path: &Path, + position: Position, + ) -> Result<LspContextEnrichment, LspError> { + Ok(LspContextEnrichment { + file_path: path.to_path_buf(), + diagnostics: self.collect_workspace_diagnostics().await?, + definitions: self.go_to_definition(path, position).await?, + references: self.find_references(path, position, true).await?, + }) + } + + pub async fn shutdown(&self) -> Result<(), LspError> { + let mut clients = self.clients.lock().await; + let drained = clients.values().cloned().collect::<Vec<_>>(); + clients.clear(); + drop(clients); + + for client in drained { + client.shutdown().await?; + } + Ok(()) + } + + async fn client_for_path(&self, path: &Path) -> Result<Arc<LspClient>, LspError> { + let extension = path + .extension() + .map(|extension| normalize_extension(extension.to_string_lossy().as_ref())) + .ok_or_else(|| LspError::UnsupportedDocument(path.to_path_buf()))?; + let server_name = self + .extension_map + .get(&extension) + .cloned() + .ok_or_else(|| LspError::UnsupportedDocument(path.to_path_buf()))?; + + let mut clients = self.clients.lock().await; + if let Some(client) = clients.get(&server_name) { + return Ok(client.clone()); + } + + let config = self + .server_configs + .get(&server_name) + .cloned() + .ok_or_else(|| LspError::UnknownServer(server_name.clone()))?; + let client = Arc::new(LspClient::connect(config).await?); + clients.insert(server_name, client.clone()); + Ok(client) + } +} + +fn dedupe_locations(locations: &mut Vec<SymbolLocation>) { + let mut seen = BTreeSet::new(); + locations.retain(|location| { + seen.insert(( + location.path.clone(), + location.range.start.line, + location.range.start.character, + location.range.end.line, + location.range.end.character, + )) + }); +} diff --git a/rust/crates/lsp/src/types.rs b/rust/crates/lsp/src/types.rs new file mode 100644 index 0000000..ab2573f --- /dev/null +++ b/rust/crates/lsp/src/types.rs @@ -0,0 +1,186 @@ +use std::collections::BTreeMap; +use std::fmt::{Display, Formatter}; +use std::path::{Path, PathBuf}; + +use lsp_types::{Diagnostic, Range}; +use serde_json::Value; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct LspServerConfig { + pub name: String, + pub command: String, + pub args: Vec<String>, + pub env: BTreeMap<String, String>, + pub workspace_root: PathBuf, + pub initialization_options: Option<Value>, + pub extension_to_language: BTreeMap<String, String>, +} + +impl LspServerConfig { + #[must_use] + pub fn language_id_for(&self, path: &Path) -> Option<&str> { + let extension = normalize_extension(path.extension()?.to_string_lossy().as_ref()); + self.extension_to_language + .get(&extension) + .map(String::as_str) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct FileDiagnostics { + pub path: PathBuf, + pub uri: String, + pub diagnostics: Vec<Diagnostic>, +} + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct WorkspaceDiagnostics { + pub files: Vec<FileDiagnostics>, +} + +impl WorkspaceDiagnostics { + #[must_use] + pub fn is_empty(&self) -> bool { + self.files.is_empty() + } + + #[must_use] + pub fn total_diagnostics(&self) -> usize { + self.files.iter().map(|file| file.diagnostics.len()).sum() + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SymbolLocation { + pub path: PathBuf, + pub range: Range, +} + +impl SymbolLocation { + #[must_use] + pub fn start_line(&self) -> u32 { + self.range.start.line + 1 + } + + #[must_use] + pub fn start_character(&self) -> u32 { + self.range.start.character + 1 + } +} + +impl Display for SymbolLocation { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}:{}:{}", + self.path.display(), + self.start_line(), + self.start_character() + ) + } +} + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct LspContextEnrichment { + pub file_path: PathBuf, + pub diagnostics: WorkspaceDiagnostics, + pub definitions: Vec<SymbolLocation>, + pub references: Vec<SymbolLocation>, +} + +impl LspContextEnrichment { + #[must_use] + pub fn is_empty(&self) -> bool { + self.diagnostics.is_empty() && self.definitions.is_empty() && self.references.is_empty() + } + + #[must_use] + pub fn render_prompt_section(&self) -> String { + const MAX_RENDERED_DIAGNOSTICS: usize = 12; + const MAX_RENDERED_LOCATIONS: usize = 12; + + let mut lines = vec!["# LSP context".to_string()]; + lines.push(format!(" - Focus file: {}", self.file_path.display())); + lines.push(format!( + " - Workspace diagnostics: {} across {} file(s)", + self.diagnostics.total_diagnostics(), + self.diagnostics.files.len() + )); + + if !self.diagnostics.files.is_empty() { + lines.push(String::new()); + lines.push("Diagnostics:".to_string()); + let mut rendered = 0usize; + for file in &self.diagnostics.files { + for diagnostic in &file.diagnostics { + if rendered == MAX_RENDERED_DIAGNOSTICS { + lines.push(" - Additional diagnostics omitted for brevity.".to_string()); + break; + } + let severity = diagnostic_severity_label(diagnostic.severity); + lines.push(format!( + " - {}:{}:{} [{}] {}", + file.path.display(), + diagnostic.range.start.line + 1, + diagnostic.range.start.character + 1, + severity, + diagnostic.message.replace('\n', " ") + )); + rendered += 1; + } + if rendered == MAX_RENDERED_DIAGNOSTICS { + break; + } + } + } + + if !self.definitions.is_empty() { + lines.push(String::new()); + lines.push("Definitions:".to_string()); + lines.extend( + self.definitions + .iter() + .take(MAX_RENDERED_LOCATIONS) + .map(|location| format!(" - {location}")), + ); + if self.definitions.len() > MAX_RENDERED_LOCATIONS { + lines.push(" - Additional definitions omitted for brevity.".to_string()); + } + } + + if !self.references.is_empty() { + lines.push(String::new()); + lines.push("References:".to_string()); + lines.extend( + self.references + .iter() + .take(MAX_RENDERED_LOCATIONS) + .map(|location| format!(" - {location}")), + ); + if self.references.len() > MAX_RENDERED_LOCATIONS { + lines.push(" - Additional references omitted for brevity.".to_string()); + } + } + + lines.join("\n") + } +} + +#[must_use] +pub(crate) fn normalize_extension(extension: &str) -> String { + if extension.starts_with('.') { + extension.to_ascii_lowercase() + } else { + format!(".{}", extension.to_ascii_lowercase()) + } +} + +fn diagnostic_severity_label(severity: Option<lsp_types::DiagnosticSeverity>) -> &'static str { + match severity { + Some(lsp_types::DiagnosticSeverity::ERROR) => "error", + Some(lsp_types::DiagnosticSeverity::WARNING) => "warning", + Some(lsp_types::DiagnosticSeverity::INFORMATION) => "info", + Some(lsp_types::DiagnosticSeverity::HINT) => "hint", + _ => "unknown", + } +} diff --git a/rust/crates/plugins/Cargo.toml b/rust/crates/plugins/Cargo.toml new file mode 100644 index 0000000..11213b5 --- /dev/null +++ b/rust/crates/plugins/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "plugins" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true + +[dependencies] +serde = { version = "1", features = ["derive"] } +serde_json.workspace = true + +[lints] +workspace = true diff --git a/rust/crates/plugins/bundled/example-bundled/.claw-plugin/plugin.json b/rust/crates/plugins/bundled/example-bundled/.claw-plugin/plugin.json new file mode 100644 index 0000000..81a4220 --- /dev/null +++ b/rust/crates/plugins/bundled/example-bundled/.claw-plugin/plugin.json @@ -0,0 +1,10 @@ +{ + "name": "example-bundled", + "version": "0.1.0", + "description": "Example bundled plugin scaffold for the Rust plugin system", + "defaultEnabled": false, + "hooks": { + "PreToolUse": ["./hooks/pre.sh"], + "PostToolUse": ["./hooks/post.sh"] + } +} diff --git a/rust/crates/plugins/bundled/example-bundled/hooks/post.sh b/rust/crates/plugins/bundled/example-bundled/hooks/post.sh new file mode 100644 index 0000000..c9eb66f --- /dev/null +++ b/rust/crates/plugins/bundled/example-bundled/hooks/post.sh @@ -0,0 +1,2 @@ +#!/bin/sh +printf '%s\n' 'example bundled post hook' diff --git a/rust/crates/plugins/bundled/example-bundled/hooks/pre.sh b/rust/crates/plugins/bundled/example-bundled/hooks/pre.sh new file mode 100644 index 0000000..af6b46b --- /dev/null +++ b/rust/crates/plugins/bundled/example-bundled/hooks/pre.sh @@ -0,0 +1,2 @@ +#!/bin/sh +printf '%s\n' 'example bundled pre hook' diff --git a/rust/crates/plugins/bundled/sample-hooks/.claw-plugin/plugin.json b/rust/crates/plugins/bundled/sample-hooks/.claw-plugin/plugin.json new file mode 100644 index 0000000..555f5df --- /dev/null +++ b/rust/crates/plugins/bundled/sample-hooks/.claw-plugin/plugin.json @@ -0,0 +1,10 @@ +{ + "name": "sample-hooks", + "version": "0.1.0", + "description": "Bundled sample plugin scaffold for hook integration tests.", + "defaultEnabled": false, + "hooks": { + "PreToolUse": ["./hooks/pre.sh"], + "PostToolUse": ["./hooks/post.sh"] + } +} diff --git a/rust/crates/plugins/bundled/sample-hooks/hooks/post.sh b/rust/crates/plugins/bundled/sample-hooks/hooks/post.sh new file mode 100644 index 0000000..c968e6d --- /dev/null +++ b/rust/crates/plugins/bundled/sample-hooks/hooks/post.sh @@ -0,0 +1,2 @@ +#!/bin/sh +printf 'sample bundled post hook' diff --git a/rust/crates/plugins/bundled/sample-hooks/hooks/pre.sh b/rust/crates/plugins/bundled/sample-hooks/hooks/pre.sh new file mode 100644 index 0000000..9560881 --- /dev/null +++ b/rust/crates/plugins/bundled/sample-hooks/hooks/pre.sh @@ -0,0 +1,2 @@ +#!/bin/sh +printf 'sample bundled pre hook' diff --git a/rust/crates/plugins/src/hooks.rs b/rust/crates/plugins/src/hooks.rs new file mode 100644 index 0000000..fde23e8 --- /dev/null +++ b/rust/crates/plugins/src/hooks.rs @@ -0,0 +1,395 @@ +use std::ffi::OsStr; +use std::path::Path; +use std::process::Command; + +use serde_json::json; + +use crate::{PluginError, PluginHooks, PluginRegistry}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HookEvent { + PreToolUse, + PostToolUse, +} + +impl HookEvent { + fn as_str(self) -> &'static str { + match self { + Self::PreToolUse => "PreToolUse", + Self::PostToolUse => "PostToolUse", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HookRunResult { + denied: bool, + messages: Vec<String>, +} + +impl HookRunResult { + #[must_use] + pub fn allow(messages: Vec<String>) -> Self { + Self { + denied: false, + messages, + } + } + + #[must_use] + pub fn is_denied(&self) -> bool { + self.denied + } + + #[must_use] + pub fn messages(&self) -> &[String] { + &self.messages + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct HookRunner { + hooks: PluginHooks, +} + +impl HookRunner { + #[must_use] + pub fn new(hooks: PluginHooks) -> Self { + Self { hooks } + } + + pub fn from_registry(plugin_registry: &PluginRegistry) -> Result<Self, PluginError> { + Ok(Self::new(plugin_registry.aggregated_hooks()?)) + } + + #[must_use] + pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { + self.run_commands( + HookEvent::PreToolUse, + &self.hooks.pre_tool_use, + tool_name, + tool_input, + None, + false, + ) + } + + #[must_use] + pub fn run_post_tool_use( + &self, + tool_name: &str, + tool_input: &str, + tool_output: &str, + is_error: bool, + ) -> HookRunResult { + self.run_commands( + HookEvent::PostToolUse, + &self.hooks.post_tool_use, + tool_name, + tool_input, + Some(tool_output), + is_error, + ) + } + + fn run_commands( + &self, + event: HookEvent, + commands: &[String], + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, + ) -> HookRunResult { + if commands.is_empty() { + return HookRunResult::allow(Vec::new()); + } + + let payload = json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_output": tool_output, + "tool_result_is_error": is_error, + }) + .to_string(); + + let mut messages = Vec::new(); + + for command in commands { + match self.run_command( + command, + event, + tool_name, + tool_input, + tool_output, + is_error, + &payload, + ) { + HookCommandOutcome::Allow { message } => { + if let Some(message) = message { + messages.push(message); + } + } + HookCommandOutcome::Deny { message } => { + messages.push(message.unwrap_or_else(|| { + format!("{} hook denied tool `{tool_name}`", event.as_str()) + })); + return HookRunResult { + denied: true, + messages, + }; + } + HookCommandOutcome::Warn { message } => messages.push(message), + } + } + + HookRunResult::allow(messages) + } + + #[allow(clippy::too_many_arguments, clippy::unused_self)] + fn run_command( + &self, + command: &str, + event: HookEvent, + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, + payload: &str, + ) -> HookCommandOutcome { + let mut child = shell_command(command); + child.stdin(std::process::Stdio::piped()); + child.stdout(std::process::Stdio::piped()); + child.stderr(std::process::Stdio::piped()); + child.env("HOOK_EVENT", event.as_str()); + child.env("HOOK_TOOL_NAME", tool_name); + child.env("HOOK_TOOL_INPUT", tool_input); + child.env("HOOK_TOOL_IS_ERROR", if is_error { "1" } else { "0" }); + if let Some(tool_output) = tool_output { + child.env("HOOK_TOOL_OUTPUT", tool_output); + } + + match child.output_with_stdin(payload.as_bytes()) { + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + let message = (!stdout.is_empty()).then_some(stdout); + match output.status.code() { + Some(0) => HookCommandOutcome::Allow { message }, + Some(2) => HookCommandOutcome::Deny { message }, + Some(code) => HookCommandOutcome::Warn { + message: format_hook_warning( + command, + code, + message.as_deref(), + stderr.as_str(), + ), + }, + None => HookCommandOutcome::Warn { + message: format!( + "{} hook `{command}` terminated by signal while handling `{tool_name}`", + event.as_str() + ), + }, + } + } + Err(error) => HookCommandOutcome::Warn { + message: format!( + "{} hook `{command}` failed to start for `{tool_name}`: {error}", + event.as_str() + ), + }, + } + } +} + +enum HookCommandOutcome { + Allow { message: Option<String> }, + Deny { message: Option<String> }, + Warn { message: String }, +} + +fn parse_tool_input(tool_input: &str) -> serde_json::Value { + serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input })) +} + +fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String { + let mut message = + format!("Hook `{command}` exited with status {code}; allowing tool execution to continue"); + if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) { + message.push_str(": "); + message.push_str(stdout); + } else if !stderr.is_empty() { + message.push_str(": "); + message.push_str(stderr); + } + message +} + +fn shell_command(command: &str) -> CommandWithStdin { + #[cfg(windows)] + let command_builder = { + let mut command_builder = Command::new("cmd"); + command_builder.arg("/C").arg(command); + CommandWithStdin::new(command_builder) + }; + + #[cfg(not(windows))] + let command_builder = if Path::new(command).exists() { + let mut command_builder = Command::new("sh"); + command_builder.arg(command); + CommandWithStdin::new(command_builder) + } else { + let mut command_builder = Command::new("sh"); + command_builder.arg("-lc").arg(command); + CommandWithStdin::new(command_builder) + }; + + command_builder +} + +struct CommandWithStdin { + command: Command, +} + +impl CommandWithStdin { + fn new(command: Command) -> Self { + Self { command } + } + + fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stdin(cfg); + self + } + + fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stdout(cfg); + self + } + + fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stderr(cfg); + self + } + + fn env<K, V>(&mut self, key: K, value: V) -> &mut Self + where + K: AsRef<OsStr>, + V: AsRef<OsStr>, + { + self.command.env(key, value); + self + } + + fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result<std::process::Output> { + let mut child = self.command.spawn()?; + if let Some(mut child_stdin) = child.stdin.take() { + use std::io::Write as _; + child_stdin.write_all(stdin)?; + } + child.wait_with_output() + } +} + +#[cfg(test)] +mod tests { + use super::{HookRunResult, HookRunner}; + use crate::{PluginManager, PluginManagerConfig}; + use std::fs; + use std::path::{Path, PathBuf}; + use std::time::{SystemTime, UNIX_EPOCH}; + + fn temp_dir(label: &str) -> PathBuf { + let nanos = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("plugins-hook-runner-{label}-{nanos}")) + } + + fn write_hook_plugin(root: &Path, name: &str, pre_message: &str, post_message: &str) { + fs::create_dir_all(root.join(".claw-plugin")).expect("manifest dir"); + fs::create_dir_all(root.join("hooks")).expect("hooks dir"); + fs::write( + root.join("hooks").join("pre.sh"), + format!("#!/bin/sh\nprintf '%s\\n' '{pre_message}'\n"), + ) + .expect("write pre hook"); + fs::write( + root.join("hooks").join("post.sh"), + format!("#!/bin/sh\nprintf '%s\\n' '{post_message}'\n"), + ) + .expect("write post hook"); + fs::write( + root.join(".claw-plugin").join("plugin.json"), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"hook plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"]\n }}\n}}" + ), + ) + .expect("write plugin manifest"); + } + + #[test] + fn collects_and_runs_hooks_from_enabled_plugins() { + let config_home = temp_dir("config"); + let first_source_root = temp_dir("source-a"); + let second_source_root = temp_dir("source-b"); + write_hook_plugin( + &first_source_root, + "first", + "plugin pre one", + "plugin post one", + ); + write_hook_plugin( + &second_source_root, + "second", + "plugin pre two", + "plugin post two", + ); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + manager + .install(first_source_root.to_str().expect("utf8 path")) + .expect("first plugin install should succeed"); + manager + .install(second_source_root.to_str().expect("utf8 path")) + .expect("second plugin install should succeed"); + let registry = manager.plugin_registry().expect("registry should build"); + + let runner = HookRunner::from_registry(®istry).expect("plugin hooks should load"); + + assert_eq!( + runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#), + HookRunResult::allow(vec![ + "plugin pre one".to_string(), + "plugin pre two".to_string(), + ]) + ); + assert_eq!( + runner.run_post_tool_use("Read", r#"{"path":"README.md"}"#, "ok", false), + HookRunResult::allow(vec![ + "plugin post one".to_string(), + "plugin post two".to_string(), + ]) + ); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(first_source_root); + let _ = fs::remove_dir_all(second_source_root); + } + + #[test] + fn pre_tool_use_denies_when_plugin_hook_exits_two() { + let runner = HookRunner::new(crate::PluginHooks { + pre_tool_use: vec!["printf 'blocked by plugin'; exit 2".to_string()], + post_tool_use: Vec::new(), + }); + + let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); + + assert!(result.is_denied()); + assert_eq!(result.messages(), &["blocked by plugin".to_string()]); + } +} diff --git a/rust/crates/plugins/src/lib.rs b/rust/crates/plugins/src/lib.rs new file mode 100644 index 0000000..6105ad9 --- /dev/null +++ b/rust/crates/plugins/src/lib.rs @@ -0,0 +1,2943 @@ +mod hooks; + +use std::collections::{BTreeMap, BTreeSet}; +use std::fmt::{Display, Formatter}; +use std::fs; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; + +pub use hooks::{HookEvent, HookRunResult, HookRunner}; + +const EXTERNAL_MARKETPLACE: &str = "external"; +const BUILTIN_MARKETPLACE: &str = "builtin"; +const BUNDLED_MARKETPLACE: &str = "bundled"; +const SETTINGS_FILE_NAME: &str = "settings.json"; +const REGISTRY_FILE_NAME: &str = "installed.json"; +const MANIFEST_FILE_NAME: &str = "plugin.json"; +const MANIFEST_RELATIVE_PATH: &str = ".claw-plugin/plugin.json"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum PluginKind { + Builtin, + Bundled, + External, +} + +impl Display for PluginKind { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Builtin => write!(f, "builtin"), + Self::Bundled => write!(f, "bundled"), + Self::External => write!(f, "external"), + } + } +} + +impl PluginKind { + #[must_use] + fn marketplace(self) -> &'static str { + match self { + Self::Builtin => BUILTIN_MARKETPLACE, + Self::Bundled => BUNDLED_MARKETPLACE, + Self::External => EXTERNAL_MARKETPLACE, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PluginMetadata { + pub id: String, + pub name: String, + pub version: String, + pub description: String, + pub kind: PluginKind, + pub source: String, + pub default_enabled: bool, + pub root: Option<PathBuf>, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct PluginHooks { + #[serde(rename = "PreToolUse", default)] + pub pre_tool_use: Vec<String>, + #[serde(rename = "PostToolUse", default)] + pub post_tool_use: Vec<String>, +} + +impl PluginHooks { + #[must_use] + pub fn is_empty(&self) -> bool { + self.pre_tool_use.is_empty() && self.post_tool_use.is_empty() + } + + #[must_use] + pub fn merged_with(&self, other: &Self) -> Self { + let mut merged = self.clone(); + merged + .pre_tool_use + .extend(other.pre_tool_use.iter().cloned()); + merged + .post_tool_use + .extend(other.post_tool_use.iter().cloned()); + merged + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct PluginLifecycle { + #[serde(rename = "Init", default)] + pub init: Vec<String>, + #[serde(rename = "Shutdown", default)] + pub shutdown: Vec<String>, +} + +impl PluginLifecycle { + #[must_use] + pub fn is_empty(&self) -> bool { + self.init.is_empty() && self.shutdown.is_empty() + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PluginManifest { + pub name: String, + pub version: String, + pub description: String, + pub permissions: Vec<PluginPermission>, + #[serde(rename = "defaultEnabled", default)] + pub default_enabled: bool, + #[serde(default)] + pub hooks: PluginHooks, + #[serde(default)] + pub lifecycle: PluginLifecycle, + #[serde(default)] + pub tools: Vec<PluginToolManifest>, + #[serde(default)] + pub commands: Vec<PluginCommandManifest>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum PluginPermission { + Read, + Write, + Execute, +} + +impl PluginPermission { + #[must_use] + pub fn as_str(self) -> &'static str { + match self { + Self::Read => "read", + Self::Write => "write", + Self::Execute => "execute", + } + } + + fn parse(value: &str) -> Option<Self> { + match value { + "read" => Some(Self::Read), + "write" => Some(Self::Write), + "execute" => Some(Self::Execute), + _ => None, + } + } +} + +impl AsRef<str> for PluginPermission { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PluginToolManifest { + pub name: String, + pub description: String, + #[serde(rename = "inputSchema")] + pub input_schema: Value, + pub command: String, + #[serde(default)] + pub args: Vec<String>, + pub required_permission: PluginToolPermission, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum PluginToolPermission { + ReadOnly, + WorkspaceWrite, + DangerFullAccess, +} + +impl PluginToolPermission { + #[must_use] + pub fn as_str(self) -> &'static str { + match self { + Self::ReadOnly => "read-only", + Self::WorkspaceWrite => "workspace-write", + Self::DangerFullAccess => "danger-full-access", + } + } + + fn parse(value: &str) -> Option<Self> { + match value { + "read-only" => Some(Self::ReadOnly), + "workspace-write" => Some(Self::WorkspaceWrite), + "danger-full-access" => Some(Self::DangerFullAccess), + _ => None, + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct PluginToolDefinition { + pub name: String, + #[serde(default)] + pub description: Option<String>, + #[serde(rename = "inputSchema")] + pub input_schema: Value, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct PluginCommandManifest { + pub name: String, + pub description: String, + pub command: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct RawPluginManifest { + pub name: String, + pub version: String, + pub description: String, + #[serde(default)] + pub permissions: Vec<String>, + #[serde(rename = "defaultEnabled", default)] + pub default_enabled: bool, + #[serde(default)] + pub hooks: PluginHooks, + #[serde(default)] + pub lifecycle: PluginLifecycle, + #[serde(default)] + pub tools: Vec<RawPluginToolManifest>, + #[serde(default)] + pub commands: Vec<PluginCommandManifest>, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +struct RawPluginToolManifest { + pub name: String, + pub description: String, + #[serde(rename = "inputSchema")] + pub input_schema: Value, + pub command: String, + #[serde(default)] + pub args: Vec<String>, + #[serde( + rename = "requiredPermission", + default = "default_tool_permission_label" + )] + pub required_permission: String, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct PluginTool { + plugin_id: String, + plugin_name: String, + definition: PluginToolDefinition, + command: String, + args: Vec<String>, + required_permission: PluginToolPermission, + root: Option<PathBuf>, +} + +impl PluginTool { + #[must_use] + pub fn new( + plugin_id: impl Into<String>, + plugin_name: impl Into<String>, + definition: PluginToolDefinition, + command: impl Into<String>, + args: Vec<String>, + required_permission: PluginToolPermission, + root: Option<PathBuf>, + ) -> Self { + Self { + plugin_id: plugin_id.into(), + plugin_name: plugin_name.into(), + definition, + command: command.into(), + args, + required_permission, + root, + } + } + + #[must_use] + pub fn plugin_id(&self) -> &str { + &self.plugin_id + } + + #[must_use] + pub fn definition(&self) -> &PluginToolDefinition { + &self.definition + } + + #[must_use] + pub fn required_permission(&self) -> &str { + self.required_permission.as_str() + } + + pub fn execute(&self, input: &Value) -> Result<String, PluginError> { + let input_json = input.to_string(); + let mut process = Command::new(&self.command); + process + .args(&self.args) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .env("CLAW_PLUGIN_ID", &self.plugin_id) + .env("CLAW_PLUGIN_NAME", &self.plugin_name) + .env("CLAW_TOOL_NAME", &self.definition.name) + .env("CLAW_TOOL_INPUT", &input_json); + if let Some(root) = &self.root { + process + .current_dir(root) + .env("CLAW_PLUGIN_ROOT", root.display().to_string()); + } + + let mut child = process.spawn()?; + if let Some(stdin) = child.stdin.as_mut() { + use std::io::Write as _; + stdin.write_all(input_json.as_bytes())?; + } + + let output = child.wait_with_output()?; + if output.status.success() { + Ok(String::from_utf8_lossy(&output.stdout).trim().to_string()) + } else { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + Err(PluginError::CommandFailed(format!( + "plugin tool `{}` from `{}` failed for `{}`: {}", + self.definition.name, + self.plugin_id, + self.command, + if stderr.is_empty() { + format!("exit status {}", output.status) + } else { + stderr + } + ))) + } + } +} + +fn default_tool_permission_label() -> String { + "danger-full-access".to_string() +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum PluginInstallSource { + LocalPath { path: PathBuf }, + GitUrl { url: String }, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct InstalledPluginRecord { + #[serde(default = "default_plugin_kind")] + pub kind: PluginKind, + pub id: String, + pub name: String, + pub version: String, + pub description: String, + pub install_path: PathBuf, + pub source: PluginInstallSource, + pub installed_at_unix_ms: u128, + pub updated_at_unix_ms: u128, +} + +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct InstalledPluginRegistry { + #[serde(default)] + pub plugins: BTreeMap<String, InstalledPluginRecord>, +} + +fn default_plugin_kind() -> PluginKind { + PluginKind::External +} + +#[derive(Debug, Clone, PartialEq)] +pub struct BuiltinPlugin { + metadata: PluginMetadata, + hooks: PluginHooks, + lifecycle: PluginLifecycle, + tools: Vec<PluginTool>, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct BundledPlugin { + metadata: PluginMetadata, + hooks: PluginHooks, + lifecycle: PluginLifecycle, + tools: Vec<PluginTool>, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ExternalPlugin { + metadata: PluginMetadata, + hooks: PluginHooks, + lifecycle: PluginLifecycle, + tools: Vec<PluginTool>, +} + +pub trait Plugin { + fn metadata(&self) -> &PluginMetadata; + fn hooks(&self) -> &PluginHooks; + fn lifecycle(&self) -> &PluginLifecycle; + fn tools(&self) -> &[PluginTool]; + fn validate(&self) -> Result<(), PluginError>; + fn initialize(&self) -> Result<(), PluginError>; + fn shutdown(&self) -> Result<(), PluginError>; +} + +#[derive(Debug, Clone, PartialEq)] +pub enum PluginDefinition { + Builtin(BuiltinPlugin), + Bundled(BundledPlugin), + External(ExternalPlugin), +} + +impl Plugin for BuiltinPlugin { + fn metadata(&self) -> &PluginMetadata { + &self.metadata + } + + fn hooks(&self) -> &PluginHooks { + &self.hooks + } + + fn lifecycle(&self) -> &PluginLifecycle { + &self.lifecycle + } + + fn tools(&self) -> &[PluginTool] { + &self.tools + } + + fn validate(&self) -> Result<(), PluginError> { + Ok(()) + } + + fn initialize(&self) -> Result<(), PluginError> { + Ok(()) + } + + fn shutdown(&self) -> Result<(), PluginError> { + Ok(()) + } +} + +impl Plugin for BundledPlugin { + fn metadata(&self) -> &PluginMetadata { + &self.metadata + } + + fn hooks(&self) -> &PluginHooks { + &self.hooks + } + + fn lifecycle(&self) -> &PluginLifecycle { + &self.lifecycle + } + + fn tools(&self) -> &[PluginTool] { + &self.tools + } + + fn validate(&self) -> Result<(), PluginError> { + validate_hook_paths(self.metadata.root.as_deref(), &self.hooks)?; + validate_lifecycle_paths(self.metadata.root.as_deref(), &self.lifecycle)?; + validate_tool_paths(self.metadata.root.as_deref(), &self.tools) + } + + fn initialize(&self) -> Result<(), PluginError> { + run_lifecycle_commands( + self.metadata(), + self.lifecycle(), + "init", + &self.lifecycle.init, + ) + } + + fn shutdown(&self) -> Result<(), PluginError> { + run_lifecycle_commands( + self.metadata(), + self.lifecycle(), + "shutdown", + &self.lifecycle.shutdown, + ) + } +} + +impl Plugin for ExternalPlugin { + fn metadata(&self) -> &PluginMetadata { + &self.metadata + } + + fn hooks(&self) -> &PluginHooks { + &self.hooks + } + + fn lifecycle(&self) -> &PluginLifecycle { + &self.lifecycle + } + + fn tools(&self) -> &[PluginTool] { + &self.tools + } + + fn validate(&self) -> Result<(), PluginError> { + validate_hook_paths(self.metadata.root.as_deref(), &self.hooks)?; + validate_lifecycle_paths(self.metadata.root.as_deref(), &self.lifecycle)?; + validate_tool_paths(self.metadata.root.as_deref(), &self.tools) + } + + fn initialize(&self) -> Result<(), PluginError> { + run_lifecycle_commands( + self.metadata(), + self.lifecycle(), + "init", + &self.lifecycle.init, + ) + } + + fn shutdown(&self) -> Result<(), PluginError> { + run_lifecycle_commands( + self.metadata(), + self.lifecycle(), + "shutdown", + &self.lifecycle.shutdown, + ) + } +} + +impl Plugin for PluginDefinition { + fn metadata(&self) -> &PluginMetadata { + match self { + Self::Builtin(plugin) => plugin.metadata(), + Self::Bundled(plugin) => plugin.metadata(), + Self::External(plugin) => plugin.metadata(), + } + } + + fn hooks(&self) -> &PluginHooks { + match self { + Self::Builtin(plugin) => plugin.hooks(), + Self::Bundled(plugin) => plugin.hooks(), + Self::External(plugin) => plugin.hooks(), + } + } + + fn lifecycle(&self) -> &PluginLifecycle { + match self { + Self::Builtin(plugin) => plugin.lifecycle(), + Self::Bundled(plugin) => plugin.lifecycle(), + Self::External(plugin) => plugin.lifecycle(), + } + } + + fn tools(&self) -> &[PluginTool] { + match self { + Self::Builtin(plugin) => plugin.tools(), + Self::Bundled(plugin) => plugin.tools(), + Self::External(plugin) => plugin.tools(), + } + } + + fn validate(&self) -> Result<(), PluginError> { + match self { + Self::Builtin(plugin) => plugin.validate(), + Self::Bundled(plugin) => plugin.validate(), + Self::External(plugin) => plugin.validate(), + } + } + + fn initialize(&self) -> Result<(), PluginError> { + match self { + Self::Builtin(plugin) => plugin.initialize(), + Self::Bundled(plugin) => plugin.initialize(), + Self::External(plugin) => plugin.initialize(), + } + } + + fn shutdown(&self) -> Result<(), PluginError> { + match self { + Self::Builtin(plugin) => plugin.shutdown(), + Self::Bundled(plugin) => plugin.shutdown(), + Self::External(plugin) => plugin.shutdown(), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct RegisteredPlugin { + definition: PluginDefinition, + enabled: bool, +} + +impl RegisteredPlugin { + #[must_use] + pub fn new(definition: PluginDefinition, enabled: bool) -> Self { + Self { + definition, + enabled, + } + } + + #[must_use] + pub fn metadata(&self) -> &PluginMetadata { + self.definition.metadata() + } + + #[must_use] + pub fn hooks(&self) -> &PluginHooks { + self.definition.hooks() + } + + #[must_use] + pub fn tools(&self) -> &[PluginTool] { + self.definition.tools() + } + + #[must_use] + pub fn is_enabled(&self) -> bool { + self.enabled + } + + pub fn validate(&self) -> Result<(), PluginError> { + self.definition.validate() + } + + pub fn initialize(&self) -> Result<(), PluginError> { + self.definition.initialize() + } + + pub fn shutdown(&self) -> Result<(), PluginError> { + self.definition.shutdown() + } + + #[must_use] + pub fn summary(&self) -> PluginSummary { + PluginSummary { + metadata: self.metadata().clone(), + enabled: self.enabled, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PluginSummary { + pub metadata: PluginMetadata, + pub enabled: bool, +} + +#[derive(Debug, Clone, Default, PartialEq)] +pub struct PluginRegistry { + plugins: Vec<RegisteredPlugin>, +} + +impl PluginRegistry { + #[must_use] + pub fn new(mut plugins: Vec<RegisteredPlugin>) -> Self { + plugins.sort_by(|left, right| left.metadata().id.cmp(&right.metadata().id)); + Self { plugins } + } + + #[must_use] + pub fn plugins(&self) -> &[RegisteredPlugin] { + &self.plugins + } + + #[must_use] + pub fn get(&self, plugin_id: &str) -> Option<&RegisteredPlugin> { + self.plugins + .iter() + .find(|plugin| plugin.metadata().id == plugin_id) + } + + #[must_use] + pub fn contains(&self, plugin_id: &str) -> bool { + self.get(plugin_id).is_some() + } + + #[must_use] + pub fn summaries(&self) -> Vec<PluginSummary> { + self.plugins.iter().map(RegisteredPlugin::summary).collect() + } + + pub fn aggregated_hooks(&self) -> Result<PluginHooks, PluginError> { + self.plugins + .iter() + .filter(|plugin| plugin.is_enabled()) + .try_fold(PluginHooks::default(), |acc, plugin| { + plugin.validate()?; + Ok(acc.merged_with(plugin.hooks())) + }) + } + + pub fn aggregated_tools(&self) -> Result<Vec<PluginTool>, PluginError> { + let mut tools = Vec::new(); + let mut seen_names = BTreeMap::new(); + for plugin in self.plugins.iter().filter(|plugin| plugin.is_enabled()) { + plugin.validate()?; + for tool in plugin.tools() { + if let Some(existing_plugin) = + seen_names.insert(tool.definition().name.clone(), tool.plugin_id().to_string()) + { + return Err(PluginError::InvalidManifest(format!( + "plugin tool `{}` is defined by both `{existing_plugin}` and `{}`", + tool.definition().name, + tool.plugin_id() + ))); + } + tools.push(tool.clone()); + } + } + Ok(tools) + } + + pub fn initialize(&self) -> Result<(), PluginError> { + for plugin in self.plugins.iter().filter(|plugin| plugin.is_enabled()) { + plugin.validate()?; + plugin.initialize()?; + } + Ok(()) + } + + pub fn shutdown(&self) -> Result<(), PluginError> { + for plugin in self + .plugins + .iter() + .rev() + .filter(|plugin| plugin.is_enabled()) + { + plugin.shutdown()?; + } + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PluginManagerConfig { + pub config_home: PathBuf, + pub enabled_plugins: BTreeMap<String, bool>, + pub external_dirs: Vec<PathBuf>, + pub install_root: Option<PathBuf>, + pub registry_path: Option<PathBuf>, + pub bundled_root: Option<PathBuf>, +} + +impl PluginManagerConfig { + #[must_use] + pub fn new(config_home: impl Into<PathBuf>) -> Self { + Self { + config_home: config_home.into(), + enabled_plugins: BTreeMap::new(), + external_dirs: Vec::new(), + install_root: None, + registry_path: None, + bundled_root: None, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PluginManager { + config: PluginManagerConfig, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct InstallOutcome { + pub plugin_id: String, + pub version: String, + pub install_path: PathBuf, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UpdateOutcome { + pub plugin_id: String, + pub old_version: String, + pub new_version: String, + pub install_path: PathBuf, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PluginManifestValidationError { + EmptyField { + field: &'static str, + }, + EmptyEntryField { + kind: &'static str, + field: &'static str, + name: Option<String>, + }, + InvalidPermission { + permission: String, + }, + DuplicatePermission { + permission: String, + }, + DuplicateEntry { + kind: &'static str, + name: String, + }, + MissingPath { + kind: &'static str, + path: PathBuf, + }, + InvalidToolInputSchema { + tool_name: String, + }, + InvalidToolRequiredPermission { + tool_name: String, + permission: String, + }, +} + +impl Display for PluginManifestValidationError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::EmptyField { field } => { + write!(f, "plugin manifest {field} cannot be empty") + } + Self::EmptyEntryField { kind, field, name } => match name { + Some(name) if !name.is_empty() => { + write!(f, "plugin {kind} `{name}` {field} cannot be empty") + } + _ => write!(f, "plugin {kind} {field} cannot be empty"), + }, + Self::InvalidPermission { permission } => { + write!( + f, + "plugin manifest permission `{permission}` must be one of read, write, or execute" + ) + } + Self::DuplicatePermission { permission } => { + write!(f, "plugin manifest permission `{permission}` is duplicated") + } + Self::DuplicateEntry { kind, name } => { + write!(f, "plugin {kind} `{name}` is duplicated") + } + Self::MissingPath { kind, path } => { + write!(f, "{kind} path `{}` does not exist", path.display()) + } + Self::InvalidToolInputSchema { tool_name } => { + write!( + f, + "plugin tool `{tool_name}` inputSchema must be a JSON object" + ) + } + Self::InvalidToolRequiredPermission { + tool_name, + permission, + } => write!( + f, + "plugin tool `{tool_name}` requiredPermission `{permission}` must be read-only, workspace-write, or danger-full-access" + ), + } + } +} + +#[derive(Debug)] +pub enum PluginError { + Io(std::io::Error), + Json(serde_json::Error), + ManifestValidation(Vec<PluginManifestValidationError>), + InvalidManifest(String), + NotFound(String), + CommandFailed(String), +} + +impl Display for PluginError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Io(error) => write!(f, "{error}"), + Self::Json(error) => write!(f, "{error}"), + Self::ManifestValidation(errors) => { + for (index, error) in errors.iter().enumerate() { + if index > 0 { + write!(f, "; ")?; + } + write!(f, "{error}")?; + } + Ok(()) + } + Self::InvalidManifest(message) + | Self::NotFound(message) + | Self::CommandFailed(message) => write!(f, "{message}"), + } + } +} + +impl std::error::Error for PluginError {} + +impl From<std::io::Error> for PluginError { + fn from(value: std::io::Error) -> Self { + Self::Io(value) + } +} + +impl From<serde_json::Error> for PluginError { + fn from(value: serde_json::Error) -> Self { + Self::Json(value) + } +} + +impl PluginManager { + #[must_use] + pub fn new(config: PluginManagerConfig) -> Self { + Self { config } + } + + #[must_use] + pub fn bundled_root() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("bundled") + } + + #[must_use] + pub fn install_root(&self) -> PathBuf { + self.config + .install_root + .clone() + .unwrap_or_else(|| self.config.config_home.join("plugins").join("installed")) + } + + #[must_use] + pub fn registry_path(&self) -> PathBuf { + self.config.registry_path.clone().unwrap_or_else(|| { + self.config + .config_home + .join("plugins") + .join(REGISTRY_FILE_NAME) + }) + } + + #[must_use] + pub fn settings_path(&self) -> PathBuf { + self.config.config_home.join(SETTINGS_FILE_NAME) + } + + pub fn plugin_registry(&self) -> Result<PluginRegistry, PluginError> { + Ok(PluginRegistry::new( + self.discover_plugins()? + .into_iter() + .map(|plugin| { + let enabled = self.is_enabled(plugin.metadata()); + RegisteredPlugin::new(plugin, enabled) + }) + .collect(), + )) + } + + pub fn list_plugins(&self) -> Result<Vec<PluginSummary>, PluginError> { + Ok(self.plugin_registry()?.summaries()) + } + + pub fn list_installed_plugins(&self) -> Result<Vec<PluginSummary>, PluginError> { + Ok(self.installed_plugin_registry()?.summaries()) + } + + pub fn discover_plugins(&self) -> Result<Vec<PluginDefinition>, PluginError> { + self.sync_bundled_plugins()?; + let mut plugins = builtin_plugins(); + plugins.extend(self.discover_installed_plugins()?); + plugins.extend(self.discover_external_directory_plugins(&plugins)?); + Ok(plugins) + } + + pub fn aggregated_hooks(&self) -> Result<PluginHooks, PluginError> { + self.plugin_registry()?.aggregated_hooks() + } + + pub fn aggregated_tools(&self) -> Result<Vec<PluginTool>, PluginError> { + self.plugin_registry()?.aggregated_tools() + } + + pub fn validate_plugin_source(&self, source: &str) -> Result<PluginManifest, PluginError> { + let path = resolve_local_source(source)?; + load_plugin_from_directory(&path) + } + + pub fn install(&mut self, source: &str) -> Result<InstallOutcome, PluginError> { + let install_source = parse_install_source(source)?; + let temp_root = self.install_root().join(".tmp"); + let staged_source = materialize_source(&install_source, &temp_root)?; + let cleanup_source = matches!(install_source, PluginInstallSource::GitUrl { .. }); + let manifest = load_plugin_from_directory(&staged_source)?; + + let plugin_id = plugin_id(&manifest.name, EXTERNAL_MARKETPLACE); + let install_path = self.install_root().join(sanitize_plugin_id(&plugin_id)); + if install_path.exists() { + fs::remove_dir_all(&install_path)?; + } + copy_dir_all(&staged_source, &install_path)?; + if cleanup_source { + let _ = fs::remove_dir_all(&staged_source); + } + + let now = unix_time_ms(); + let record = InstalledPluginRecord { + kind: PluginKind::External, + id: plugin_id.clone(), + name: manifest.name, + version: manifest.version.clone(), + description: manifest.description, + install_path: install_path.clone(), + source: install_source, + installed_at_unix_ms: now, + updated_at_unix_ms: now, + }; + + let mut registry = self.load_registry()?; + registry.plugins.insert(plugin_id.clone(), record); + self.store_registry(®istry)?; + self.write_enabled_state(&plugin_id, Some(true))?; + self.config.enabled_plugins.insert(plugin_id.clone(), true); + + Ok(InstallOutcome { + plugin_id, + version: manifest.version, + install_path, + }) + } + + pub fn enable(&mut self, plugin_id: &str) -> Result<(), PluginError> { + self.ensure_known_plugin(plugin_id)?; + self.write_enabled_state(plugin_id, Some(true))?; + self.config + .enabled_plugins + .insert(plugin_id.to_string(), true); + Ok(()) + } + + pub fn disable(&mut self, plugin_id: &str) -> Result<(), PluginError> { + self.ensure_known_plugin(plugin_id)?; + self.write_enabled_state(plugin_id, Some(false))?; + self.config + .enabled_plugins + .insert(plugin_id.to_string(), false); + Ok(()) + } + + pub fn uninstall(&mut self, plugin_id: &str) -> Result<(), PluginError> { + let mut registry = self.load_registry()?; + let record = registry.plugins.remove(plugin_id).ok_or_else(|| { + PluginError::NotFound(format!("plugin `{plugin_id}` is not installed")) + })?; + if record.kind == PluginKind::Bundled { + registry.plugins.insert(plugin_id.to_string(), record); + return Err(PluginError::CommandFailed(format!( + "plugin `{plugin_id}` is bundled and managed automatically; disable it instead" + ))); + } + if record.install_path.exists() { + fs::remove_dir_all(&record.install_path)?; + } + self.store_registry(®istry)?; + self.write_enabled_state(plugin_id, None)?; + self.config.enabled_plugins.remove(plugin_id); + Ok(()) + } + + pub fn update(&mut self, plugin_id: &str) -> Result<UpdateOutcome, PluginError> { + let mut registry = self.load_registry()?; + let record = registry.plugins.get(plugin_id).cloned().ok_or_else(|| { + PluginError::NotFound(format!("plugin `{plugin_id}` is not installed")) + })?; + + let temp_root = self.install_root().join(".tmp"); + let staged_source = materialize_source(&record.source, &temp_root)?; + let cleanup_source = matches!(record.source, PluginInstallSource::GitUrl { .. }); + let manifest = load_plugin_from_directory(&staged_source)?; + + if record.install_path.exists() { + fs::remove_dir_all(&record.install_path)?; + } + copy_dir_all(&staged_source, &record.install_path)?; + if cleanup_source { + let _ = fs::remove_dir_all(&staged_source); + } + + let updated_record = InstalledPluginRecord { + version: manifest.version.clone(), + description: manifest.description, + updated_at_unix_ms: unix_time_ms(), + ..record.clone() + }; + registry + .plugins + .insert(plugin_id.to_string(), updated_record); + self.store_registry(®istry)?; + + Ok(UpdateOutcome { + plugin_id: plugin_id.to_string(), + old_version: record.version, + new_version: manifest.version, + install_path: record.install_path, + }) + } + + fn discover_installed_plugins(&self) -> Result<Vec<PluginDefinition>, PluginError> { + let mut registry = self.load_registry()?; + let mut plugins = Vec::new(); + let mut seen_ids = BTreeSet::<String>::new(); + let mut seen_paths = BTreeSet::<PathBuf>::new(); + let mut stale_registry_ids = Vec::new(); + + for install_path in discover_plugin_dirs(&self.install_root())? { + let matched_record = registry + .plugins + .values() + .find(|record| record.install_path == install_path); + let kind = matched_record.map_or(PluginKind::External, |record| record.kind); + let source = matched_record.map_or_else( + || install_path.display().to_string(), + |record| describe_install_source(&record.source), + ); + let plugin = load_plugin_definition(&install_path, kind, source, kind.marketplace())?; + if seen_ids.insert(plugin.metadata().id.clone()) { + seen_paths.insert(install_path); + plugins.push(plugin); + } + } + + for record in registry.plugins.values() { + if seen_paths.contains(&record.install_path) { + continue; + } + if !record.install_path.exists() || plugin_manifest_path(&record.install_path).is_err() + { + stale_registry_ids.push(record.id.clone()); + continue; + } + let plugin = load_plugin_definition( + &record.install_path, + record.kind, + describe_install_source(&record.source), + record.kind.marketplace(), + )?; + if seen_ids.insert(plugin.metadata().id.clone()) { + seen_paths.insert(record.install_path.clone()); + plugins.push(plugin); + } + } + + if !stale_registry_ids.is_empty() { + for plugin_id in stale_registry_ids { + registry.plugins.remove(&plugin_id); + } + self.store_registry(®istry)?; + } + + Ok(plugins) + } + + fn discover_external_directory_plugins( + &self, + existing_plugins: &[PluginDefinition], + ) -> Result<Vec<PluginDefinition>, PluginError> { + let mut plugins = Vec::new(); + + for directory in &self.config.external_dirs { + for root in discover_plugin_dirs(directory)? { + let plugin = load_plugin_definition( + &root, + PluginKind::External, + root.display().to_string(), + EXTERNAL_MARKETPLACE, + )?; + if existing_plugins + .iter() + .chain(plugins.iter()) + .all(|existing| existing.metadata().id != plugin.metadata().id) + { + plugins.push(plugin); + } + } + } + + Ok(plugins) + } + + fn installed_plugin_registry(&self) -> Result<PluginRegistry, PluginError> { + self.sync_bundled_plugins()?; + Ok(PluginRegistry::new( + self.discover_installed_plugins()? + .into_iter() + .map(|plugin| { + let enabled = self.is_enabled(plugin.metadata()); + RegisteredPlugin::new(plugin, enabled) + }) + .collect(), + )) + } + + fn sync_bundled_plugins(&self) -> Result<(), PluginError> { + let bundled_root = self + .config + .bundled_root + .clone() + .unwrap_or_else(Self::bundled_root); + let bundled_plugins = discover_plugin_dirs(&bundled_root)?; + let mut registry = self.load_registry()?; + let mut changed = false; + let install_root = self.install_root(); + let mut active_bundled_ids = BTreeSet::new(); + + for source_root in bundled_plugins { + let manifest = load_plugin_from_directory(&source_root)?; + let plugin_id = plugin_id(&manifest.name, BUNDLED_MARKETPLACE); + active_bundled_ids.insert(plugin_id.clone()); + let install_path = install_root.join(sanitize_plugin_id(&plugin_id)); + let now = unix_time_ms(); + let existing_record = registry.plugins.get(&plugin_id); + let installed_copy_is_valid = + install_path.exists() && load_plugin_from_directory(&install_path).is_ok(); + let needs_sync = existing_record.is_none_or(|record| { + record.kind != PluginKind::Bundled + || record.version != manifest.version + || record.name != manifest.name + || record.description != manifest.description + || record.install_path != install_path + || !record.install_path.exists() + || !installed_copy_is_valid + }); + + if !needs_sync { + continue; + } + + if install_path.exists() { + fs::remove_dir_all(&install_path)?; + } + copy_dir_all(&source_root, &install_path)?; + + let installed_at_unix_ms = + existing_record.map_or(now, |record| record.installed_at_unix_ms); + registry.plugins.insert( + plugin_id.clone(), + InstalledPluginRecord { + kind: PluginKind::Bundled, + id: plugin_id, + name: manifest.name, + version: manifest.version, + description: manifest.description, + install_path, + source: PluginInstallSource::LocalPath { path: source_root }, + installed_at_unix_ms, + updated_at_unix_ms: now, + }, + ); + changed = true; + } + + let stale_bundled_ids = registry + .plugins + .iter() + .filter_map(|(plugin_id, record)| { + (record.kind == PluginKind::Bundled && !active_bundled_ids.contains(plugin_id)) + .then_some(plugin_id.clone()) + }) + .collect::<Vec<_>>(); + + for plugin_id in stale_bundled_ids { + if let Some(record) = registry.plugins.remove(&plugin_id) { + if record.install_path.exists() { + fs::remove_dir_all(&record.install_path)?; + } + changed = true; + } + } + + if changed { + self.store_registry(®istry)?; + } + + Ok(()) + } + + fn is_enabled(&self, metadata: &PluginMetadata) -> bool { + self.config + .enabled_plugins + .get(&metadata.id) + .copied() + .unwrap_or(match metadata.kind { + PluginKind::External => false, + PluginKind::Builtin | PluginKind::Bundled => metadata.default_enabled, + }) + } + + fn ensure_known_plugin(&self, plugin_id: &str) -> Result<(), PluginError> { + if self.plugin_registry()?.contains(plugin_id) { + Ok(()) + } else { + Err(PluginError::NotFound(format!( + "plugin `{plugin_id}` is not installed or discoverable" + ))) + } + } + + fn load_registry(&self) -> Result<InstalledPluginRegistry, PluginError> { + let path = self.registry_path(); + match fs::read_to_string(&path) { + Ok(contents) if contents.trim().is_empty() => Ok(InstalledPluginRegistry::default()), + Ok(contents) => Ok(serde_json::from_str(&contents)?), + Err(error) if error.kind() == std::io::ErrorKind::NotFound => { + Ok(InstalledPluginRegistry::default()) + } + Err(error) => Err(PluginError::Io(error)), + } + } + + fn store_registry(&self, registry: &InstalledPluginRegistry) -> Result<(), PluginError> { + let path = self.registry_path(); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + fs::write(path, serde_json::to_string_pretty(registry)?)?; + Ok(()) + } + + fn write_enabled_state( + &self, + plugin_id: &str, + enabled: Option<bool>, + ) -> Result<(), PluginError> { + update_settings_json(&self.settings_path(), |root| { + let enabled_plugins = ensure_object(root, "enabledPlugins"); + match enabled { + Some(value) => { + enabled_plugins.insert(plugin_id.to_string(), Value::Bool(value)); + } + None => { + enabled_plugins.remove(plugin_id); + } + } + }) + } +} + +#[must_use] +pub fn builtin_plugins() -> Vec<PluginDefinition> { + vec![PluginDefinition::Builtin(BuiltinPlugin { + metadata: PluginMetadata { + id: plugin_id("example-builtin", BUILTIN_MARKETPLACE), + name: "example-builtin".to_string(), + version: "0.1.0".to_string(), + description: "Example built-in plugin scaffold for the Rust plugin system".to_string(), + kind: PluginKind::Builtin, + source: BUILTIN_MARKETPLACE.to_string(), + default_enabled: false, + root: None, + }, + hooks: PluginHooks::default(), + lifecycle: PluginLifecycle::default(), + tools: Vec::new(), + })] +} + +fn load_plugin_definition( + root: &Path, + kind: PluginKind, + source: String, + marketplace: &str, +) -> Result<PluginDefinition, PluginError> { + let manifest = load_plugin_from_directory(root)?; + let metadata = PluginMetadata { + id: plugin_id(&manifest.name, marketplace), + name: manifest.name, + version: manifest.version, + description: manifest.description, + kind, + source, + default_enabled: manifest.default_enabled, + root: Some(root.to_path_buf()), + }; + let hooks = resolve_hooks(root, &manifest.hooks); + let lifecycle = resolve_lifecycle(root, &manifest.lifecycle); + let tools = resolve_tools(root, &metadata.id, &metadata.name, &manifest.tools); + Ok(match kind { + PluginKind::Builtin => PluginDefinition::Builtin(BuiltinPlugin { + metadata, + hooks, + lifecycle, + tools, + }), + PluginKind::Bundled => PluginDefinition::Bundled(BundledPlugin { + metadata, + hooks, + lifecycle, + tools, + }), + PluginKind::External => PluginDefinition::External(ExternalPlugin { + metadata, + hooks, + lifecycle, + tools, + }), + }) +} + +pub fn load_plugin_from_directory(root: &Path) -> Result<PluginManifest, PluginError> { + load_manifest_from_directory(root) +} + +fn load_manifest_from_directory(root: &Path) -> Result<PluginManifest, PluginError> { + let manifest_path = plugin_manifest_path(root)?; + load_manifest_from_path(root, &manifest_path) +} + +fn load_manifest_from_path( + root: &Path, + manifest_path: &Path, +) -> Result<PluginManifest, PluginError> { + let contents = fs::read_to_string(manifest_path).map_err(|error| { + PluginError::NotFound(format!( + "plugin manifest not found at {}: {error}", + manifest_path.display() + )) + })?; + let raw_manifest: RawPluginManifest = serde_json::from_str(&contents)?; + build_plugin_manifest(root, raw_manifest) +} + +fn plugin_manifest_path(root: &Path) -> Result<PathBuf, PluginError> { + let direct_path = root.join(MANIFEST_FILE_NAME); + if direct_path.exists() { + return Ok(direct_path); + } + + let packaged_path = root.join(MANIFEST_RELATIVE_PATH); + if packaged_path.exists() { + return Ok(packaged_path); + } + + Err(PluginError::NotFound(format!( + "plugin manifest not found at {} or {}", + direct_path.display(), + packaged_path.display() + ))) +} + +fn build_plugin_manifest( + root: &Path, + raw: RawPluginManifest, +) -> Result<PluginManifest, PluginError> { + let mut errors = Vec::new(); + + validate_required_manifest_field("name", &raw.name, &mut errors); + validate_required_manifest_field("version", &raw.version, &mut errors); + validate_required_manifest_field("description", &raw.description, &mut errors); + + let permissions = build_manifest_permissions(&raw.permissions, &mut errors); + validate_command_entries(root, raw.hooks.pre_tool_use.iter(), "hook", &mut errors); + validate_command_entries(root, raw.hooks.post_tool_use.iter(), "hook", &mut errors); + validate_command_entries( + root, + raw.lifecycle.init.iter(), + "lifecycle command", + &mut errors, + ); + validate_command_entries( + root, + raw.lifecycle.shutdown.iter(), + "lifecycle command", + &mut errors, + ); + let tools = build_manifest_tools(root, raw.tools, &mut errors); + let commands = build_manifest_commands(root, raw.commands, &mut errors); + + if !errors.is_empty() { + return Err(PluginError::ManifestValidation(errors)); + } + + Ok(PluginManifest { + name: raw.name, + version: raw.version, + description: raw.description, + permissions, + default_enabled: raw.default_enabled, + hooks: raw.hooks, + lifecycle: raw.lifecycle, + tools, + commands, + }) +} + +fn validate_required_manifest_field( + field: &'static str, + value: &str, + errors: &mut Vec<PluginManifestValidationError>, +) { + if value.trim().is_empty() { + errors.push(PluginManifestValidationError::EmptyField { field }); + } +} + +fn build_manifest_permissions( + permissions: &[String], + errors: &mut Vec<PluginManifestValidationError>, +) -> Vec<PluginPermission> { + let mut seen = BTreeSet::new(); + let mut validated = Vec::new(); + + for permission in permissions { + let permission = permission.trim(); + if permission.is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "permission", + field: "value", + name: None, + }); + continue; + } + if !seen.insert(permission.to_string()) { + errors.push(PluginManifestValidationError::DuplicatePermission { + permission: permission.to_string(), + }); + continue; + } + match PluginPermission::parse(permission) { + Some(permission) => validated.push(permission), + None => errors.push(PluginManifestValidationError::InvalidPermission { + permission: permission.to_string(), + }), + } + } + + validated +} + +fn build_manifest_tools( + root: &Path, + tools: Vec<RawPluginToolManifest>, + errors: &mut Vec<PluginManifestValidationError>, +) -> Vec<PluginToolManifest> { + let mut seen = BTreeSet::new(); + let mut validated = Vec::new(); + + for tool in tools { + let name = tool.name.trim().to_string(); + if name.is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "tool", + field: "name", + name: None, + }); + continue; + } + if !seen.insert(name.clone()) { + errors.push(PluginManifestValidationError::DuplicateEntry { kind: "tool", name }); + continue; + } + if tool.description.trim().is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "tool", + field: "description", + name: Some(name.clone()), + }); + } + if tool.command.trim().is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "tool", + field: "command", + name: Some(name.clone()), + }); + } else { + validate_command_entry(root, &tool.command, "tool", errors); + } + if !tool.input_schema.is_object() { + errors.push(PluginManifestValidationError::InvalidToolInputSchema { + tool_name: name.clone(), + }); + } + let Some(required_permission) = + PluginToolPermission::parse(tool.required_permission.trim()) + else { + errors.push( + PluginManifestValidationError::InvalidToolRequiredPermission { + tool_name: name.clone(), + permission: tool.required_permission.trim().to_string(), + }, + ); + continue; + }; + + validated.push(PluginToolManifest { + name, + description: tool.description, + input_schema: tool.input_schema, + command: tool.command, + args: tool.args, + required_permission, + }); + } + + validated +} + +fn build_manifest_commands( + root: &Path, + commands: Vec<PluginCommandManifest>, + errors: &mut Vec<PluginManifestValidationError>, +) -> Vec<PluginCommandManifest> { + let mut seen = BTreeSet::new(); + let mut validated = Vec::new(); + + for command in commands { + let name = command.name.trim().to_string(); + if name.is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "command", + field: "name", + name: None, + }); + continue; + } + if !seen.insert(name.clone()) { + errors.push(PluginManifestValidationError::DuplicateEntry { + kind: "command", + name, + }); + continue; + } + if command.description.trim().is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "command", + field: "description", + name: Some(name.clone()), + }); + } + if command.command.trim().is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind: "command", + field: "command", + name: Some(name.clone()), + }); + } else { + validate_command_entry(root, &command.command, "command", errors); + } + validated.push(command); + } + + validated +} + +fn validate_command_entries<'a>( + root: &Path, + entries: impl Iterator<Item = &'a String>, + kind: &'static str, + errors: &mut Vec<PluginManifestValidationError>, +) { + for entry in entries { + validate_command_entry(root, entry, kind, errors); + } +} + +fn validate_command_entry( + root: &Path, + entry: &str, + kind: &'static str, + errors: &mut Vec<PluginManifestValidationError>, +) { + if entry.trim().is_empty() { + errors.push(PluginManifestValidationError::EmptyEntryField { + kind, + field: "command", + name: None, + }); + return; + } + if is_literal_command(entry) { + return; + } + + let path = if Path::new(entry).is_absolute() { + PathBuf::from(entry) + } else { + root.join(entry) + }; + if !path.exists() { + errors.push(PluginManifestValidationError::MissingPath { kind, path }); + } +} + +fn resolve_hooks(root: &Path, hooks: &PluginHooks) -> PluginHooks { + PluginHooks { + pre_tool_use: hooks + .pre_tool_use + .iter() + .map(|entry| resolve_hook_entry(root, entry)) + .collect(), + post_tool_use: hooks + .post_tool_use + .iter() + .map(|entry| resolve_hook_entry(root, entry)) + .collect(), + } +} + +fn resolve_lifecycle(root: &Path, lifecycle: &PluginLifecycle) -> PluginLifecycle { + PluginLifecycle { + init: lifecycle + .init + .iter() + .map(|entry| resolve_hook_entry(root, entry)) + .collect(), + shutdown: lifecycle + .shutdown + .iter() + .map(|entry| resolve_hook_entry(root, entry)) + .collect(), + } +} + +fn resolve_tools( + root: &Path, + plugin_id: &str, + plugin_name: &str, + tools: &[PluginToolManifest], +) -> Vec<PluginTool> { + tools + .iter() + .map(|tool| { + PluginTool::new( + plugin_id, + plugin_name, + PluginToolDefinition { + name: tool.name.clone(), + description: Some(tool.description.clone()), + input_schema: tool.input_schema.clone(), + }, + resolve_hook_entry(root, &tool.command), + tool.args.clone(), + tool.required_permission, + Some(root.to_path_buf()), + ) + }) + .collect() +} + +fn validate_hook_paths(root: Option<&Path>, hooks: &PluginHooks) -> Result<(), PluginError> { + let Some(root) = root else { + return Ok(()); + }; + for entry in hooks.pre_tool_use.iter().chain(hooks.post_tool_use.iter()) { + validate_command_path(root, entry, "hook")?; + } + Ok(()) +} + +fn validate_lifecycle_paths( + root: Option<&Path>, + lifecycle: &PluginLifecycle, +) -> Result<(), PluginError> { + let Some(root) = root else { + return Ok(()); + }; + for entry in lifecycle.init.iter().chain(lifecycle.shutdown.iter()) { + validate_command_path(root, entry, "lifecycle command")?; + } + Ok(()) +} + +fn validate_tool_paths(root: Option<&Path>, tools: &[PluginTool]) -> Result<(), PluginError> { + let Some(root) = root else { + return Ok(()); + }; + for tool in tools { + validate_command_path(root, &tool.command, "tool")?; + } + Ok(()) +} + +fn validate_command_path(root: &Path, entry: &str, kind: &str) -> Result<(), PluginError> { + if is_literal_command(entry) { + return Ok(()); + } + let path = if Path::new(entry).is_absolute() { + PathBuf::from(entry) + } else { + root.join(entry) + }; + if !path.exists() { + return Err(PluginError::InvalidManifest(format!( + "{kind} path `{}` does not exist", + path.display() + ))); + } + Ok(()) +} + +fn resolve_hook_entry(root: &Path, entry: &str) -> String { + if is_literal_command(entry) { + entry.to_string() + } else { + root.join(entry).display().to_string() + } +} + +fn is_literal_command(entry: &str) -> bool { + !entry.starts_with("./") && !entry.starts_with("../") && !Path::new(entry).is_absolute() +} + +fn run_lifecycle_commands( + metadata: &PluginMetadata, + lifecycle: &PluginLifecycle, + phase: &str, + commands: &[String], +) -> Result<(), PluginError> { + if lifecycle.is_empty() || commands.is_empty() { + return Ok(()); + } + + for command in commands { + let mut process = if Path::new(command).exists() { + if cfg!(windows) { + let mut process = Command::new("cmd"); + process.arg("/C").arg(command); + process + } else { + let mut process = Command::new("sh"); + process.arg(command); + process + } + } else if cfg!(windows) { + let mut process = Command::new("cmd"); + process.arg("/C").arg(command); + process + } else { + let mut process = Command::new("sh"); + process.arg("-lc").arg(command); + process + }; + if let Some(root) = &metadata.root { + process.current_dir(root); + } + let output = process.output()?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + return Err(PluginError::CommandFailed(format!( + "plugin `{}` {} failed for `{}`: {}", + metadata.id, + phase, + command, + if stderr.is_empty() { + format!("exit status {}", output.status) + } else { + stderr + } + ))); + } + } + + Ok(()) +} + +fn resolve_local_source(source: &str) -> Result<PathBuf, PluginError> { + let path = PathBuf::from(source); + if path.exists() { + Ok(path) + } else { + Err(PluginError::NotFound(format!( + "plugin source `{source}` was not found" + ))) + } +} + +fn parse_install_source(source: &str) -> Result<PluginInstallSource, PluginError> { + if source.starts_with("http://") + || source.starts_with("https://") + || source.starts_with("git@") + || Path::new(source) + .extension() + .is_some_and(|extension| extension.eq_ignore_ascii_case("git")) + { + Ok(PluginInstallSource::GitUrl { + url: source.to_string(), + }) + } else { + Ok(PluginInstallSource::LocalPath { + path: resolve_local_source(source)?, + }) + } +} + +fn materialize_source( + source: &PluginInstallSource, + temp_root: &Path, +) -> Result<PathBuf, PluginError> { + fs::create_dir_all(temp_root)?; + match source { + PluginInstallSource::LocalPath { path } => Ok(path.clone()), + PluginInstallSource::GitUrl { url } => { + let destination = temp_root.join(format!("plugin-{}", unix_time_ms())); + let output = Command::new("git") + .arg("clone") + .arg("--depth") + .arg("1") + .arg(url) + .arg(&destination) + .output()?; + if !output.status.success() { + return Err(PluginError::CommandFailed(format!( + "git clone failed for `{url}`: {}", + String::from_utf8_lossy(&output.stderr).trim() + ))); + } + Ok(destination) + } + } +} + +fn discover_plugin_dirs(root: &Path) -> Result<Vec<PathBuf>, PluginError> { + match fs::read_dir(root) { + Ok(entries) => { + let mut paths = Vec::new(); + for entry in entries { + let path = entry?.path(); + if path.is_dir() && plugin_manifest_path(&path).is_ok() { + paths.push(path); + } + } + paths.sort(); + Ok(paths) + } + Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(Vec::new()), + Err(error) => Err(PluginError::Io(error)), + } +} + +fn plugin_id(name: &str, marketplace: &str) -> String { + format!("{name}@{marketplace}") +} + +fn sanitize_plugin_id(plugin_id: &str) -> String { + plugin_id + .chars() + .map(|ch| match ch { + '/' | '\\' | '@' | ':' => '-', + other => other, + }) + .collect() +} + +fn describe_install_source(source: &PluginInstallSource) -> String { + match source { + PluginInstallSource::LocalPath { path } => path.display().to_string(), + PluginInstallSource::GitUrl { url } => url.clone(), + } +} + +fn unix_time_ms() -> u128 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time should be after epoch") + .as_millis() +} + +fn copy_dir_all(source: &Path, destination: &Path) -> Result<(), PluginError> { + fs::create_dir_all(destination)?; + for entry in fs::read_dir(source)? { + let entry = entry?; + let target = destination.join(entry.file_name()); + if entry.file_type()?.is_dir() { + copy_dir_all(&entry.path(), &target)?; + } else { + fs::copy(entry.path(), target)?; + } + } + Ok(()) +} + +fn update_settings_json( + path: &Path, + mut update: impl FnMut(&mut Map<String, Value>), +) -> Result<(), PluginError> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + let mut root = match fs::read_to_string(path) { + Ok(contents) if !contents.trim().is_empty() => serde_json::from_str::<Value>(&contents)?, + Ok(_) => Value::Object(Map::new()), + Err(error) if error.kind() == std::io::ErrorKind::NotFound => Value::Object(Map::new()), + Err(error) => return Err(PluginError::Io(error)), + }; + + let object = root.as_object_mut().ok_or_else(|| { + PluginError::InvalidManifest(format!( + "settings file {} must contain a JSON object", + path.display() + )) + })?; + update(object); + fs::write(path, serde_json::to_string_pretty(&root)?)?; + Ok(()) +} + +fn ensure_object<'a>(root: &'a mut Map<String, Value>, key: &str) -> &'a mut Map<String, Value> { + if !root.get(key).is_some_and(Value::is_object) { + root.insert(key.to_string(), Value::Object(Map::new())); + } + root.get_mut(key) + .and_then(Value::as_object_mut) + .expect("object should exist") +} + +#[cfg(test)] +mod tests { + use super::*; + + fn temp_dir(label: &str) -> PathBuf { + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("time should be after epoch") + .as_nanos(); + std::env::temp_dir().join(format!("plugins-{label}-{nanos}")) + } + + fn write_file(path: &Path, contents: &str) { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).expect("parent dir"); + } + fs::write(path, contents).expect("write file"); + } + + fn write_loader_plugin(root: &Path) { + write_file( + root.join("hooks").join("pre.sh").as_path(), + "#!/bin/sh\nprintf 'pre'\n", + ); + write_file( + root.join("tools").join("echo-tool.sh").as_path(), + "#!/bin/sh\ncat\n", + ); + write_file( + root.join("commands").join("sync.sh").as_path(), + "#!/bin/sh\nprintf 'sync'\n", + ); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "loader-demo", + "version": "1.2.3", + "description": "Manifest loader test plugin", + "permissions": ["read", "write"], + "hooks": { + "PreToolUse": ["./hooks/pre.sh"] + }, + "tools": [ + { + "name": "echo_tool", + "description": "Echoes JSON input", + "inputSchema": { + "type": "object" + }, + "command": "./tools/echo-tool.sh", + "requiredPermission": "workspace-write" + } + ], + "commands": [ + { + "name": "sync", + "description": "Sync command", + "command": "./commands/sync.sh" + } + ] +}"#, + ); + } + + fn write_external_plugin(root: &Path, name: &str, version: &str) { + write_file( + root.join("hooks").join("pre.sh").as_path(), + "#!/bin/sh\nprintf 'pre'\n", + ); + write_file( + root.join("hooks").join("post.sh").as_path(), + "#!/bin/sh\nprintf 'post'\n", + ); + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"test plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/pre.sh\"],\n \"PostToolUse\": [\"./hooks/post.sh\"]\n }}\n}}" + ) + .as_str(), + ); + } + + fn write_broken_plugin(root: &Path, name: &str) { + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"1.0.0\",\n \"description\": \"broken plugin\",\n \"hooks\": {{\n \"PreToolUse\": [\"./hooks/missing.sh\"]\n }}\n}}" + ) + .as_str(), + ); + } + + fn write_lifecycle_plugin(root: &Path, name: &str, version: &str) -> PathBuf { + let log_path = root.join("lifecycle.log"); + write_file( + root.join("lifecycle").join("init.sh").as_path(), + "#!/bin/sh\nprintf 'init\\n' >> lifecycle.log\n", + ); + write_file( + root.join("lifecycle").join("shutdown.sh").as_path(), + "#!/bin/sh\nprintf 'shutdown\\n' >> lifecycle.log\n", + ); + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"lifecycle plugin\",\n \"lifecycle\": {{\n \"Init\": [\"./lifecycle/init.sh\"],\n \"Shutdown\": [\"./lifecycle/shutdown.sh\"]\n }}\n}}" + ) + .as_str(), + ); + log_path + } + + fn write_tool_plugin(root: &Path, name: &str, version: &str) { + write_tool_plugin_with_name(root, name, version, "plugin_echo"); + } + + fn write_tool_plugin_with_name(root: &Path, name: &str, version: &str, tool_name: &str) { + let script_path = root.join("tools").join("echo-json.sh"); + write_file( + &script_path, + "#!/bin/sh\nINPUT=$(cat)\nprintf '{\"plugin\":\"%s\",\"tool\":\"%s\",\"input\":%s}\\n' \"$CLAW_PLUGIN_ID\" \"$CLAW_TOOL_NAME\" \"$INPUT\"\n", + ); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + + let mut permissions = fs::metadata(&script_path).expect("metadata").permissions(); + permissions.set_mode(0o755); + fs::set_permissions(&script_path, permissions).expect("chmod"); + } + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"tool plugin\",\n \"tools\": [\n {{\n \"name\": \"{tool_name}\",\n \"description\": \"Echo JSON input\",\n \"inputSchema\": {{\"type\": \"object\", \"properties\": {{\"message\": {{\"type\": \"string\"}}}}, \"required\": [\"message\"], \"additionalProperties\": false}},\n \"command\": \"./tools/echo-json.sh\",\n \"requiredPermission\": \"workspace-write\"\n }}\n ]\n}}" + ) + .as_str(), + ); + } + + fn write_bundled_plugin(root: &Path, name: &str, version: &str, default_enabled: bool) { + write_file( + root.join(MANIFEST_RELATIVE_PATH).as_path(), + format!( + "{{\n \"name\": \"{name}\",\n \"version\": \"{version}\",\n \"description\": \"bundled plugin\",\n \"defaultEnabled\": {}\n}}", + if default_enabled { "true" } else { "false" } + ) + .as_str(), + ); + } + + fn load_enabled_plugins(path: &Path) -> BTreeMap<String, bool> { + let contents = fs::read_to_string(path).expect("settings should exist"); + let root: Value = serde_json::from_str(&contents).expect("settings json"); + root.get("enabledPlugins") + .and_then(Value::as_object) + .map(|enabled_plugins| { + enabled_plugins + .iter() + .map(|(plugin_id, value)| { + ( + plugin_id.clone(), + value.as_bool().expect("plugin state should be a bool"), + ) + }) + .collect() + }) + .unwrap_or_default() + } + + #[test] + fn load_plugin_from_directory_validates_required_fields() { + let root = temp_dir("manifest-required"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{"name":"","version":"1.0.0","description":"desc"}"#, + ); + + let error = load_plugin_from_directory(&root).expect_err("empty name should fail"); + assert!(error.to_string().contains("name cannot be empty")); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_reads_root_manifest_and_validates_entries() { + let root = temp_dir("manifest-root"); + write_loader_plugin(&root); + + let manifest = load_plugin_from_directory(&root).expect("manifest should load"); + assert_eq!(manifest.name, "loader-demo"); + assert_eq!(manifest.version, "1.2.3"); + assert_eq!( + manifest + .permissions + .iter() + .map(|permission| permission.as_str()) + .collect::<Vec<_>>(), + vec!["read", "write"] + ); + assert_eq!(manifest.hooks.pre_tool_use, vec!["./hooks/pre.sh"]); + assert_eq!(manifest.tools.len(), 1); + assert_eq!(manifest.tools[0].name, "echo_tool"); + assert_eq!( + manifest.tools[0].required_permission, + PluginToolPermission::WorkspaceWrite + ); + assert_eq!(manifest.commands.len(), 1); + assert_eq!(manifest.commands[0].name, "sync"); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_supports_packaged_manifest_path() { + let root = temp_dir("manifest-packaged"); + write_external_plugin(&root, "packaged-demo", "1.0.0"); + + let manifest = load_plugin_from_directory(&root).expect("packaged manifest should load"); + assert_eq!(manifest.name, "packaged-demo"); + assert!(manifest.tools.is_empty()); + assert!(manifest.commands.is_empty()); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_defaults_optional_fields() { + let root = temp_dir("manifest-defaults"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "minimal", + "version": "0.1.0", + "description": "Minimal manifest" +}"#, + ); + + let manifest = load_plugin_from_directory(&root).expect("minimal manifest should load"); + assert!(manifest.permissions.is_empty()); + assert!(manifest.hooks.is_empty()); + assert!(manifest.tools.is_empty()); + assert!(manifest.commands.is_empty()); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_rejects_duplicate_permissions_and_commands() { + let root = temp_dir("manifest-duplicates"); + write_file( + root.join("commands").join("sync.sh").as_path(), + "#!/bin/sh\nprintf 'sync'\n", + ); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "duplicate-manifest", + "version": "1.0.0", + "description": "Duplicate validation", + "permissions": ["read", "read"], + "commands": [ + {"name": "sync", "description": "Sync one", "command": "./commands/sync.sh"}, + {"name": "sync", "description": "Sync two", "command": "./commands/sync.sh"} + ] +}"#, + ); + + let error = load_plugin_from_directory(&root).expect_err("duplicates should fail"); + match error { + PluginError::ManifestValidation(errors) => { + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::DuplicatePermission { permission } + if permission == "read" + ))); + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::DuplicateEntry { kind, name } + if *kind == "command" && name == "sync" + ))); + } + other => panic!("expected manifest validation errors, got {other}"), + } + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_rejects_missing_tool_or_command_paths() { + let root = temp_dir("manifest-paths"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "missing-paths", + "version": "1.0.0", + "description": "Missing path validation", + "tools": [ + { + "name": "tool_one", + "description": "Missing tool script", + "inputSchema": {"type": "object"}, + "command": "./tools/missing.sh" + } + ] +}"#, + ); + + let error = load_plugin_from_directory(&root).expect_err("missing paths should fail"); + assert!(error.to_string().contains("does not exist")); + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_rejects_invalid_permissions() { + let root = temp_dir("manifest-invalid-permissions"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "invalid-permissions", + "version": "1.0.0", + "description": "Invalid permission validation", + "permissions": ["admin"] +}"#, + ); + + let error = load_plugin_from_directory(&root).expect_err("invalid permissions should fail"); + match error { + PluginError::ManifestValidation(errors) => { + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::InvalidPermission { permission } + if permission == "admin" + ))); + } + other => panic!("expected manifest validation errors, got {other}"), + } + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_rejects_invalid_tool_required_permission() { + let root = temp_dir("manifest-invalid-tool-permission"); + write_file( + root.join("tools").join("echo.sh").as_path(), + "#!/bin/sh\ncat\n", + ); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "invalid-tool-permission", + "version": "1.0.0", + "description": "Invalid tool permission validation", + "tools": [ + { + "name": "echo_tool", + "description": "Echo tool", + "inputSchema": {"type": "object"}, + "command": "./tools/echo.sh", + "requiredPermission": "admin" + } + ] +}"#, + ); + + let error = + load_plugin_from_directory(&root).expect_err("invalid tool permission should fail"); + match error { + PluginError::ManifestValidation(errors) => { + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::InvalidToolRequiredPermission { + tool_name, + permission + } if tool_name == "echo_tool" && permission == "admin" + ))); + } + other => panic!("expected manifest validation errors, got {other}"), + } + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn load_plugin_from_directory_accumulates_multiple_validation_errors() { + let root = temp_dir("manifest-multi-error"); + write_file( + root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "", + "version": "1.0.0", + "description": "", + "permissions": ["admin"], + "commands": [ + {"name": "", "description": "", "command": "./commands/missing.sh"} + ] +}"#, + ); + + let error = + load_plugin_from_directory(&root).expect_err("multiple manifest errors should fail"); + match error { + PluginError::ManifestValidation(errors) => { + assert!(errors.len() >= 4); + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::EmptyField { field } if *field == "name" + ))); + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::EmptyField { field } + if *field == "description" + ))); + assert!(errors.iter().any(|error| matches!( + error, + PluginManifestValidationError::InvalidPermission { permission } + if permission == "admin" + ))); + } + other => panic!("expected manifest validation errors, got {other}"), + } + + let _ = fs::remove_dir_all(root); + } + + #[test] + fn discovers_builtin_and_bundled_plugins() { + let manager = PluginManager::new(PluginManagerConfig::new(temp_dir("discover"))); + let plugins = manager.list_plugins().expect("plugins should list"); + assert!(plugins + .iter() + .any(|plugin| plugin.metadata.kind == PluginKind::Builtin)); + assert!(plugins + .iter() + .any(|plugin| plugin.metadata.kind == PluginKind::Bundled)); + } + + #[test] + fn installs_enables_updates_and_uninstalls_external_plugins() { + let config_home = temp_dir("home"); + let source_root = temp_dir("source"); + write_external_plugin(&source_root, "demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let install = manager + .install(source_root.to_str().expect("utf8 path")) + .expect("install should succeed"); + assert_eq!(install.plugin_id, "demo@external"); + assert!(manager + .list_plugins() + .expect("list plugins") + .iter() + .any(|plugin| plugin.metadata.id == "demo@external" && plugin.enabled)); + + let hooks = manager.aggregated_hooks().expect("hooks should aggregate"); + assert_eq!(hooks.pre_tool_use.len(), 1); + assert!(hooks.pre_tool_use[0].contains("pre.sh")); + + manager + .disable("demo@external") + .expect("disable should work"); + assert!(manager + .aggregated_hooks() + .expect("hooks after disable") + .is_empty()); + manager.enable("demo@external").expect("enable should work"); + + write_external_plugin(&source_root, "demo", "2.0.0"); + let update = manager.update("demo@external").expect("update should work"); + assert_eq!(update.old_version, "1.0.0"); + assert_eq!(update.new_version, "2.0.0"); + + manager + .uninstall("demo@external") + .expect("uninstall should work"); + assert!(!manager + .list_plugins() + .expect("list plugins") + .iter() + .any(|plugin| plugin.metadata.id == "demo@external")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn auto_installs_bundled_plugins_into_the_registry() { + let config_home = temp_dir("bundled-home"); + let bundled_root = temp_dir("bundled-root"); + write_bundled_plugin(&bundled_root.join("starter"), "starter", "0.1.0", false); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + let manager = PluginManager::new(config); + + let installed = manager + .list_installed_plugins() + .expect("bundled plugins should auto-install"); + assert!(installed.iter().any(|plugin| { + plugin.metadata.id == "starter@bundled" + && plugin.metadata.kind == PluginKind::Bundled + && !plugin.enabled + })); + + let registry = manager.load_registry().expect("registry should exist"); + let record = registry + .plugins + .get("starter@bundled") + .expect("bundled plugin should be recorded"); + assert_eq!(record.kind, PluginKind::Bundled); + assert!(record.install_path.exists()); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn default_bundled_root_loads_repo_bundles_as_installed_plugins() { + let config_home = temp_dir("default-bundled-home"); + let manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + + let installed = manager + .list_installed_plugins() + .expect("default bundled plugins should auto-install"); + assert!(installed + .iter() + .any(|plugin| plugin.metadata.id == "example-bundled@bundled")); + assert!(installed + .iter() + .any(|plugin| plugin.metadata.id == "sample-hooks@bundled")); + + let _ = fs::remove_dir_all(config_home); + } + + #[test] + fn bundled_sync_prunes_removed_bundled_registry_entries() { + let config_home = temp_dir("bundled-prune-home"); + let bundled_root = temp_dir("bundled-prune-root"); + let stale_install_path = config_home + .join("plugins") + .join("installed") + .join("stale-bundled-external"); + write_bundled_plugin(&bundled_root.join("active"), "active", "0.1.0", false); + write_file( + stale_install_path.join(MANIFEST_RELATIVE_PATH).as_path(), + r#"{ + "name": "stale", + "version": "0.1.0", + "description": "stale bundled plugin" +}"#, + ); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + config.install_root = Some(config_home.join("plugins").join("installed")); + let manager = PluginManager::new(config); + + let mut registry = InstalledPluginRegistry::default(); + registry.plugins.insert( + "stale@bundled".to_string(), + InstalledPluginRecord { + kind: PluginKind::Bundled, + id: "stale@bundled".to_string(), + name: "stale".to_string(), + version: "0.1.0".to_string(), + description: "stale bundled plugin".to_string(), + install_path: stale_install_path.clone(), + source: PluginInstallSource::LocalPath { + path: bundled_root.join("stale"), + }, + installed_at_unix_ms: 1, + updated_at_unix_ms: 1, + }, + ); + manager.store_registry(®istry).expect("store registry"); + manager + .write_enabled_state("stale@bundled", Some(true)) + .expect("seed bundled enabled state"); + + let installed = manager + .list_installed_plugins() + .expect("bundled sync should succeed"); + assert!(installed + .iter() + .any(|plugin| plugin.metadata.id == "active@bundled")); + assert!(!installed + .iter() + .any(|plugin| plugin.metadata.id == "stale@bundled")); + + let registry = manager.load_registry().expect("load registry"); + assert!(!registry.plugins.contains_key("stale@bundled")); + assert!(!stale_install_path.exists()); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn installed_plugin_discovery_keeps_registry_entries_outside_install_root() { + let config_home = temp_dir("registry-fallback-home"); + let bundled_root = temp_dir("registry-fallback-bundled"); + let install_root = config_home.join("plugins").join("installed"); + let external_install_path = temp_dir("registry-fallback-external"); + write_file( + external_install_path.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "registry-fallback", + "version": "1.0.0", + "description": "Registry fallback plugin" +}"#, + ); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + config.install_root = Some(install_root.clone()); + let manager = PluginManager::new(config); + + let mut registry = InstalledPluginRegistry::default(); + registry.plugins.insert( + "registry-fallback@external".to_string(), + InstalledPluginRecord { + kind: PluginKind::External, + id: "registry-fallback@external".to_string(), + name: "registry-fallback".to_string(), + version: "1.0.0".to_string(), + description: "Registry fallback plugin".to_string(), + install_path: external_install_path.clone(), + source: PluginInstallSource::LocalPath { + path: external_install_path.clone(), + }, + installed_at_unix_ms: 1, + updated_at_unix_ms: 1, + }, + ); + manager.store_registry(®istry).expect("store registry"); + manager + .write_enabled_state("stale-external@external", Some(true)) + .expect("seed stale external enabled state"); + + let installed = manager + .list_installed_plugins() + .expect("registry fallback plugin should load"); + assert!(installed + .iter() + .any(|plugin| plugin.metadata.id == "registry-fallback@external")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + let _ = fs::remove_dir_all(external_install_path); + } + + #[test] + fn installed_plugin_discovery_prunes_stale_registry_entries() { + let config_home = temp_dir("registry-prune-home"); + let bundled_root = temp_dir("registry-prune-bundled"); + let install_root = config_home.join("plugins").join("installed"); + let missing_install_path = temp_dir("registry-prune-missing"); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + config.install_root = Some(install_root); + let manager = PluginManager::new(config); + + let mut registry = InstalledPluginRegistry::default(); + registry.plugins.insert( + "stale-external@external".to_string(), + InstalledPluginRecord { + kind: PluginKind::External, + id: "stale-external@external".to_string(), + name: "stale-external".to_string(), + version: "1.0.0".to_string(), + description: "stale external plugin".to_string(), + install_path: missing_install_path.clone(), + source: PluginInstallSource::LocalPath { + path: missing_install_path.clone(), + }, + installed_at_unix_ms: 1, + updated_at_unix_ms: 1, + }, + ); + manager.store_registry(®istry).expect("store registry"); + + let installed = manager + .list_installed_plugins() + .expect("stale registry entries should be pruned"); + assert!(!installed + .iter() + .any(|plugin| plugin.metadata.id == "stale-external@external")); + + let registry = manager.load_registry().expect("load registry"); + assert!(!registry.plugins.contains_key("stale-external@external")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn persists_bundled_plugin_enable_state_across_reloads() { + let config_home = temp_dir("bundled-state-home"); + let bundled_root = temp_dir("bundled-state-root"); + write_bundled_plugin(&bundled_root.join("starter"), "starter", "0.1.0", false); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + let mut manager = PluginManager::new(config.clone()); + + manager + .enable("starter@bundled") + .expect("enable bundled plugin should succeed"); + assert_eq!( + load_enabled_plugins(&manager.settings_path()).get("starter@bundled"), + Some(&true) + ); + + let mut reloaded_config = PluginManagerConfig::new(&config_home); + reloaded_config.bundled_root = Some(bundled_root.clone()); + reloaded_config.enabled_plugins = load_enabled_plugins(&manager.settings_path()); + let reloaded_manager = PluginManager::new(reloaded_config); + let reloaded = reloaded_manager + .list_installed_plugins() + .expect("bundled plugins should still be listed"); + assert!(reloaded + .iter() + .any(|plugin| { plugin.metadata.id == "starter@bundled" && plugin.enabled })); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn persists_bundled_plugin_disable_state_across_reloads() { + let config_home = temp_dir("bundled-disabled-home"); + let bundled_root = temp_dir("bundled-disabled-root"); + write_bundled_plugin(&bundled_root.join("starter"), "starter", "0.1.0", true); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + let mut manager = PluginManager::new(config); + + manager + .disable("starter@bundled") + .expect("disable bundled plugin should succeed"); + assert_eq!( + load_enabled_plugins(&manager.settings_path()).get("starter@bundled"), + Some(&false) + ); + + let mut reloaded_config = PluginManagerConfig::new(&config_home); + reloaded_config.bundled_root = Some(bundled_root.clone()); + reloaded_config.enabled_plugins = load_enabled_plugins(&manager.settings_path()); + let reloaded_manager = PluginManager::new(reloaded_config); + let reloaded = reloaded_manager + .list_installed_plugins() + .expect("bundled plugins should still be listed"); + assert!(reloaded + .iter() + .any(|plugin| { plugin.metadata.id == "starter@bundled" && !plugin.enabled })); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn validates_plugin_source_before_install() { + let config_home = temp_dir("validate-home"); + let source_root = temp_dir("validate-source"); + write_external_plugin(&source_root, "validator", "1.0.0"); + let manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let manifest = manager + .validate_plugin_source(source_root.to_str().expect("utf8 path")) + .expect("manifest should validate"); + assert_eq!(manifest.name, "validator"); + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn plugin_registry_tracks_enabled_state_and_lookup() { + let config_home = temp_dir("registry-home"); + let source_root = temp_dir("registry-source"); + write_external_plugin(&source_root, "registry-demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + manager + .install(source_root.to_str().expect("utf8 path")) + .expect("install should succeed"); + manager + .disable("registry-demo@external") + .expect("disable should succeed"); + + let registry = manager.plugin_registry().expect("registry should build"); + let plugin = registry + .get("registry-demo@external") + .expect("installed plugin should be discoverable"); + assert_eq!(plugin.metadata().name, "registry-demo"); + assert!(!plugin.is_enabled()); + assert!(registry.contains("registry-demo@external")); + assert!(!registry.contains("missing@external")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn rejects_plugin_sources_with_missing_hook_paths() { + let config_home = temp_dir("broken-home"); + let source_root = temp_dir("broken-source"); + write_broken_plugin(&source_root, "broken"); + + let manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let error = manager + .validate_plugin_source(source_root.to_str().expect("utf8 path")) + .expect_err("missing hook file should fail validation"); + assert!(error.to_string().contains("does not exist")); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let install_error = manager + .install(source_root.to_str().expect("utf8 path")) + .expect_err("install should reject invalid hook paths"); + assert!(install_error.to_string().contains("does not exist")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn plugin_registry_runs_initialize_and_shutdown_for_enabled_plugins() { + let config_home = temp_dir("lifecycle-home"); + let source_root = temp_dir("lifecycle-source"); + let _ = write_lifecycle_plugin(&source_root, "lifecycle-demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + let install = manager + .install(source_root.to_str().expect("utf8 path")) + .expect("install should succeed"); + let log_path = install.install_path.join("lifecycle.log"); + + let registry = manager.plugin_registry().expect("registry should build"); + registry.initialize().expect("init should succeed"); + registry.shutdown().expect("shutdown should succeed"); + + let log = fs::read_to_string(&log_path).expect("lifecycle log should exist"); + assert_eq!(log, "init\nshutdown\n"); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn aggregates_and_executes_plugin_tools() { + let config_home = temp_dir("tool-home"); + let source_root = temp_dir("tool-source"); + write_tool_plugin(&source_root, "tool-demo", "1.0.0"); + + let mut manager = PluginManager::new(PluginManagerConfig::new(&config_home)); + manager + .install(source_root.to_str().expect("utf8 path")) + .expect("install should succeed"); + + let tools = manager.aggregated_tools().expect("tools should aggregate"); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].definition().name, "plugin_echo"); + assert_eq!(tools[0].required_permission(), "workspace-write"); + + let output = tools[0] + .execute(&serde_json::json!({ "message": "hello" })) + .expect("plugin tool should execute"); + let payload: Value = serde_json::from_str(&output).expect("valid json"); + assert_eq!(payload["plugin"], "tool-demo@external"); + assert_eq!(payload["tool"], "plugin_echo"); + assert_eq!(payload["input"]["message"], "hello"); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(source_root); + } + + #[test] + fn list_installed_plugins_scans_install_root_without_registry_entries() { + let config_home = temp_dir("installed-scan-home"); + let bundled_root = temp_dir("installed-scan-bundled"); + let install_root = config_home.join("plugins").join("installed"); + let installed_plugin_root = install_root.join("scan-demo"); + write_file( + installed_plugin_root.join(MANIFEST_FILE_NAME).as_path(), + r#"{ + "name": "scan-demo", + "version": "1.0.0", + "description": "Scanned from install root" +}"#, + ); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + config.install_root = Some(install_root); + let manager = PluginManager::new(config); + + let installed = manager + .list_installed_plugins() + .expect("installed plugins should scan directories"); + assert!(installed + .iter() + .any(|plugin| plugin.metadata.id == "scan-demo@external")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } + + #[test] + fn list_installed_plugins_scans_packaged_manifests_in_install_root() { + let config_home = temp_dir("installed-packaged-scan-home"); + let bundled_root = temp_dir("installed-packaged-scan-bundled"); + let install_root = config_home.join("plugins").join("installed"); + let installed_plugin_root = install_root.join("scan-packaged"); + write_file( + installed_plugin_root.join(MANIFEST_RELATIVE_PATH).as_path(), + r#"{ + "name": "scan-packaged", + "version": "1.0.0", + "description": "Packaged manifest in install root" +}"#, + ); + + let mut config = PluginManagerConfig::new(&config_home); + config.bundled_root = Some(bundled_root.clone()); + config.install_root = Some(install_root); + let manager = PluginManager::new(config); + + let installed = manager + .list_installed_plugins() + .expect("installed plugins should scan packaged manifests"); + assert!(installed + .iter() + .any(|plugin| plugin.metadata.id == "scan-packaged@external")); + + let _ = fs::remove_dir_all(config_home); + let _ = fs::remove_dir_all(bundled_root); + } +} diff --git a/rust/crates/runtime/Cargo.toml b/rust/crates/runtime/Cargo.toml index 7ce7cd8..025cd03 100644 --- a/rust/crates/runtime/Cargo.toml +++ b/rust/crates/runtime/Cargo.toml @@ -8,9 +8,11 @@ publish.workspace = true [dependencies] sha2 = "0.10" glob = "0.3" +lsp = { path = "../lsp" } +plugins = { path = "../plugins" } regex = "1" serde = { version = "1", features = ["derive"] } -serde_json = "1" +serde_json.workspace = true tokio = { version = "1", features = ["io-util", "macros", "process", "rt", "rt-multi-thread", "time"] } walkdir = "2" diff --git a/rust/crates/runtime/src/bootstrap.rs b/rust/crates/runtime/src/bootstrap.rs index dfc99ab..760f27e 100644 --- a/rust/crates/runtime/src/bootstrap.rs +++ b/rust/crates/runtime/src/bootstrap.rs @@ -21,7 +21,7 @@ pub struct BootstrapPlan { impl BootstrapPlan { #[must_use] - pub fn claude_code_default() -> Self { + pub fn claw_default() -> Self { Self::from_phases(vec![ BootstrapPhase::CliEntry, BootstrapPhase::FastPathVersion, diff --git a/rust/crates/runtime/src/compact.rs b/rust/crates/runtime/src/compact.rs index e227019..a0792da 100644 --- a/rust/crates/runtime/src/compact.rs +++ b/rust/crates/runtime/src/compact.rs @@ -1,5 +1,10 @@ use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session}; +const COMPACT_CONTINUATION_PREAMBLE: &str = + "This session is being continued from a previous conversation that ran out of context. The summary below covers the earlier portion of the conversation.\n\n"; +const COMPACT_RECENT_MESSAGES_NOTE: &str = "Recent messages are preserved verbatim."; +const COMPACT_DIRECT_RESUME_INSTRUCTION: &str = "Continue the conversation from where it left off without asking the user any further questions. Resume directly — do not acknowledge the summary, do not recap what was happening, and do not preface with continuation text."; + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct CompactionConfig { pub preserve_recent_messages: usize, @@ -30,8 +35,15 @@ pub fn estimate_session_tokens(session: &Session) -> usize { #[must_use] pub fn should_compact(session: &Session, config: CompactionConfig) -> bool { - session.messages.len() > config.preserve_recent_messages - && estimate_session_tokens(session) >= config.max_estimated_tokens + let start = compacted_summary_prefix_len(session); + let compactable = &session.messages[start..]; + + compactable.len() > config.preserve_recent_messages + && compactable + .iter() + .map(estimate_message_tokens) + .sum::<usize>() + >= config.max_estimated_tokens } #[must_use] @@ -56,16 +68,18 @@ pub fn get_compact_continuation_message( recent_messages_preserved: bool, ) -> String { let mut base = format!( - "This session is being continued from a previous conversation that ran out of context. The summary below covers the earlier portion of the conversation.\n\n{}", + "{COMPACT_CONTINUATION_PREAMBLE}{}", format_compact_summary(summary) ); if recent_messages_preserved { - base.push_str("\n\nRecent messages are preserved verbatim."); + base.push_str("\n\n"); + base.push_str(COMPACT_RECENT_MESSAGES_NOTE); } if suppress_follow_up_questions { - base.push_str("\nContinue the conversation from where it left off without asking the user any further questions. Resume directly — do not acknowledge the summary, do not recap what was happening, and do not preface with continuation text."); + base.push('\n'); + base.push_str(COMPACT_DIRECT_RESUME_INSTRUCTION); } base @@ -82,13 +96,19 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio }; } + let existing_summary = session + .messages + .first() + .and_then(extract_existing_compacted_summary); + let compacted_prefix_len = usize::from(existing_summary.is_some()); let keep_from = session .messages .len() .saturating_sub(config.preserve_recent_messages); - let removed = &session.messages[..keep_from]; + let removed = &session.messages[compacted_prefix_len..keep_from]; let preserved = session.messages[keep_from..].to_vec(); - let summary = summarize_messages(removed); + let summary = + merge_compact_summaries(existing_summary.as_deref(), &summarize_messages(removed)); let formatted_summary = format_compact_summary(&summary); let continuation = get_compact_continuation_message(&summary, true, !preserved.is_empty()); @@ -110,6 +130,16 @@ pub fn compact_session(session: &Session, config: CompactionConfig) -> Compactio } } +fn compacted_summary_prefix_len(session: &Session) -> usize { + usize::from( + session + .messages + .first() + .and_then(extract_existing_compacted_summary) + .is_some(), + ) +} + fn summarize_messages(messages: &[ConversationMessage]) -> String { let user_messages = messages .iter() @@ -197,6 +227,41 @@ fn summarize_messages(messages: &[ConversationMessage]) -> String { lines.join("\n") } +fn merge_compact_summaries(existing_summary: Option<&str>, new_summary: &str) -> String { + let Some(existing_summary) = existing_summary else { + return new_summary.to_string(); + }; + + let previous_highlights = extract_summary_highlights(existing_summary); + let new_formatted_summary = format_compact_summary(new_summary); + let new_highlights = extract_summary_highlights(&new_formatted_summary); + let new_timeline = extract_summary_timeline(&new_formatted_summary); + + let mut lines = vec!["<summary>".to_string(), "Conversation summary:".to_string()]; + + if !previous_highlights.is_empty() { + lines.push("- Previously compacted context:".to_string()); + lines.extend( + previous_highlights + .into_iter() + .map(|line| format!(" {line}")), + ); + } + + if !new_highlights.is_empty() { + lines.push("- Newly compacted context:".to_string()); + lines.extend(new_highlights.into_iter().map(|line| format!(" {line}"))); + } + + if !new_timeline.is_empty() { + lines.push("- Key timeline:".to_string()); + lines.extend(new_timeline.into_iter().map(|line| format!(" {line}"))); + } + + lines.push("</summary>".to_string()); + lines.join("\n") +} + fn summarize_block(block: &ContentBlock) -> String { let raw = match block { ContentBlock::Text { text } => text.clone(), @@ -374,11 +439,71 @@ fn collapse_blank_lines(content: &str) -> String { result } +fn extract_existing_compacted_summary(message: &ConversationMessage) -> Option<String> { + if message.role != MessageRole::System { + return None; + } + + let text = first_text_block(message)?; + let summary = text.strip_prefix(COMPACT_CONTINUATION_PREAMBLE)?; + let summary = summary + .split_once(&format!("\n\n{COMPACT_RECENT_MESSAGES_NOTE}")) + .map_or(summary, |(value, _)| value); + let summary = summary + .split_once(&format!("\n{COMPACT_DIRECT_RESUME_INSTRUCTION}")) + .map_or(summary, |(value, _)| value); + Some(summary.trim().to_string()) +} + +fn extract_summary_highlights(summary: &str) -> Vec<String> { + let mut lines = Vec::new(); + let mut in_timeline = false; + + for line in format_compact_summary(summary).lines() { + let trimmed = line.trim_end(); + if trimmed.is_empty() || trimmed == "Summary:" || trimmed == "Conversation summary:" { + continue; + } + if trimmed == "- Key timeline:" { + in_timeline = true; + continue; + } + if in_timeline { + continue; + } + lines.push(trimmed.to_string()); + } + + lines +} + +fn extract_summary_timeline(summary: &str) -> Vec<String> { + let mut lines = Vec::new(); + let mut in_timeline = false; + + for line in format_compact_summary(summary).lines() { + let trimmed = line.trim_end(); + if trimmed == "- Key timeline:" { + in_timeline = true; + continue; + } + if !in_timeline { + continue; + } + if trimmed.is_empty() { + break; + } + lines.push(trimmed.to_string()); + } + + lines +} + #[cfg(test)] mod tests { use super::{ collect_key_files, compact_session, estimate_session_tokens, format_compact_summary, - infer_pending_work, should_compact, CompactionConfig, + get_compact_continuation_message, infer_pending_work, should_compact, CompactionConfig, }; use crate::session::{ContentBlock, ConversationMessage, MessageRole, Session}; @@ -453,6 +578,98 @@ mod tests { ); } + #[test] + fn keeps_previous_compacted_context_when_compacting_again() { + let initial_session = Session { + version: 1, + messages: vec![ + ConversationMessage::user_text("Investigate rust/crates/runtime/src/compact.rs"), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "I will inspect the compact flow.".to_string(), + }]), + ConversationMessage::user_text( + "Also update rust/crates/runtime/src/conversation.rs", + ), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "Next: preserve prior summary context during auto compact.".to_string(), + }]), + ], + }; + let config = CompactionConfig { + preserve_recent_messages: 2, + max_estimated_tokens: 1, + }; + + let first = compact_session(&initial_session, config); + let mut follow_up_messages = first.compacted_session.messages.clone(); + follow_up_messages.extend([ + ConversationMessage::user_text("Please add regression tests for compaction."), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "Working on regression coverage now.".to_string(), + }]), + ]); + + let second = compact_session( + &Session { + version: 1, + messages: follow_up_messages, + }, + config, + ); + + assert!(second + .formatted_summary + .contains("Previously compacted context:")); + assert!(second + .formatted_summary + .contains("Scope: 2 earlier messages compacted")); + assert!(second + .formatted_summary + .contains("Newly compacted context:")); + assert!(second + .formatted_summary + .contains("Also update rust/crates/runtime/src/conversation.rs")); + assert!(matches!( + &second.compacted_session.messages[0].blocks[0], + ContentBlock::Text { text } + if text.contains("Previously compacted context:") + && text.contains("Newly compacted context:") + )); + assert!(matches!( + &second.compacted_session.messages[1].blocks[0], + ContentBlock::Text { text } if text.contains("Please add regression tests for compaction.") + )); + } + + #[test] + fn ignores_existing_compacted_summary_when_deciding_to_recompact() { + let summary = "<summary>Conversation summary:\n- Scope: earlier work preserved.\n- Key timeline:\n - user: large preserved context\n</summary>"; + let session = Session { + version: 1, + messages: vec![ + ConversationMessage { + role: MessageRole::System, + blocks: vec![ContentBlock::Text { + text: get_compact_continuation_message(summary, true, true), + }], + usage: None, + }, + ConversationMessage::user_text("tiny"), + ConversationMessage::assistant(vec![ContentBlock::Text { + text: "recent".to_string(), + }]), + ], + }; + + assert!(!should_compact( + &session, + CompactionConfig { + preserve_recent_messages: 2, + max_estimated_tokens: 1, + } + )); + } + #[test] fn truncates_long_blocks_in_summary() { let summary = super::summarize_block(&ContentBlock::Text { @@ -465,10 +682,10 @@ mod tests { #[test] fn extracts_key_files_from_message_content() { let files = collect_key_files(&[ConversationMessage::user_text( - "Update rust/crates/runtime/src/compact.rs and rust/crates/rusty-claude-cli/src/main.rs next.", + "Update rust/crates/runtime/src/compact.rs and rust/crates/tools/src/lib.rs next.", )]); assert!(files.contains(&"rust/crates/runtime/src/compact.rs".to_string())); - assert!(files.contains(&"rust/crates/rusty-claude-cli/src/main.rs".to_string())); + assert!(files.contains(&"rust/crates/tools/src/lib.rs".to_string())); } #[test] diff --git a/rust/crates/runtime/src/config.rs b/rust/crates/runtime/src/config.rs index 60ef53f..11ec21d 100644 --- a/rust/crates/runtime/src/config.rs +++ b/rust/crates/runtime/src/config.rs @@ -6,7 +6,7 @@ use std::path::{Path, PathBuf}; use crate::json::JsonValue; use crate::sandbox::{FilesystemIsolationMode, SandboxConfig}; -pub const CLAUDE_CODE_SETTINGS_SCHEMA_NAME: &str = "SettingsSchema"; +pub const CLAW_SETTINGS_SCHEMA_NAME: &str = "SettingsSchema"; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum ConfigSource { @@ -35,8 +35,19 @@ pub struct RuntimeConfig { feature_config: RuntimeFeatureConfig, } +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct RuntimePluginConfig { + enabled_plugins: BTreeMap<String, bool>, + external_directories: Vec<String>, + install_root: Option<String>, + registry_path: Option<String>, + bundled_root: Option<String>, +} + #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct RuntimeFeatureConfig { + hooks: RuntimeHookConfig, + plugins: RuntimePluginConfig, mcp: McpConfigCollection, oauth: Option<OAuthConfig>, model: Option<String>, @@ -44,6 +55,12 @@ pub struct RuntimeFeatureConfig { sandbox: SandboxConfig, } +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct RuntimeHookConfig { + pre_tool_use: Vec<String>, + post_tool_use: Vec<String>, +} + #[derive(Debug, Clone, PartialEq, Eq, Default)] pub struct McpConfigCollection { servers: BTreeMap<String, ScopedMcpServerConfig>, @@ -62,7 +79,7 @@ pub enum McpTransport { Http, Ws, Sdk, - ClaudeAiProxy, + ManagedProxy, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -72,7 +89,7 @@ pub enum McpServerConfig { Http(McpRemoteServerConfig), Ws(McpWebSocketServerConfig), Sdk(McpSdkServerConfig), - ClaudeAiProxy(McpClaudeAiProxyServerConfig), + ManagedProxy(McpManagedProxyServerConfig), } #[derive(Debug, Clone, PartialEq, Eq)] @@ -103,7 +120,7 @@ pub struct McpSdkServerConfig { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct McpClaudeAiProxyServerConfig { +pub struct McpManagedProxyServerConfig { pub url: String, pub id: String, } @@ -167,25 +184,20 @@ impl ConfigLoader { #[must_use] pub fn default_for(cwd: impl Into<PathBuf>) -> Self { let cwd = cwd.into(); - let config_home = std::env::var_os("CLAUDE_CONFIG_HOME") - .map(PathBuf::from) - .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".claude"))) - .or_else(|| { - if cfg!(target_os = "windows") { - std::env::var_os("USERPROFILE").map(|home| PathBuf::from(home).join(".claude")) - } else { - None - } - }) - .unwrap_or_else(|| PathBuf::from(".claude")); + let config_home = default_config_home(); Self { cwd, config_home } } + #[must_use] + pub fn config_home(&self) -> &Path { + &self.config_home + } + #[must_use] pub fn discover(&self) -> Vec<ConfigEntry> { let user_legacy_path = self.config_home.parent().map_or_else( - || PathBuf::from(".claude.json"), - |parent| parent.join(".claude.json"), + || PathBuf::from(".claw.json"), + |parent| parent.join(".claw.json"), ); vec![ ConfigEntry { @@ -198,15 +210,15 @@ impl ConfigLoader { }, ConfigEntry { source: ConfigSource::Project, - path: self.cwd.join(".claude.json"), + path: self.cwd.join(".claw.json"), }, ConfigEntry { source: ConfigSource::Project, - path: self.cwd.join(".claude").join("settings.json"), + path: self.cwd.join(".claw").join("settings.json"), }, ConfigEntry { source: ConfigSource::Local, - path: self.cwd.join(".claude").join("settings.local.json"), + path: self.cwd.join(".claw").join("settings.local.json"), }, ] } @@ -228,6 +240,8 @@ impl ConfigLoader { let merged_value = JsonValue::Object(merged.clone()); let feature_config = RuntimeFeatureConfig { + hooks: parse_optional_hooks_config(&merged_value)?, + plugins: parse_optional_plugin_config(&merged_value)?, mcp: McpConfigCollection { servers: mcp_servers, }, @@ -285,6 +299,16 @@ impl RuntimeConfig { &self.feature_config.mcp } + #[must_use] + pub fn hooks(&self) -> &RuntimeHookConfig { + &self.feature_config.hooks + } + + #[must_use] + pub fn plugins(&self) -> &RuntimePluginConfig { + &self.feature_config.plugins + } + #[must_use] pub fn oauth(&self) -> Option<&OAuthConfig> { self.feature_config.oauth.as_ref() @@ -307,6 +331,28 @@ impl RuntimeConfig { } impl RuntimeFeatureConfig { + #[must_use] + pub fn with_hooks(mut self, hooks: RuntimeHookConfig) -> Self { + self.hooks = hooks; + self + } + + #[must_use] + pub fn with_plugins(mut self, plugins: RuntimePluginConfig) -> Self { + self.plugins = plugins; + self + } + + #[must_use] + pub fn hooks(&self) -> &RuntimeHookConfig { + &self.hooks + } + + #[must_use] + pub fn plugins(&self) -> &RuntimePluginConfig { + &self.plugins + } + #[must_use] pub fn mcp(&self) -> &McpConfigCollection { &self.mcp @@ -333,6 +379,85 @@ impl RuntimeFeatureConfig { } } +impl RuntimePluginConfig { + #[must_use] + pub fn enabled_plugins(&self) -> &BTreeMap<String, bool> { + &self.enabled_plugins + } + + #[must_use] + pub fn external_directories(&self) -> &[String] { + &self.external_directories + } + + #[must_use] + pub fn install_root(&self) -> Option<&str> { + self.install_root.as_deref() + } + + #[must_use] + pub fn registry_path(&self) -> Option<&str> { + self.registry_path.as_deref() + } + + #[must_use] + pub fn bundled_root(&self) -> Option<&str> { + self.bundled_root.as_deref() + } + + pub fn set_plugin_state(&mut self, plugin_id: String, enabled: bool) { + self.enabled_plugins.insert(plugin_id, enabled); + } + + #[must_use] + pub fn state_for(&self, plugin_id: &str, default_enabled: bool) -> bool { + self.enabled_plugins + .get(plugin_id) + .copied() + .unwrap_or(default_enabled) + } +} + +#[must_use] +pub fn default_config_home() -> PathBuf { + std::env::var_os("CLAW_CONFIG_HOME") + .map(PathBuf::from) + .or_else(|| std::env::var_os("HOME").map(|home| PathBuf::from(home).join(".claw"))) + .unwrap_or_else(|| PathBuf::from(".claw")) +} + +impl RuntimeHookConfig { + #[must_use] + pub fn new(pre_tool_use: Vec<String>, post_tool_use: Vec<String>) -> Self { + Self { + pre_tool_use, + post_tool_use, + } + } + + #[must_use] + pub fn pre_tool_use(&self) -> &[String] { + &self.pre_tool_use + } + + #[must_use] + pub fn post_tool_use(&self) -> &[String] { + &self.post_tool_use + } + + #[must_use] + pub fn merged(&self, other: &Self) -> Self { + let mut merged = self.clone(); + merged.extend(other); + merged + } + + pub fn extend(&mut self, other: &Self) { + extend_unique(&mut self.pre_tool_use, other.pre_tool_use()); + extend_unique(&mut self.post_tool_use, other.post_tool_use()); + } +} + impl McpConfigCollection { #[must_use] pub fn servers(&self) -> &BTreeMap<String, ScopedMcpServerConfig> { @@ -361,7 +486,7 @@ impl McpServerConfig { Self::Http(_) => McpTransport::Http, Self::Ws(_) => McpTransport::Ws, Self::Sdk(_) => McpTransport::Sdk, - Self::ClaudeAiProxy(_) => McpTransport::ClaudeAiProxy, + Self::ManagedProxy(_) => McpTransport::ManagedProxy, } } } @@ -369,7 +494,7 @@ impl McpServerConfig { fn read_optional_json_object( path: &Path, ) -> Result<Option<BTreeMap<String, JsonValue>>, ConfigError> { - let is_legacy_config = path.file_name().and_then(|name| name.to_str()) == Some(".claude.json"); + let is_legacy_config = path.file_name().and_then(|name| name.to_str()) == Some(".claw.json"); let contents = match fs::read_to_string(path) { Ok(contents) => contents, Err(error) if error.kind() == std::io::ErrorKind::NotFound => return Ok(None), @@ -431,6 +556,52 @@ fn parse_optional_model(root: &JsonValue) -> Option<String> { .map(ToOwned::to_owned) } +fn parse_optional_hooks_config(root: &JsonValue) -> Result<RuntimeHookConfig, ConfigError> { + let Some(object) = root.as_object() else { + return Ok(RuntimeHookConfig::default()); + }; + let Some(hooks_value) = object.get("hooks") else { + return Ok(RuntimeHookConfig::default()); + }; + let hooks = expect_object(hooks_value, "merged settings.hooks")?; + Ok(RuntimeHookConfig { + pre_tool_use: optional_string_array(hooks, "PreToolUse", "merged settings.hooks")? + .unwrap_or_default(), + post_tool_use: optional_string_array(hooks, "PostToolUse", "merged settings.hooks")? + .unwrap_or_default(), + }) +} + +fn parse_optional_plugin_config(root: &JsonValue) -> Result<RuntimePluginConfig, ConfigError> { + let Some(object) = root.as_object() else { + return Ok(RuntimePluginConfig::default()); + }; + + let mut config = RuntimePluginConfig::default(); + if let Some(enabled_plugins) = object.get("enabledPlugins") { + config.enabled_plugins = parse_bool_map(enabled_plugins, "merged settings.enabledPlugins")?; + } + + let Some(plugins_value) = object.get("plugins") else { + return Ok(config); + }; + let plugins = expect_object(plugins_value, "merged settings.plugins")?; + + if let Some(enabled_value) = plugins.get("enabled") { + config.enabled_plugins = parse_bool_map(enabled_value, "merged settings.plugins.enabled")?; + } + config.external_directories = + optional_string_array(plugins, "externalDirectories", "merged settings.plugins")? + .unwrap_or_default(); + config.install_root = + optional_string(plugins, "installRoot", "merged settings.plugins")?.map(str::to_string); + config.registry_path = + optional_string(plugins, "registryPath", "merged settings.plugins")?.map(str::to_string); + config.bundled_root = + optional_string(plugins, "bundledRoot", "merged settings.plugins")?.map(str::to_string); + Ok(config) +} + fn parse_optional_permission_mode( root: &JsonValue, ) -> Result<Option<ResolvedPermissionMode>, ConfigError> { @@ -553,12 +724,10 @@ fn parse_mcp_server_config( "sdk" => Ok(McpServerConfig::Sdk(McpSdkServerConfig { name: expect_string(object, "name", context)?.to_string(), })), - "claudeai-proxy" => Ok(McpServerConfig::ClaudeAiProxy( - McpClaudeAiProxyServerConfig { - url: expect_string(object, "url", context)?.to_string(), - id: expect_string(object, "id", context)?.to_string(), - }, - )), + "claudeai-proxy" => Ok(McpServerConfig::ManagedProxy(McpManagedProxyServerConfig { + url: expect_string(object, "url", context)?.to_string(), + id: expect_string(object, "id", context)?.to_string(), + })), other => Err(ConfigError::Parse(format!( "{context}: unsupported MCP server type for {server_name}: {other}" ))), @@ -663,6 +832,24 @@ fn optional_u16( } } +fn parse_bool_map(value: &JsonValue, context: &str) -> Result<BTreeMap<String, bool>, ConfigError> { + let Some(map) = value.as_object() else { + return Err(ConfigError::Parse(format!( + "{context}: expected JSON object" + ))); + }; + map.iter() + .map(|(key, value)| { + value + .as_bool() + .map(|enabled| (key.clone(), enabled)) + .ok_or_else(|| { + ConfigError::Parse(format!("{context}: field {key} must be a boolean")) + }) + }) + .collect() +} + fn optional_string_array( object: &BTreeMap<String, JsonValue>, key: &str, @@ -737,11 +924,23 @@ fn deep_merge_objects( } } +fn extend_unique(target: &mut Vec<String>, values: &[String]) { + for value in values { + push_unique(target, value.clone()); + } +} + +fn push_unique(target: &mut Vec<String>, value: String) { + if !target.iter().any(|existing| existing == &value) { + target.push(value); + } +} + #[cfg(test)] mod tests { use super::{ ConfigLoader, ConfigSource, McpServerConfig, McpTransport, ResolvedPermissionMode, - CLAUDE_CODE_SETTINGS_SCHEMA_NAME, + CLAW_SETTINGS_SCHEMA_NAME, }; use crate::json::JsonValue; use crate::sandbox::FilesystemIsolationMode; @@ -760,7 +959,7 @@ mod tests { fn rejects_non_object_settings_files() { let root = temp_dir(); let cwd = root.join("project"); - let home = root.join("home").join(".claude"); + let home = root.join("home").join(".claw"); fs::create_dir_all(&home).expect("home config dir"); fs::create_dir_all(&cwd).expect("project dir"); fs::write(home.join("settings.json"), "[]").expect("write bad settings"); @@ -776,15 +975,15 @@ mod tests { } #[test] - fn loads_and_merges_claude_code_config_files_by_precedence() { + fn loads_and_merges_claw_code_config_files_by_precedence() { let root = temp_dir(); let cwd = root.join("project"); - let home = root.join("home").join(".claude"); - fs::create_dir_all(cwd.join(".claude")).expect("project config dir"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); fs::create_dir_all(&home).expect("home config dir"); fs::write( - home.parent().expect("home parent").join(".claude.json"), + home.parent().expect("home parent").join(".claw.json"), r#"{"model":"haiku","env":{"A":"1"},"mcpServers":{"home":{"command":"uvx","args":["home"]}}}"#, ) .expect("write user compat config"); @@ -794,17 +993,17 @@ mod tests { ) .expect("write user settings"); fs::write( - cwd.join(".claude.json"), + cwd.join(".claw.json"), r#"{"model":"project-compat","env":{"B":"2"}}"#, ) .expect("write project compat config"); fs::write( - cwd.join(".claude").join("settings.json"), + cwd.join(".claw").join("settings.json"), r#"{"env":{"C":"3"},"hooks":{"PostToolUse":["project"]},"mcpServers":{"project":{"command":"uvx","args":["project"]}}}"#, ) .expect("write project settings"); fs::write( - cwd.join(".claude").join("settings.local.json"), + cwd.join(".claw").join("settings.local.json"), r#"{"model":"opus","permissionMode":"acceptEdits"}"#, ) .expect("write local settings"); @@ -813,7 +1012,7 @@ mod tests { .load() .expect("config should load"); - assert_eq!(CLAUDE_CODE_SETTINGS_SCHEMA_NAME, "SettingsSchema"); + assert_eq!(CLAW_SETTINGS_SCHEMA_NAME, "SettingsSchema"); assert_eq!(loaded.loaded_entries().len(), 5); assert_eq!(loaded.loaded_entries()[0].source, ConfigSource::User); assert_eq!( @@ -843,6 +1042,8 @@ mod tests { .and_then(JsonValue::as_object) .expect("hooks object") .contains_key("PostToolUse")); + assert_eq!(loaded.hooks().pre_tool_use(), &["base".to_string()]); + assert_eq!(loaded.hooks().post_tool_use(), &["project".to_string()]); assert!(loaded.mcp().get("home").is_some()); assert!(loaded.mcp().get("project").is_some()); @@ -853,12 +1054,12 @@ mod tests { fn parses_sandbox_config() { let root = temp_dir(); let cwd = root.join("project"); - let home = root.join("home").join(".claude"); - fs::create_dir_all(cwd.join(".claude")).expect("project config dir"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); fs::create_dir_all(&home).expect("home config dir"); fs::write( - cwd.join(".claude").join("settings.local.json"), + cwd.join(".claw").join("settings.local.json"), r#"{ "sandbox": { "enabled": true, @@ -891,8 +1092,8 @@ mod tests { fn parses_typed_mcp_and_oauth_config() { let root = temp_dir(); let cwd = root.join("project"); - let home = root.join("home").join(".claude"); - fs::create_dir_all(cwd.join(".claude")).expect("project config dir"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); fs::create_dir_all(&home).expect("home config dir"); fs::write( @@ -929,7 +1130,7 @@ mod tests { ) .expect("write user settings"); fs::write( - cwd.join(".claude").join("settings.local.json"), + cwd.join(".claw").join("settings.local.json"), r#"{ "mcpServers": { "remote-server": { @@ -978,11 +1179,101 @@ mod tests { fs::remove_dir_all(root).expect("cleanup temp dir"); } + #[test] + fn parses_plugin_config_from_enabled_plugins() { + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); + fs::create_dir_all(&home).expect("home config dir"); + + fs::write( + home.join("settings.json"), + r#"{ + "enabledPlugins": { + "tool-guard@builtin": true, + "sample-plugin@external": false + } + }"#, + ) + .expect("write user settings"); + + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + assert_eq!( + loaded.plugins().enabled_plugins().get("tool-guard@builtin"), + Some(&true) + ); + assert_eq!( + loaded + .plugins() + .enabled_plugins() + .get("sample-plugin@external"), + Some(&false) + ); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + + #[test] + fn parses_plugin_config() { + let root = temp_dir(); + let cwd = root.join("project"); + let home = root.join("home").join(".claw"); + fs::create_dir_all(cwd.join(".claw")).expect("project config dir"); + fs::create_dir_all(&home).expect("home config dir"); + + fs::write( + home.join("settings.json"), + r#"{ + "enabledPlugins": { + "core-helpers@builtin": true + }, + "plugins": { + "externalDirectories": ["./external-plugins"], + "installRoot": "plugin-cache/installed", + "registryPath": "plugin-cache/installed.json", + "bundledRoot": "./bundled-plugins" + } + }"#, + ) + .expect("write plugin settings"); + + let loaded = ConfigLoader::new(&cwd, &home) + .load() + .expect("config should load"); + + assert_eq!( + loaded + .plugins() + .enabled_plugins() + .get("core-helpers@builtin"), + Some(&true) + ); + assert_eq!( + loaded.plugins().external_directories(), + &["./external-plugins".to_string()] + ); + assert_eq!( + loaded.plugins().install_root(), + Some("plugin-cache/installed") + ); + assert_eq!( + loaded.plugins().registry_path(), + Some("plugin-cache/installed.json") + ); + assert_eq!(loaded.plugins().bundled_root(), Some("./bundled-plugins")); + + fs::remove_dir_all(root).expect("cleanup temp dir"); + } + #[test] fn rejects_invalid_mcp_server_shapes() { let root = temp_dir(); let cwd = root.join("project"); - let home = root.join("home").join(".claude"); + let home = root.join("home").join(".claw"); fs::create_dir_all(&home).expect("home config dir"); fs::create_dir_all(&cwd).expect("project dir"); fs::write( diff --git a/rust/crates/runtime/src/conversation.rs b/rust/crates/runtime/src/conversation.rs index 625fb25..8411b8d 100644 --- a/rust/crates/runtime/src/conversation.rs +++ b/rust/crates/runtime/src/conversation.rs @@ -4,6 +4,8 @@ use std::fmt::{Display, Formatter}; use crate::compact::{ compact_session, estimate_session_tokens, CompactionConfig, CompactionResult, }; +use crate::config::RuntimeFeatureConfig; +use crate::hooks::{HookRunResult, HookRunner}; use crate::permissions::{PermissionOutcome, PermissionPolicy, PermissionPrompter}; use crate::session::{ContentBlock, ConversationMessage, Session}; use crate::usage::{TokenUsage, UsageTracker}; @@ -94,6 +96,7 @@ pub struct ConversationRuntime<C, T> { system_prompt: Vec<String>, max_iterations: usize, usage_tracker: UsageTracker, + hook_runner: HookRunner, } impl<C, T> ConversationRuntime<C, T> @@ -108,6 +111,25 @@ where tool_executor: T, permission_policy: PermissionPolicy, system_prompt: Vec<String>, + ) -> Self { + Self::new_with_features( + session, + api_client, + tool_executor, + permission_policy, + system_prompt, + RuntimeFeatureConfig::default(), + ) + } + + #[must_use] + pub fn new_with_features( + session: Session, + api_client: C, + tool_executor: T, + permission_policy: PermissionPolicy, + system_prompt: Vec<String>, + feature_config: RuntimeFeatureConfig, ) -> Self { let usage_tracker = UsageTracker::from_session(&session); Self { @@ -116,8 +138,9 @@ where tool_executor, permission_policy, system_prompt, - max_iterations: 16, + max_iterations: usize::MAX, usage_tracker, + hook_runner: HookRunner::from_feature_config(&feature_config), } } @@ -185,19 +208,41 @@ where let result_message = match permission_outcome { PermissionOutcome::Allow => { - match self.tool_executor.execute(&tool_name, &input) { - Ok(output) => ConversationMessage::tool_result( + let pre_hook_result = self.hook_runner.run_pre_tool_use(&tool_name, &input); + if pre_hook_result.is_denied() { + let deny_message = format!("PreToolUse hook denied tool `{tool_name}`"); + ConversationMessage::tool_result( + tool_use_id, + tool_name, + format_hook_message(&pre_hook_result, &deny_message), + true, + ) + } else { + let (mut output, mut is_error) = + match self.tool_executor.execute(&tool_name, &input) { + Ok(output) => (output, false), + Err(error) => (error.to_string(), true), + }; + output = merge_hook_feedback(pre_hook_result.messages(), output, false); + + let post_hook_result = self + .hook_runner + .run_post_tool_use(&tool_name, &input, &output, is_error); + if post_hook_result.is_denied() { + is_error = true; + } + output = merge_hook_feedback( + post_hook_result.messages(), + output, + post_hook_result.is_denied(), + ); + + ConversationMessage::tool_result( tool_use_id, tool_name, output, - false, - ), - Err(error) => ConversationMessage::tool_result( - tool_use_id, - tool_name, - error.to_string(), - true, - ), + is_error, + ) } } PermissionOutcome::Deny { reason } => { @@ -290,6 +335,32 @@ fn flush_text_block(text: &mut String, blocks: &mut Vec<ContentBlock>) { } } +fn format_hook_message(result: &HookRunResult, fallback: &str) -> String { + if result.messages().is_empty() { + fallback.to_string() + } else { + result.messages().join("\n") + } +} + +fn merge_hook_feedback(messages: &[String], output: String, denied: bool) -> String { + if messages.is_empty() { + return output; + } + + let mut sections = Vec::new(); + if !output.trim().is_empty() { + sections.push(output); + } + let label = if denied { + "Hook feedback (denied)" + } else { + "Hook feedback" + }; + sections.push(format!("{label}:\n{}", messages.join("\n"))); + sections.join("\n\n") +} + type ToolHandler = Box<dyn FnMut(&str) -> Result<String, ToolError>>; #[derive(Default)] @@ -329,6 +400,7 @@ mod tests { StaticToolExecutor, }; use crate::compact::CompactionConfig; + use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; use crate::permissions::{ PermissionMode, PermissionPolicy, PermissionPromptDecision, PermissionPrompter, PermissionRequest, @@ -503,6 +575,141 @@ mod tests { )); } + #[test] + fn denies_tool_use_when_pre_tool_hook_blocks() { + struct SingleCallApiClient; + impl ApiClient for SingleCallApiClient { + fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { + if request + .messages + .iter() + .any(|message| message.role == MessageRole::Tool) + { + return Ok(vec![ + AssistantEvent::TextDelta("blocked".to_string()), + AssistantEvent::MessageStop, + ]); + } + Ok(vec![ + AssistantEvent::ToolUse { + id: "tool-1".to_string(), + name: "blocked".to_string(), + input: r#"{"path":"secret.txt"}"#.to_string(), + }, + AssistantEvent::MessageStop, + ]) + } + } + + let mut runtime = ConversationRuntime::new_with_features( + Session::new(), + SingleCallApiClient, + StaticToolExecutor::new().register("blocked", |_input| { + panic!("tool should not execute when hook denies") + }), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + vec![shell_snippet("printf 'blocked by hook'; exit 2")], + Vec::new(), + )), + ); + + let summary = runtime + .run_turn("use the tool", None) + .expect("conversation should continue after hook denial"); + + assert_eq!(summary.tool_results.len(), 1); + let ContentBlock::ToolResult { + is_error, output, .. + } = &summary.tool_results[0].blocks[0] + else { + panic!("expected tool result block"); + }; + assert!( + *is_error, + "hook denial should produce an error result: {output}" + ); + assert!( + output.contains("denied tool") || output.contains("blocked by hook"), + "unexpected hook denial output: {output:?}" + ); + } + + #[test] + fn appends_post_tool_hook_feedback_to_tool_result() { + struct TwoCallApiClient { + calls: usize, + } + + impl ApiClient for TwoCallApiClient { + fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { + self.calls += 1; + match self.calls { + 1 => Ok(vec![ + AssistantEvent::ToolUse { + id: "tool-1".to_string(), + name: "add".to_string(), + input: r#"{"lhs":2,"rhs":2}"#.to_string(), + }, + AssistantEvent::MessageStop, + ]), + 2 => { + assert!(request + .messages + .iter() + .any(|message| message.role == MessageRole::Tool)); + Ok(vec![ + AssistantEvent::TextDelta("done".to_string()), + AssistantEvent::MessageStop, + ]) + } + _ => Err(RuntimeError::new("unexpected extra API call")), + } + } + } + + let mut runtime = ConversationRuntime::new_with_features( + Session::new(), + TwoCallApiClient { calls: 0 }, + StaticToolExecutor::new().register("add", |_input| Ok("4".to_string())), + PermissionPolicy::new(PermissionMode::DangerFullAccess), + vec!["system".to_string()], + RuntimeFeatureConfig::default().with_hooks(RuntimeHookConfig::new( + vec![shell_snippet("printf 'pre hook ran'")], + vec![shell_snippet("printf 'post hook ran'")], + )), + ); + + let summary = runtime + .run_turn("use add", None) + .expect("tool loop succeeds"); + + assert_eq!(summary.tool_results.len(), 1); + let ContentBlock::ToolResult { + is_error, output, .. + } = &summary.tool_results[0].blocks[0] + else { + panic!("expected tool result block"); + }; + assert!( + !*is_error, + "post hook should preserve non-error result: {output:?}" + ); + assert!( + output.contains('4'), + "tool output missing value: {output:?}" + ); + assert!( + output.contains("pre hook ran"), + "tool output missing pre hook feedback: {output:?}" + ); + assert!( + output.contains("post hook ran"), + "tool output missing post hook feedback: {output:?}" + ); + } + #[test] fn reconstructs_usage_tracker_from_restored_session() { struct SimpleApi; @@ -581,4 +788,14 @@ mod tests { MessageRole::System ); } + + #[cfg(windows)] + fn shell_snippet(script: &str) -> String { + script.replace('\'', "\"") + } + + #[cfg(not(windows))] + fn shell_snippet(script: &str) -> String { + script.to_string() + } } diff --git a/rust/crates/runtime/src/file_ops.rs b/rust/crates/runtime/src/file_ops.rs index a647b85..1faf9ab 100644 --- a/rust/crates/runtime/src/file_ops.rs +++ b/rust/crates/runtime/src/file_ops.rs @@ -488,7 +488,7 @@ mod tests { .duration_since(UNIX_EPOCH) .expect("time should move forward") .as_nanos(); - std::env::temp_dir().join(format!("clawd-native-{name}-{unique}")) + std::env::temp_dir().join(format!("claw-native-{name}-{unique}")) } #[test] diff --git a/rust/crates/runtime/src/hooks.rs b/rust/crates/runtime/src/hooks.rs new file mode 100644 index 0000000..63ef9ff --- /dev/null +++ b/rust/crates/runtime/src/hooks.rs @@ -0,0 +1,357 @@ +use std::ffi::OsStr; +use std::process::Command; + +use serde_json::json; + +use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum HookEvent { + PreToolUse, + PostToolUse, +} + +impl HookEvent { + fn as_str(self) -> &'static str { + match self { + Self::PreToolUse => "PreToolUse", + Self::PostToolUse => "PostToolUse", + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct HookRunResult { + denied: bool, + messages: Vec<String>, +} + +impl HookRunResult { + #[must_use] + pub fn allow(messages: Vec<String>) -> Self { + Self { + denied: false, + messages, + } + } + + #[must_use] + pub fn is_denied(&self) -> bool { + self.denied + } + + #[must_use] + pub fn messages(&self) -> &[String] { + &self.messages + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct HookRunner { + config: RuntimeHookConfig, +} + +#[derive(Debug, Clone, Copy)] +struct HookCommandRequest<'a> { + event: HookEvent, + tool_name: &'a str, + tool_input: &'a str, + tool_output: Option<&'a str>, + is_error: bool, + payload: &'a str, +} + +impl HookRunner { + #[must_use] + pub fn new(config: RuntimeHookConfig) -> Self { + Self { config } + } + + #[must_use] + pub fn from_feature_config(feature_config: &RuntimeFeatureConfig) -> Self { + Self::new(feature_config.hooks().clone()) + } + + #[must_use] + pub fn run_pre_tool_use(&self, tool_name: &str, tool_input: &str) -> HookRunResult { + self.run_commands( + HookEvent::PreToolUse, + self.config.pre_tool_use(), + tool_name, + tool_input, + None, + false, + ) + } + + #[must_use] + pub fn run_post_tool_use( + &self, + tool_name: &str, + tool_input: &str, + tool_output: &str, + is_error: bool, + ) -> HookRunResult { + self.run_commands( + HookEvent::PostToolUse, + self.config.post_tool_use(), + tool_name, + tool_input, + Some(tool_output), + is_error, + ) + } + + fn run_commands( + &self, + event: HookEvent, + commands: &[String], + tool_name: &str, + tool_input: &str, + tool_output: Option<&str>, + is_error: bool, + ) -> HookRunResult { + if commands.is_empty() { + return HookRunResult::allow(Vec::new()); + } + + let payload = json!({ + "hook_event_name": event.as_str(), + "tool_name": tool_name, + "tool_input": parse_tool_input(tool_input), + "tool_input_json": tool_input, + "tool_output": tool_output, + "tool_result_is_error": is_error, + }) + .to_string(); + + let mut messages = Vec::new(); + + for command in commands { + match Self::run_command( + command, + HookCommandRequest { + event, + tool_name, + tool_input, + tool_output, + is_error, + payload: &payload, + }, + ) { + HookCommandOutcome::Allow { message } => { + if let Some(message) = message { + messages.push(message); + } + } + HookCommandOutcome::Deny { message } => { + let message = message.unwrap_or_else(|| { + format!("{} hook denied tool `{tool_name}`", event.as_str()) + }); + messages.push(message); + return HookRunResult { + denied: true, + messages, + }; + } + HookCommandOutcome::Warn { message } => messages.push(message), + } + } + + HookRunResult::allow(messages) + } + + fn run_command(command: &str, request: HookCommandRequest<'_>) -> HookCommandOutcome { + let mut child = shell_command(command); + child.stdin(std::process::Stdio::piped()); + child.stdout(std::process::Stdio::piped()); + child.stderr(std::process::Stdio::piped()); + child.env("HOOK_EVENT", request.event.as_str()); + child.env("HOOK_TOOL_NAME", request.tool_name); + child.env("HOOK_TOOL_INPUT", request.tool_input); + child.env( + "HOOK_TOOL_IS_ERROR", + if request.is_error { "1" } else { "0" }, + ); + if let Some(tool_output) = request.tool_output { + child.env("HOOK_TOOL_OUTPUT", tool_output); + } + + match child.output_with_stdin(request.payload.as_bytes()) { + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + let message = (!stdout.is_empty()).then_some(stdout); + match output.status.code() { + Some(0) => HookCommandOutcome::Allow { message }, + Some(2) => HookCommandOutcome::Deny { message }, + Some(code) => HookCommandOutcome::Warn { + message: format_hook_warning( + command, + code, + message.as_deref(), + stderr.as_str(), + ), + }, + None => HookCommandOutcome::Warn { + message: format!( + "{} hook `{command}` terminated by signal while handling `{}`", + request.event.as_str(), + request.tool_name + ), + }, + } + } + Err(error) => HookCommandOutcome::Warn { + message: format!( + "{} hook `{command}` failed to start for `{}`: {error}", + request.event.as_str(), + request.tool_name + ), + }, + } + } +} + +enum HookCommandOutcome { + Allow { message: Option<String> }, + Deny { message: Option<String> }, + Warn { message: String }, +} + +fn parse_tool_input(tool_input: &str) -> serde_json::Value { + serde_json::from_str(tool_input).unwrap_or_else(|_| json!({ "raw": tool_input })) +} + +fn format_hook_warning(command: &str, code: i32, stdout: Option<&str>, stderr: &str) -> String { + let mut message = + format!("Hook `{command}` exited with status {code}; allowing tool execution to continue"); + if let Some(stdout) = stdout.filter(|stdout| !stdout.is_empty()) { + message.push_str(": "); + message.push_str(stdout); + } else if !stderr.is_empty() { + message.push_str(": "); + message.push_str(stderr); + } + message +} + +fn shell_command(command: &str) -> CommandWithStdin { + #[cfg(windows)] + let mut command_builder = { + let mut command_builder = Command::new("cmd"); + command_builder.arg("/C").arg(command); + CommandWithStdin::new(command_builder) + }; + + #[cfg(not(windows))] + let command_builder = { + let mut command_builder = Command::new("sh"); + command_builder.arg("-lc").arg(command); + CommandWithStdin::new(command_builder) + }; + + command_builder +} + +struct CommandWithStdin { + command: Command, +} + +impl CommandWithStdin { + fn new(command: Command) -> Self { + Self { command } + } + + fn stdin(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stdin(cfg); + self + } + + fn stdout(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stdout(cfg); + self + } + + fn stderr(&mut self, cfg: std::process::Stdio) -> &mut Self { + self.command.stderr(cfg); + self + } + + fn env<K, V>(&mut self, key: K, value: V) -> &mut Self + where + K: AsRef<OsStr>, + V: AsRef<OsStr>, + { + self.command.env(key, value); + self + } + + fn output_with_stdin(&mut self, stdin: &[u8]) -> std::io::Result<std::process::Output> { + let mut child = self.command.spawn()?; + if let Some(mut child_stdin) = child.stdin.take() { + use std::io::Write; + child_stdin.write_all(stdin)?; + } + child.wait_with_output() + } +} + +#[cfg(test)] +mod tests { + use super::{HookRunResult, HookRunner}; + use crate::config::{RuntimeFeatureConfig, RuntimeHookConfig}; + + #[test] + fn allows_exit_code_zero_and_captures_stdout() { + let runner = HookRunner::new(RuntimeHookConfig::new( + vec![shell_snippet("printf 'pre ok'")], + Vec::new(), + )); + + let result = runner.run_pre_tool_use("Read", r#"{"path":"README.md"}"#); + + assert_eq!(result, HookRunResult::allow(vec!["pre ok".to_string()])); + } + + #[test] + fn denies_exit_code_two() { + let runner = HookRunner::new(RuntimeHookConfig::new( + vec![shell_snippet("printf 'blocked by hook'; exit 2")], + Vec::new(), + )); + + let result = runner.run_pre_tool_use("Bash", r#"{"command":"pwd"}"#); + + assert!(result.is_denied()); + assert_eq!(result.messages(), &["blocked by hook".to_string()]); + } + + #[test] + fn warns_for_other_non_zero_statuses() { + let runner = HookRunner::from_feature_config(&RuntimeFeatureConfig::default().with_hooks( + RuntimeHookConfig::new( + vec![shell_snippet("printf 'warning hook'; exit 1")], + Vec::new(), + ), + )); + + let result = runner.run_pre_tool_use("Edit", r#"{"file":"src/lib.rs"}"#); + + assert!(!result.is_denied()); + assert!(result + .messages() + .iter() + .any(|message| message.contains("allowing tool execution to continue"))); + } + + #[cfg(windows)] + fn shell_snippet(script: &str) -> String { + script.replace('\'', "\"") + } + + #[cfg(not(windows))] + fn shell_snippet(script: &str) -> String { + script.to_string() + } +} diff --git a/rust/crates/runtime/src/lib.rs b/rust/crates/runtime/src/lib.rs index 2861d47..c714f95 100644 --- a/rust/crates/runtime/src/lib.rs +++ b/rust/crates/runtime/src/lib.rs @@ -4,6 +4,7 @@ mod compact; mod config; mod conversation; mod file_ops; +mod hooks; mod json; mod mcp; mod mcp_client; @@ -16,6 +17,10 @@ pub mod sandbox; mod session; mod usage; +pub use lsp::{ + FileDiagnostics, LspContextEnrichment, LspError, LspManager, LspServerConfig, + SymbolLocation, WorkspaceDiagnostics, +}; pub use bash::{execute_bash, BashCommandInput, BashCommandOutput}; pub use bootstrap::{BootstrapPhase, BootstrapPlan}; pub use compact::{ @@ -23,11 +28,11 @@ pub use compact::{ get_compact_continuation_message, should_compact, CompactionConfig, CompactionResult, }; pub use config::{ - ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpClaudeAiProxyServerConfig, + ConfigEntry, ConfigError, ConfigLoader, ConfigSource, McpManagedProxyServerConfig, McpConfigCollection, McpOAuthConfig, McpRemoteServerConfig, McpSdkServerConfig, McpServerConfig, McpStdioServerConfig, McpTransport, McpWebSocketServerConfig, OAuthConfig, - ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, ScopedMcpServerConfig, - CLAUDE_CODE_SETTINGS_SCHEMA_NAME, + ResolvedPermissionMode, RuntimeConfig, RuntimeFeatureConfig, RuntimeHookConfig, + RuntimePluginConfig, ScopedMcpServerConfig, CLAW_SETTINGS_SCHEMA_NAME, }; pub use conversation::{ ApiClient, ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, StaticToolExecutor, @@ -38,12 +43,13 @@ pub use file_ops::{ GrepSearchInput, GrepSearchOutput, ReadFileOutput, StructuredPatchHunk, TextFilePayload, WriteFileOutput, }; +pub use hooks::{HookEvent, HookRunResult, HookRunner}; pub use mcp::{ mcp_server_signature, mcp_tool_name, mcp_tool_prefix, normalize_name_for_mcp, scoped_mcp_config_hash, unwrap_ccr_proxy_url, }; pub use mcp_client::{ - McpClaudeAiProxyTransport, McpClientAuth, McpClientBootstrap, McpClientTransport, + McpManagedProxyTransport, McpClientAuth, McpClientBootstrap, McpClientTransport, McpRemoteTransport, McpSdkTransport, McpStdioTransport, }; pub use mcp_stdio::{ diff --git a/rust/crates/runtime/src/mcp.rs b/rust/crates/runtime/src/mcp.rs index 103fbe4..b37ea33 100644 --- a/rust/crates/runtime/src/mcp.rs +++ b/rust/crates/runtime/src/mcp.rs @@ -73,7 +73,7 @@ pub fn mcp_server_signature(config: &McpServerConfig) -> Option<String> { Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url))) } McpServerConfig::Ws(config) => Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url))), - McpServerConfig::ClaudeAiProxy(config) => { + McpServerConfig::ManagedProxy(config) => { Some(format!("url:{}", unwrap_ccr_proxy_url(&config.url))) } McpServerConfig::Sdk(_) => None, @@ -110,7 +110,7 @@ pub fn scoped_mcp_config_hash(config: &ScopedMcpServerConfig) -> String { ws.headers_helper.as_deref().unwrap_or("") ), McpServerConfig::Sdk(sdk) => format!("sdk|{}", sdk.name), - McpServerConfig::ClaudeAiProxy(proxy) => { + McpServerConfig::ManagedProxy(proxy) => { format!("claudeai-proxy|{}|{}", proxy.url, proxy.id) } }; diff --git a/rust/crates/runtime/src/mcp_client.rs b/rust/crates/runtime/src/mcp_client.rs index 23ccb95..e0e1f2c 100644 --- a/rust/crates/runtime/src/mcp_client.rs +++ b/rust/crates/runtime/src/mcp_client.rs @@ -10,7 +10,7 @@ pub enum McpClientTransport { Http(McpRemoteTransport), WebSocket(McpRemoteTransport), Sdk(McpSdkTransport), - ClaudeAiProxy(McpClaudeAiProxyTransport), + ManagedProxy(McpManagedProxyTransport), } #[derive(Debug, Clone, PartialEq, Eq)] @@ -34,7 +34,7 @@ pub struct McpSdkTransport { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct McpClaudeAiProxyTransport { +pub struct McpManagedProxyTransport { pub url: String, pub id: String, } @@ -97,12 +97,10 @@ impl McpClientTransport { McpServerConfig::Sdk(config) => Self::Sdk(McpSdkTransport { name: config.name.clone(), }), - McpServerConfig::ClaudeAiProxy(config) => { - Self::ClaudeAiProxy(McpClaudeAiProxyTransport { - url: config.url.clone(), - id: config.id.clone(), - }) - } + McpServerConfig::ManagedProxy(config) => Self::ManagedProxy(McpManagedProxyTransport { + url: config.url.clone(), + id: config.id.clone(), + }), } } } diff --git a/rust/crates/runtime/src/mcp_stdio.rs b/rust/crates/runtime/src/mcp_stdio.rs index 7e67d5d..27402d6 100644 --- a/rust/crates/runtime/src/mcp_stdio.rs +++ b/rust/crates/runtime/src/mcp_stdio.rs @@ -809,6 +809,7 @@ mod tests { use std::io::ErrorKind; use std::os::unix::fs::PermissionsExt; use std::path::{Path, PathBuf}; + use std::process::Command; use std::time::{SystemTime, UNIX_EPOCH}; use serde_json::json; @@ -1137,15 +1138,37 @@ mod tests { fn script_transport(script_path: &Path) -> crate::mcp_client::McpStdioTransport { crate::mcp_client::McpStdioTransport { - command: "python3".to_string(), + command: python_command(), args: vec![script_path.to_string_lossy().into_owned()], env: BTreeMap::new(), } } + fn python_command() -> String { + for key in ["MCP_TEST_PYTHON", "PYTHON3", "PYTHON"] { + if let Ok(value) = std::env::var(key) { + if !value.trim().is_empty() { + return value; + } + } + } + + for candidate in ["python3", "python"] { + if Command::new(candidate).arg("--version").output().is_ok() { + return candidate.to_string(); + } + } + + panic!("expected a Python interpreter for MCP stdio tests") + } + fn cleanup_script(script_path: &Path) { - fs::remove_file(script_path).expect("cleanup script"); - fs::remove_dir_all(script_path.parent().expect("script parent")).expect("cleanup dir"); + if let Err(error) = fs::remove_file(script_path) { + assert_eq!(error.kind(), std::io::ErrorKind::NotFound, "cleanup script"); + } + if let Err(error) = fs::remove_dir_all(script_path.parent().expect("script parent")) { + assert_eq!(error.kind(), std::io::ErrorKind::NotFound, "cleanup dir"); + } } fn manager_server_config( @@ -1156,7 +1179,7 @@ mod tests { ScopedMcpServerConfig { scope: ConfigSource::Local, config: McpServerConfig::Stdio(McpStdioServerConfig { - command: "python3".to_string(), + command: python_command(), args: vec![script_path.to_string_lossy().into_owned()], env: BTreeMap::from([ ("MCP_SERVER_LABEL".to_string(), label.to_string()), diff --git a/rust/crates/runtime/src/oauth.rs b/rust/crates/runtime/src/oauth.rs index 837bdf2..e4756c1 100644 --- a/rust/crates/runtime/src/oauth.rs +++ b/rust/crates/runtime/src/oauth.rs @@ -324,15 +324,15 @@ fn generate_random_token(bytes: usize) -> io::Result<String> { } fn credentials_home_dir() -> io::Result<PathBuf> { - if let Some(path) = std::env::var_os("CLAUDE_CONFIG_HOME") { + if let Some(path) = std::env::var_os("CLAW_CONFIG_HOME") { return Ok(PathBuf::from(path)); } if let Some(path) = std::env::var_os("HOME") { - return Ok(PathBuf::from(path).join(".claude")); + return Ok(PathBuf::from(path).join(".claw")); } if cfg!(target_os = "windows") { if let Some(path) = std::env::var_os("USERPROFILE") { - return Ok(PathBuf::from(path).join(".claude")); + return Ok(PathBuf::from(path).join(".claw")); } } Err(io::Error::new(io::ErrorKind::NotFound, "HOME or USERPROFILE is not set")) @@ -547,7 +547,7 @@ mod tests { fn oauth_credentials_round_trip_and_clear_preserves_other_fields() { let _guard = env_lock(); let config_home = temp_config_home(); - std::env::set_var("CLAUDE_CONFIG_HOME", &config_home); + std::env::set_var("CLAW_CONFIG_HOME", &config_home); let path = credentials_path().expect("credentials path"); std::fs::create_dir_all(path.parent().expect("parent")).expect("create parent"); std::fs::write(&path, "{\"other\":\"value\"}\n").expect("seed credentials"); @@ -573,7 +573,7 @@ mod tests { assert!(cleared.contains("\"other\": \"value\"")); assert!(!cleared.contains("\"oauth\"")); - std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::env::remove_var("CLAW_CONFIG_HOME"); std::fs::remove_dir_all(config_home).expect("cleanup temp dir"); } diff --git a/rust/crates/runtime/src/prompt.rs b/rust/crates/runtime/src/prompt.rs index 7192412..d3b09e3 100644 --- a/rust/crates/runtime/src/prompt.rs +++ b/rust/crates/runtime/src/prompt.rs @@ -4,6 +4,7 @@ use std::path::{Path, PathBuf}; use std::process::Command; use crate::config::{ConfigError, ConfigLoader, RuntimeConfig}; +use lsp::LspContextEnrichment; #[derive(Debug)] pub enum PromptBuildError { @@ -35,7 +36,7 @@ impl From<ConfigError> for PromptBuildError { } pub const SYSTEM_PROMPT_DYNAMIC_BOUNDARY: &str = "__SYSTEM_PROMPT_DYNAMIC_BOUNDARY__"; -pub const FRONTIER_MODEL_NAME: &str = "Claude Opus 4.6"; +pub const FRONTIER_MODEL_NAME: &str = "Opus 4.6"; const MAX_INSTRUCTION_FILE_CHARS: usize = 4_000; const MAX_TOTAL_INSTRUCTION_CHARS: usize = 12_000; @@ -130,6 +131,15 @@ impl SystemPromptBuilder { self } + #[must_use] + pub fn with_lsp_context(mut self, enrichment: &LspContextEnrichment) -> Self { + if !enrichment.is_empty() { + self.append_sections + .push(enrichment.render_prompt_section()); + } + self + } + #[must_use] pub fn build(&self) -> Vec<String> { let mut sections = Vec::new(); @@ -201,10 +211,10 @@ fn discover_instruction_files(cwd: &Path) -> std::io::Result<Vec<ContextFile>> { let mut files = Vec::new(); for dir in directories { for candidate in [ - dir.join("CLAUDE.md"), - dir.join("CLAUDE.local.md"), - dir.join(".claude").join("CLAUDE.md"), - dir.join(".claude").join("instructions.md"), + dir.join("CLAW.md"), + dir.join("CLAW.local.md"), + dir.join(".claw").join("CLAW.md"), + dir.join(".claw").join("instructions.md"), ] { push_context_file(&mut files, candidate)?; } @@ -282,7 +292,7 @@ fn render_project_context(project_context: &ProjectContext) -> String { ]; if !project_context.instruction_files.is_empty() { bullets.push(format!( - "Claude instruction files discovered: {}.", + "Claw instruction files discovered: {}.", project_context.instruction_files.len() )); } @@ -301,7 +311,7 @@ fn render_project_context(project_context: &ProjectContext) -> String { } fn render_instruction_files(files: &[ContextFile]) -> String { - let mut sections = vec!["# Claude instructions".to_string()]; + let mut sections = vec!["# Claw instructions".to_string()]; let mut remaining_chars = MAX_TOTAL_INSTRUCTION_CHARS; for file in files { if remaining_chars == 0 { @@ -421,7 +431,7 @@ fn render_config_section(config: &RuntimeConfig) -> String { let mut lines = vec!["# Runtime config".to_string()]; if config.loaded_entries().is_empty() { lines.extend(prepend_bullets(vec![ - "No Claude Code settings files loaded.".to_string(), + "No Claw Code settings files loaded.".to_string() ])); return lines.join("\n"); } @@ -517,23 +527,23 @@ mod tests { fn discovers_instruction_files_from_ancestor_chain() { let root = temp_dir(); let nested = root.join("apps").join("api"); - fs::create_dir_all(nested.join(".claude")).expect("nested claude dir"); - fs::write(root.join("CLAUDE.md"), "root instructions").expect("write root instructions"); - fs::write(root.join("CLAUDE.local.md"), "local instructions") + fs::create_dir_all(nested.join(".claw")).expect("nested claw dir"); + fs::write(root.join("CLAW.md"), "root instructions").expect("write root instructions"); + fs::write(root.join("CLAW.local.md"), "local instructions") .expect("write local instructions"); fs::create_dir_all(root.join("apps")).expect("apps dir"); - fs::create_dir_all(root.join("apps").join(".claude")).expect("apps claude dir"); - fs::write(root.join("apps").join("CLAUDE.md"), "apps instructions") + fs::create_dir_all(root.join("apps").join(".claw")).expect("apps claw dir"); + fs::write(root.join("apps").join("CLAW.md"), "apps instructions") .expect("write apps instructions"); fs::write( - root.join("apps").join(".claude").join("instructions.md"), - "apps dot claude instructions", + root.join("apps").join(".claw").join("instructions.md"), + "apps dot claw instructions", ) - .expect("write apps dot claude instructions"); - fs::write(nested.join(".claude").join("CLAUDE.md"), "nested rules") + .expect("write apps dot claw instructions"); + fs::write(nested.join(".claw").join("CLAW.md"), "nested rules") .expect("write nested rules"); fs::write( - nested.join(".claude").join("instructions.md"), + nested.join(".claw").join("instructions.md"), "nested instructions", ) .expect("write nested instructions"); @@ -551,7 +561,7 @@ mod tests { "root instructions", "local instructions", "apps instructions", - "apps dot claude instructions", + "apps dot claw instructions", "nested rules", "nested instructions" ] @@ -564,8 +574,8 @@ mod tests { let root = temp_dir(); let nested = root.join("apps").join("api"); fs::create_dir_all(&nested).expect("nested dir"); - fs::write(root.join("CLAUDE.md"), "same rules\n\n").expect("write root"); - fs::write(nested.join("CLAUDE.md"), "same rules\n").expect("write nested"); + fs::write(root.join("CLAW.md"), "same rules\n\n").expect("write root"); + fs::write(nested.join("CLAW.md"), "same rules\n").expect("write nested"); let context = ProjectContext::discover(&nested, "2026-03-31").expect("context should load"); assert_eq!(context.instruction_files.len(), 1); @@ -593,13 +603,14 @@ mod tests { #[test] fn displays_context_paths_compactly() { assert_eq!( - display_context_path(Path::new("/tmp/project/.claude/CLAUDE.md")), - "CLAUDE.md" + display_context_path(Path::new("/tmp/project/.claw/CLAW.md")), + "CLAW.md" ); } #[test] fn discover_with_git_includes_status_snapshot() { + let _guard = env_lock(); let root = temp_dir(); fs::create_dir_all(&root).expect("root dir"); std::process::Command::new("git") @@ -607,7 +618,7 @@ mod tests { .current_dir(&root) .status() .expect("git init should run"); - fs::write(root.join("CLAUDE.md"), "rules").expect("write instructions"); + fs::write(root.join("CLAW.md"), "rules").expect("write instructions"); fs::write(root.join("tracked.txt"), "hello").expect("write tracked file"); let context = @@ -615,7 +626,7 @@ mod tests { let status = context.git_status.expect("git status should be present"); assert!(status.contains("## No commits yet on") || status.contains("## ")); - assert!(status.contains("?? CLAUDE.md")); + assert!(status.contains("?? CLAW.md")); assert!(status.contains("?? tracked.txt")); assert!(context.git_diff.is_none()); @@ -624,6 +635,7 @@ mod tests { #[test] fn discover_with_git_includes_diff_snapshot_for_tracked_changes() { + let _guard = env_lock(); let root = temp_dir(); fs::create_dir_all(&root).expect("root dir"); std::process::Command::new("git") @@ -665,12 +677,12 @@ mod tests { } #[test] - fn load_system_prompt_reads_claude_files_and_config() { + fn load_system_prompt_reads_claw_files_and_config() { let root = temp_dir(); - fs::create_dir_all(root.join(".claude")).expect("claude dir"); - fs::write(root.join("CLAUDE.md"), "Project rules").expect("write instructions"); + fs::create_dir_all(root.join(".claw")).expect("claw dir"); + fs::write(root.join("CLAW.md"), "Project rules").expect("write instructions"); fs::write( - root.join(".claude").join("settings.json"), + root.join(".claw").join("settings.json"), r#"{"permissionMode":"acceptEdits"}"#, ) .expect("write settings"); @@ -678,9 +690,9 @@ mod tests { let _guard = env_lock(); let previous = std::env::current_dir().expect("cwd"); let original_home = std::env::var("HOME").ok(); - let original_claude_home = std::env::var("CLAUDE_CONFIG_HOME").ok(); + let original_claw_home = std::env::var("CLAW_CONFIG_HOME").ok(); std::env::set_var("HOME", &root); - std::env::set_var("CLAUDE_CONFIG_HOME", root.join("missing-home")); + std::env::set_var("CLAW_CONFIG_HOME", root.join("missing-home")); std::env::set_current_dir(&root).expect("change cwd"); let prompt = super::load_system_prompt(&root, "2026-03-31", "linux", "6.8") .expect("system prompt should load") @@ -695,10 +707,10 @@ mod tests { } else { std::env::remove_var("HOME"); } - if let Some(value) = original_claude_home { - std::env::set_var("CLAUDE_CONFIG_HOME", value); + if let Some(value) = original_claw_home { + std::env::set_var("CLAW_CONFIG_HOME", value); } else { - std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::env::remove_var("CLAW_CONFIG_HOME"); } assert!(prompt.contains("Project rules")); @@ -707,12 +719,12 @@ mod tests { } #[test] - fn renders_claude_code_style_sections_with_project_context() { + fn renders_claw_code_style_sections_with_project_context() { let root = temp_dir(); - fs::create_dir_all(root.join(".claude")).expect("claude dir"); - fs::write(root.join("CLAUDE.md"), "Project rules").expect("write CLAUDE.md"); + fs::create_dir_all(root.join(".claw")).expect("claw dir"); + fs::write(root.join("CLAW.md"), "Project rules").expect("write CLAW.md"); fs::write( - root.join(".claude").join("settings.json"), + root.join(".claw").join("settings.json"), r#"{"permissionMode":"acceptEdits"}"#, ) .expect("write settings"); @@ -731,7 +743,7 @@ mod tests { assert!(prompt.contains("# System")); assert!(prompt.contains("# Project context")); - assert!(prompt.contains("# Claude instructions")); + assert!(prompt.contains("# Claw instructions")); assert!(prompt.contains("Project rules")); assert!(prompt.contains("permissionMode")); assert!(prompt.contains(SYSTEM_PROMPT_DYNAMIC_BOUNDARY)); @@ -748,12 +760,12 @@ mod tests { } #[test] - fn discovers_dot_claude_instructions_markdown() { + fn discovers_dot_claw_instructions_markdown() { let root = temp_dir(); let nested = root.join("apps").join("api"); - fs::create_dir_all(nested.join(".claude")).expect("nested claude dir"); + fs::create_dir_all(nested.join(".claw")).expect("nested claw dir"); fs::write( - nested.join(".claude").join("instructions.md"), + nested.join(".claw").join("instructions.md"), "instruction markdown", ) .expect("write instructions.md"); @@ -762,7 +774,7 @@ mod tests { assert!(context .instruction_files .iter() - .any(|file| file.path.ends_with(".claude/instructions.md"))); + .any(|file| file.path.ends_with(".claw/instructions.md"))); assert!( render_instruction_files(&context.instruction_files).contains("instruction markdown") ); @@ -773,10 +785,10 @@ mod tests { #[test] fn renders_instruction_file_metadata() { let rendered = render_instruction_files(&[ContextFile { - path: PathBuf::from("/tmp/project/CLAUDE.md"), + path: PathBuf::from("/tmp/project/CLAW.md"), content: "Project rules".to_string(), }]); - assert!(rendered.contains("# Claude instructions")); + assert!(rendered.contains("# Claw instructions")); assert!(rendered.contains("scope: /tmp/project")); assert!(rendered.contains("Project rules")); } diff --git a/rust/crates/runtime/src/remote.rs b/rust/crates/runtime/src/remote.rs index 24ee780..5fe59a0 100644 --- a/rust/crates/runtime/src/remote.rs +++ b/rust/crates/runtime/src/remote.rs @@ -72,9 +72,9 @@ impl RemoteSessionContext { #[must_use] pub fn from_env_map(env_map: &BTreeMap<String, String>) -> Self { Self { - enabled: env_truthy(env_map.get("CLAUDE_CODE_REMOTE")), + enabled: env_truthy(env_map.get("CLAW_CODE_REMOTE")), session_id: env_map - .get("CLAUDE_CODE_REMOTE_SESSION_ID") + .get("CLAW_CODE_REMOTE_SESSION_ID") .filter(|value| !value.is_empty()) .cloned(), base_url: env_map @@ -272,9 +272,9 @@ mod tests { #[test] fn remote_context_reads_env_state() { let env = BTreeMap::from([ - ("CLAUDE_CODE_REMOTE".to_string(), "true".to_string()), + ("CLAW_CODE_REMOTE".to_string(), "true".to_string()), ( - "CLAUDE_CODE_REMOTE_SESSION_ID".to_string(), + "CLAW_CODE_REMOTE_SESSION_ID".to_string(), "session-123".to_string(), ), ( @@ -291,7 +291,7 @@ mod tests { #[test] fn bootstrap_fails_open_when_token_or_session_is_missing() { let env = BTreeMap::from([ - ("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()), + ("CLAW_CODE_REMOTE".to_string(), "1".to_string()), ("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()), ]); let bootstrap = UpstreamProxyBootstrap::from_env_map(&env); @@ -307,10 +307,10 @@ mod tests { fs::write(&token_path, "secret-token\n").expect("write token"); let env = BTreeMap::from([ - ("CLAUDE_CODE_REMOTE".to_string(), "1".to_string()), + ("CLAW_CODE_REMOTE".to_string(), "1".to_string()), ("CCR_UPSTREAM_PROXY_ENABLED".to_string(), "true".to_string()), ( - "CLAUDE_CODE_REMOTE_SESSION_ID".to_string(), + "CLAW_CODE_REMOTE_SESSION_ID".to_string(), "session-123".to_string(), ), ( diff --git a/rust/crates/runtime/src/session.rs b/rust/crates/runtime/src/session.rs index beaa435..ec37070 100644 --- a/rust/crates/runtime/src/session.rs +++ b/rust/crates/runtime/src/session.rs @@ -3,10 +3,13 @@ use std::fmt::{Display, Formatter}; use std::fs; use std::path::Path; +use serde::{Deserialize, Serialize}; + use crate::json::{JsonError, JsonValue}; use crate::usage::TokenUsage; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] pub enum MessageRole { System, User, @@ -14,7 +17,8 @@ pub enum MessageRole { Tool, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] pub enum ContentBlock { Text { text: String, @@ -32,14 +36,14 @@ pub enum ContentBlock { }, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct ConversationMessage { pub role: MessageRole, pub blocks: Vec<ContentBlock>, pub usage: Option<TokenUsage>, } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct Session { pub version: u32, pub messages: Vec<ConversationMessage>, diff --git a/rust/crates/runtime/src/usage.rs b/rust/crates/runtime/src/usage.rs index 04e28df..0570bc1 100644 --- a/rust/crates/runtime/src/usage.rs +++ b/rust/crates/runtime/src/usage.rs @@ -1,4 +1,5 @@ use crate::session::Session; +use serde::{Deserialize, Serialize}; const DEFAULT_INPUT_COST_PER_MILLION: f64 = 15.0; const DEFAULT_OUTPUT_COST_PER_MILLION: f64 = 75.0; @@ -25,7 +26,7 @@ impl ModelPricing { } } -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)] pub struct TokenUsage { pub input_tokens: u32, pub output_tokens: u32, @@ -249,9 +250,9 @@ mod tests { let cost = usage.estimate_cost_usd(); assert_eq!(format_usd(cost.input_cost_usd), "$15.0000"); assert_eq!(format_usd(cost.output_cost_usd), "$37.5000"); - let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-20250514")); + let lines = usage.summary_lines_for_model("usage", Some("claude-sonnet-4-6")); assert!(lines[0].contains("estimated_cost=$54.6750")); - assert!(lines[0].contains("model=claude-sonnet-4-20250514")); + assert!(lines[0].contains("model=claude-sonnet-4-6")); assert!(lines[1].contains("cache_read=$0.3000")); } @@ -264,7 +265,7 @@ mod tests { cache_read_input_tokens: 0, }; - let haiku = pricing_for_model("claude-haiku-4-5-20251001").expect("haiku pricing"); + let haiku = pricing_for_model("claude-haiku-4-5-20251213").expect("haiku pricing"); let opus = pricing_for_model("claude-opus-4-6").expect("opus pricing"); let haiku_cost = usage.estimate_cost_usd_with_pricing(haiku); let opus_cost = usage.estimate_cost_usd_with_pricing(opus); diff --git a/rust/crates/server/Cargo.toml b/rust/crates/server/Cargo.toml new file mode 100644 index 0000000..9151aef --- /dev/null +++ b/rust/crates/server/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "server" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true + +[dependencies] +async-stream = "0.3" +axum = "0.8" +runtime = { path = "../runtime" } +serde = { version = "1", features = ["derive"] } +serde_json.workspace = true +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync", "net", "time"] } + +[dev-dependencies] +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "stream"] } + +[lints] +workspace = true diff --git a/rust/crates/server/src/lib.rs b/rust/crates/server/src/lib.rs new file mode 100644 index 0000000..b3386ea --- /dev/null +++ b/rust/crates/server/src/lib.rs @@ -0,0 +1,442 @@ +use std::collections::HashMap; +use std::convert::Infallible; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use async_stream::stream; +use axum::extract::{Path, State}; +use axum::http::StatusCode; +use axum::response::sse::{Event, KeepAlive, Sse}; +use axum::response::IntoResponse; +use axum::routing::{get, post}; +use axum::{Json, Router}; +use runtime::{ConversationMessage, Session as RuntimeSession}; +use serde::{Deserialize, Serialize}; +use tokio::sync::{broadcast, RwLock}; + +pub type SessionId = String; +pub type SessionStore = Arc<RwLock<HashMap<SessionId, Session>>>; + +const BROADCAST_CAPACITY: usize = 64; + +#[derive(Clone)] +pub struct AppState { + sessions: SessionStore, + next_session_id: Arc<AtomicU64>, +} + +impl AppState { + #[must_use] + pub fn new() -> Self { + Self { + sessions: Arc::new(RwLock::new(HashMap::new())), + next_session_id: Arc::new(AtomicU64::new(1)), + } + } + + fn allocate_session_id(&self) -> SessionId { + let id = self.next_session_id.fetch_add(1, Ordering::Relaxed); + format!("session-{id}") + } +} + +impl Default for AppState { + fn default() -> Self { + Self::new() + } +} + +#[derive(Clone)] +pub struct Session { + pub id: SessionId, + pub created_at: u64, + pub conversation: RuntimeSession, + events: broadcast::Sender<SessionEvent>, +} + +impl Session { + fn new(id: SessionId) -> Self { + let (events, _) = broadcast::channel(BROADCAST_CAPACITY); + Self { + id, + created_at: unix_timestamp_millis(), + conversation: RuntimeSession::new(), + events, + } + } + + fn subscribe(&self) -> broadcast::Receiver<SessionEvent> { + self.events.subscribe() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +enum SessionEvent { + Snapshot { + session_id: SessionId, + session: RuntimeSession, + }, + Message { + session_id: SessionId, + message: ConversationMessage, + }, +} + +impl SessionEvent { + fn event_name(&self) -> &'static str { + match self { + Self::Snapshot { .. } => "snapshot", + Self::Message { .. } => "message", + } + } + + fn to_sse_event(&self) -> Result<Event, serde_json::Error> { + Ok(Event::default() + .event(self.event_name()) + .data(serde_json::to_string(self)?)) + } +} + +#[derive(Debug, Serialize)] +struct ErrorResponse { + error: String, +} + +type ApiError = (StatusCode, Json<ErrorResponse>); +type ApiResult<T> = Result<T, ApiError>; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct CreateSessionResponse { + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionSummary { + pub id: SessionId, + pub created_at: u64, + pub message_count: usize, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ListSessionsResponse { + pub sessions: Vec<SessionSummary>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionDetailsResponse { + pub id: SessionId, + pub created_at: u64, + pub session: RuntimeSession, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SendMessageRequest { + pub message: String, +} + +#[must_use] +pub fn app(state: AppState) -> Router { + Router::new() + .route("/sessions", post(create_session).get(list_sessions)) + .route("/sessions/{id}", get(get_session)) + .route("/sessions/{id}/events", get(stream_session_events)) + .route("/sessions/{id}/message", post(send_message)) + .with_state(state) +} + +async fn create_session( + State(state): State<AppState>, +) -> (StatusCode, Json<CreateSessionResponse>) { + let session_id = state.allocate_session_id(); + let session = Session::new(session_id.clone()); + + state + .sessions + .write() + .await + .insert(session_id.clone(), session); + + ( + StatusCode::CREATED, + Json(CreateSessionResponse { session_id }), + ) +} + +async fn list_sessions(State(state): State<AppState>) -> Json<ListSessionsResponse> { + let sessions = state.sessions.read().await; + let mut summaries = sessions + .values() + .map(|session| SessionSummary { + id: session.id.clone(), + created_at: session.created_at, + message_count: session.conversation.messages.len(), + }) + .collect::<Vec<_>>(); + summaries.sort_by(|left, right| left.id.cmp(&right.id)); + + Json(ListSessionsResponse { + sessions: summaries, + }) +} + +async fn get_session( + State(state): State<AppState>, + Path(id): Path<SessionId>, +) -> ApiResult<Json<SessionDetailsResponse>> { + let sessions = state.sessions.read().await; + let session = sessions + .get(&id) + .ok_or_else(|| not_found(format!("session `{id}` not found")))?; + + Ok(Json(SessionDetailsResponse { + id: session.id.clone(), + created_at: session.created_at, + session: session.conversation.clone(), + })) +} + +async fn send_message( + State(state): State<AppState>, + Path(id): Path<SessionId>, + Json(payload): Json<SendMessageRequest>, +) -> ApiResult<StatusCode> { + let message = ConversationMessage::user_text(payload.message); + let broadcaster = { + let mut sessions = state.sessions.write().await; + let session = sessions + .get_mut(&id) + .ok_or_else(|| not_found(format!("session `{id}` not found")))?; + session.conversation.messages.push(message.clone()); + session.events.clone() + }; + + let _ = broadcaster.send(SessionEvent::Message { + session_id: id, + message, + }); + + Ok(StatusCode::NO_CONTENT) +} + +async fn stream_session_events( + State(state): State<AppState>, + Path(id): Path<SessionId>, +) -> ApiResult<impl IntoResponse> { + let (snapshot, mut receiver) = { + let sessions = state.sessions.read().await; + let session = sessions + .get(&id) + .ok_or_else(|| not_found(format!("session `{id}` not found")))?; + ( + SessionEvent::Snapshot { + session_id: session.id.clone(), + session: session.conversation.clone(), + }, + session.subscribe(), + ) + }; + + let stream = stream! { + if let Ok(event) = snapshot.to_sse_event() { + yield Ok::<Event, Infallible>(event); + } + + loop { + match receiver.recv().await { + Ok(event) => { + if let Ok(sse_event) = event.to_sse_event() { + yield Ok::<Event, Infallible>(sse_event); + } + } + Err(broadcast::error::RecvError::Lagged(_)) => continue, + Err(broadcast::error::RecvError::Closed) => break, + } + } + }; + + Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15)))) +} + +fn unix_timestamp_millis() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be after epoch") + .as_millis() as u64 +} + +fn not_found(message: String) -> ApiError { + ( + StatusCode::NOT_FOUND, + Json(ErrorResponse { error: message }), + ) +} + +#[cfg(test)] +mod tests { + use super::{ + app, AppState, CreateSessionResponse, ListSessionsResponse, SessionDetailsResponse, + }; + use reqwest::Client; + use std::net::SocketAddr; + use std::time::Duration; + use tokio::net::TcpListener; + use tokio::task::JoinHandle; + use tokio::time::timeout; + + struct TestServer { + address: SocketAddr, + handle: JoinHandle<()>, + } + + impl TestServer { + async fn spawn() -> Self { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("test listener should bind"); + let address = listener + .local_addr() + .expect("listener should report local address"); + let handle = tokio::spawn(async move { + axum::serve(listener, app(AppState::default())) + .await + .expect("server should run"); + }); + + Self { address, handle } + } + + fn url(&self, path: &str) -> String { + format!("http://{}{}", self.address, path) + } + } + + impl Drop for TestServer { + fn drop(&mut self) { + self.handle.abort(); + } + } + + async fn create_session(client: &Client, server: &TestServer) -> CreateSessionResponse { + client + .post(server.url("/sessions")) + .send() + .await + .expect("create request should succeed") + .error_for_status() + .expect("create request should return success") + .json::<CreateSessionResponse>() + .await + .expect("create response should parse") + } + + async fn next_sse_frame(response: &mut reqwest::Response, buffer: &mut String) -> String { + loop { + if let Some(index) = buffer.find("\n\n") { + let frame = buffer[..index].to_string(); + let remainder = buffer[index + 2..].to_string(); + *buffer = remainder; + return frame; + } + + let next_chunk = timeout(Duration::from_secs(5), response.chunk()) + .await + .expect("SSE stream should yield within timeout") + .expect("SSE stream should remain readable") + .expect("SSE stream should stay open"); + buffer.push_str(&String::from_utf8_lossy(&next_chunk)); + } + } + + #[tokio::test] + async fn creates_and_lists_sessions() { + let server = TestServer::spawn().await; + let client = Client::new(); + + // given + let created = create_session(&client, &server).await; + + // when + let sessions = client + .get(server.url("/sessions")) + .send() + .await + .expect("list request should succeed") + .error_for_status() + .expect("list request should return success") + .json::<ListSessionsResponse>() + .await + .expect("list response should parse"); + let details = client + .get(server.url(&format!("/sessions/{}", created.session_id))) + .send() + .await + .expect("details request should succeed") + .error_for_status() + .expect("details request should return success") + .json::<SessionDetailsResponse>() + .await + .expect("details response should parse"); + + // then + assert_eq!(created.session_id, "session-1"); + assert_eq!(sessions.sessions.len(), 1); + assert_eq!(sessions.sessions[0].id, created.session_id); + assert_eq!(sessions.sessions[0].message_count, 0); + assert_eq!(details.id, "session-1"); + assert!(details.session.messages.is_empty()); + } + + #[tokio::test] + async fn streams_message_events_and_persists_message_flow() { + let server = TestServer::spawn().await; + let client = Client::new(); + + // given + let created = create_session(&client, &server).await; + let mut response = client + .get(server.url(&format!("/sessions/{}/events", created.session_id))) + .send() + .await + .expect("events request should succeed") + .error_for_status() + .expect("events request should return success"); + let mut buffer = String::new(); + let snapshot_frame = next_sse_frame(&mut response, &mut buffer).await; + + // when + let send_status = client + .post(server.url(&format!("/sessions/{}/message", created.session_id))) + .json(&super::SendMessageRequest { + message: "hello from test".to_string(), + }) + .send() + .await + .expect("message request should succeed") + .status(); + let message_frame = next_sse_frame(&mut response, &mut buffer).await; + let details = client + .get(server.url(&format!("/sessions/{}", created.session_id))) + .send() + .await + .expect("details request should succeed") + .error_for_status() + .expect("details request should return success") + .json::<SessionDetailsResponse>() + .await + .expect("details response should parse"); + + // then + assert_eq!(send_status, reqwest::StatusCode::NO_CONTENT); + assert!(snapshot_frame.contains("event: snapshot")); + assert!(snapshot_frame.contains("\"session_id\":\"session-1\"")); + assert!(message_frame.contains("event: message")); + assert!(message_frame.contains("hello from test")); + assert_eq!(details.session.messages.len(), 1); + assert_eq!( + details.session.messages[0], + runtime::ConversationMessage::user_text("hello from test") + ); + } +} diff --git a/rust/crates/tools/Cargo.toml b/rust/crates/tools/Cargo.toml index 64768f4..04d738b 100644 --- a/rust/crates/tools/Cargo.toml +++ b/rust/crates/tools/Cargo.toml @@ -6,10 +6,13 @@ license.workspace = true publish.workspace = true [dependencies] +api = { path = "../api" } +plugins = { path = "../plugins" } runtime = { path = "../runtime" } reqwest = { version = "0.12", default-features = false, features = ["blocking", "rustls-tls"] } serde = { version = "1", features = ["derive"] } -serde_json = "1" +serde_json.workspace = true +tokio = { version = "1", features = ["rt-multi-thread"] } [lints] workspace = true diff --git a/rust/crates/tools/src/lib.rs b/rust/crates/tools/src/lib.rs index 091b256..4b42572 100644 --- a/rust/crates/tools/src/lib.rs +++ b/rust/crates/tools/src/lib.rs @@ -3,10 +3,18 @@ use std::path::{Path, PathBuf}; use std::process::Command; use std::time::{Duration, Instant}; +use api::{ + max_tokens_for_model, resolve_model_alias, ContentBlockDelta, InputContentBlock, InputMessage, + MessageRequest, MessageResponse, OutputContentBlock, ProviderClient, + StreamEvent as ApiStreamEvent, ToolChoice, ToolDefinition, ToolResultContentBlock, +}; +use plugins::PluginTool; use reqwest::blocking::Client; use runtime::{ - edit_file, execute_bash, glob_search, grep_search, read_file, write_file, BashCommandInput, - GrepSearchInput, PermissionMode, + edit_file, execute_bash, glob_search, grep_search, load_system_prompt, read_file, write_file, + ApiClient, ApiRequest, AssistantEvent, BashCommandInput, ContentBlock, ConversationMessage, + ConversationRuntime, GrepSearchInput, MessageRole, PermissionMode, PermissionPolicy, + RuntimeError, Session, TokenUsage, ToolError, ToolExecutor, }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; @@ -48,6 +56,161 @@ pub struct ToolSpec { pub required_permission: PermissionMode, } +#[derive(Debug, Clone, PartialEq)] +pub struct GlobalToolRegistry { + plugin_tools: Vec<PluginTool>, +} + +impl GlobalToolRegistry { + #[must_use] + pub fn builtin() -> Self { + Self { + plugin_tools: Vec::new(), + } + } + + pub fn with_plugin_tools(plugin_tools: Vec<PluginTool>) -> Result<Self, String> { + let builtin_names = mvp_tool_specs() + .into_iter() + .map(|spec| spec.name.to_string()) + .collect::<BTreeSet<_>>(); + let mut seen_plugin_names = BTreeSet::new(); + + for tool in &plugin_tools { + let name = tool.definition().name.clone(); + if builtin_names.contains(&name) { + return Err(format!( + "plugin tool `{name}` conflicts with a built-in tool name" + )); + } + if !seen_plugin_names.insert(name.clone()) { + return Err(format!("duplicate plugin tool name `{name}`")); + } + } + + Ok(Self { plugin_tools }) + } + + pub fn normalize_allowed_tools(&self, values: &[String]) -> Result<Option<BTreeSet<String>>, String> { + if values.is_empty() { + return Ok(None); + } + + let builtin_specs = mvp_tool_specs(); + let canonical_names = builtin_specs + .iter() + .map(|spec| spec.name.to_string()) + .chain(self.plugin_tools.iter().map(|tool| tool.definition().name.clone())) + .collect::<Vec<_>>(); + let mut name_map = canonical_names + .iter() + .map(|name| (normalize_tool_name(name), name.clone())) + .collect::<BTreeMap<_, _>>(); + + for (alias, canonical) in [ + ("read", "read_file"), + ("write", "write_file"), + ("edit", "edit_file"), + ("glob", "glob_search"), + ("grep", "grep_search"), + ] { + name_map.insert(alias.to_string(), canonical.to_string()); + } + + let mut allowed = BTreeSet::new(); + for value in values { + for token in value + .split(|ch: char| ch == ',' || ch.is_whitespace()) + .filter(|token| !token.is_empty()) + { + let normalized = normalize_tool_name(token); + let canonical = name_map.get(&normalized).ok_or_else(|| { + format!( + "unsupported tool in --allowedTools: {token} (expected one of: {})", + canonical_names.join(", ") + ) + })?; + allowed.insert(canonical.clone()); + } + } + + Ok(Some(allowed)) + } + + #[must_use] + pub fn definitions(&self, allowed_tools: Option<&BTreeSet<String>>) -> Vec<ToolDefinition> { + let builtin = mvp_tool_specs() + .into_iter() + .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name))) + .map(|spec| ToolDefinition { + name: spec.name.to_string(), + description: Some(spec.description.to_string()), + input_schema: spec.input_schema, + }); + let plugin = self + .plugin_tools + .iter() + .filter(|tool| { + allowed_tools.is_none_or(|allowed| allowed.contains(tool.definition().name.as_str())) + }) + .map(|tool| ToolDefinition { + name: tool.definition().name.clone(), + description: tool.definition().description.clone(), + input_schema: tool.definition().input_schema.clone(), + }); + builtin.chain(plugin).collect() + } + + #[must_use] + pub fn permission_specs( + &self, + allowed_tools: Option<&BTreeSet<String>>, + ) -> Vec<(String, PermissionMode)> { + let builtin = mvp_tool_specs() + .into_iter() + .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name))) + .map(|spec| (spec.name.to_string(), spec.required_permission)); + let plugin = self + .plugin_tools + .iter() + .filter(|tool| { + allowed_tools.is_none_or(|allowed| allowed.contains(tool.definition().name.as_str())) + }) + .map(|tool| { + ( + tool.definition().name.clone(), + permission_mode_from_plugin(tool.required_permission()), + ) + }); + builtin.chain(plugin).collect() + } + + pub fn execute(&self, name: &str, input: &Value) -> Result<String, String> { + if mvp_tool_specs().iter().any(|spec| spec.name == name) { + return execute_tool(name, input); + } + self.plugin_tools + .iter() + .find(|tool| tool.definition().name == name) + .ok_or_else(|| format!("unsupported tool: {name}"))? + .execute(input) + .map_err(|error| error.to_string()) + } +} + +fn normalize_tool_name(value: &str) -> String { + value.trim().replace('-', "_").to_ascii_lowercase() +} + +fn permission_mode_from_plugin(value: &str) -> PermissionMode { + match value { + "read-only" => PermissionMode::ReadOnly, + "workspace-write" => PermissionMode::WorkspaceWrite, + "danger-full-access" => PermissionMode::DangerFullAccess, + other => panic!("unsupported plugin permission: {other}"), + } +} + #[must_use] #[allow(clippy::too_many_lines)] pub fn mvp_tool_specs() -> Vec<ToolSpec> { @@ -316,7 +479,7 @@ pub fn mvp_tool_specs() -> Vec<ToolSpec> { }, ToolSpec { name: "Config", - description: "Get or set Claude Code settings.", + description: "Get or set Claw Code settings.", input_schema: json!({ "type": "object", "properties": { @@ -702,7 +865,7 @@ struct SkillOutput { prompt: String, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] struct AgentOutput { #[serde(rename = "agentId")] agent_id: String, @@ -718,6 +881,20 @@ struct AgentOutput { manifest_file: String, #[serde(rename = "createdAt")] created_at: String, + #[serde(rename = "startedAt", skip_serializing_if = "Option::is_none")] + started_at: Option<String>, + #[serde(rename = "completedAt", skip_serializing_if = "Option::is_none")] + completed_at: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option<String>, +} + +#[derive(Debug, Clone)] +struct AgentJob { + manifest: AgentOutput, + prompt: String, + system_prompt: Vec<String>, + allowed_tools: BTreeSet<String>, } #[derive(Debug, Serialize)] @@ -904,7 +1081,7 @@ fn build_http_client() -> Result<Client, String> { Client::builder() .timeout(Duration::from_secs(20)) .redirect(reqwest::redirect::Policy::limited(10)) - .user_agent("clawd-rust-tools/0.1") + .user_agent("claw-rust-tools/0.1") .build() .map_err(|error| error.to_string()) } @@ -925,7 +1102,7 @@ fn normalize_fetch_url(url: &str) -> Result<String, String> { } fn build_search_url(query: &str) -> Result<reqwest::Url, String> { - if let Ok(base) = std::env::var("CLAWD_WEB_SEARCH_BASE_URL") { + if let Ok(base) = std::env::var("CLAW_WEB_SEARCH_BASE_URL") { let mut url = reqwest::Url::parse(&base).map_err(|error| error.to_string())?; url.query_pairs_mut().append_pair("q", query); return Ok(url); @@ -1259,15 +1436,7 @@ fn validate_todos(todos: &[TodoItem]) -> Result<(), String> { if todos.is_empty() { return Err(String::from("todos must not be empty")); } - let in_progress = todos - .iter() - .filter(|todo| matches!(todo.status, TodoStatus::InProgress)) - .count(); - if in_progress > 1 { - return Err(String::from( - "exactly zero or one todo items may be in_progress", - )); - } + // Allow multiple in_progress items for parallel workflows if todos.iter().any(|todo| todo.content.trim().is_empty()) { return Err(String::from("todo content must not be empty")); } @@ -1278,11 +1447,11 @@ fn validate_todos(todos: &[TodoItem]) -> Result<(), String> { } fn todo_store_path() -> Result<std::path::PathBuf, String> { - if let Ok(path) = std::env::var("CLAWD_TODO_STORE") { + if let Ok(path) = std::env::var("CLAW_TODO_STORE") { return Ok(std::path::PathBuf::from(path)); } let cwd = std::env::current_dir().map_err(|error| error.to_string())?; - Ok(cwd.join(".clawd-todos.json")) + Ok(cwd.join(".claw-todos.json")) } fn resolve_skill_path(skill: &str) -> Result<std::path::PathBuf, String> { @@ -1295,6 +1464,12 @@ fn resolve_skill_path(skill: &str) -> Result<std::path::PathBuf, String> { if let Ok(codex_home) = std::env::var("CODEX_HOME") { candidates.push(std::path::PathBuf::from(codex_home).join("skills")); } + if let Ok(home) = std::env::var("HOME") { + let home = std::path::PathBuf::from(home); + candidates.push(home.join(".agents").join("skills")); + candidates.push(home.join(".config").join("opencode").join("skills")); + candidates.push(home.join(".codex").join("skills")); + } candidates.push(std::path::PathBuf::from("/home/bellman/.codex/skills")); for root in candidates { @@ -1323,7 +1498,18 @@ fn resolve_skill_path(skill: &str) -> Result<std::path::PathBuf, String> { Err(format!("unknown skill: {requested}")) } +const DEFAULT_AGENT_MODEL: &str = "claude-opus-4-6"; +const DEFAULT_AGENT_SYSTEM_DATE: &str = "2026-03-31"; +const DEFAULT_AGENT_MAX_ITERATIONS: usize = 32; + fn execute_agent(input: AgentInput) -> Result<AgentOutput, String> { + execute_agent_with_spawn(input, spawn_agent_job) +} + +fn execute_agent_with_spawn<F>(input: AgentInput, spawn_fn: F) -> Result<AgentOutput, String> +where + F: FnOnce(AgentJob) -> Result<(), String>, +{ if input.description.trim().is_empty() { return Err(String::from("description must not be empty")); } @@ -1337,6 +1523,7 @@ fn execute_agent(input: AgentInput) -> Result<AgentOutput, String> { let output_file = output_dir.join(format!("{agent_id}.md")); let manifest_file = output_dir.join(format!("{agent_id}.json")); let normalized_subagent_type = normalize_subagent_type(input.subagent_type.as_deref()); + let model = resolve_agent_model(input.model.as_deref()); let agent_name = input .name .as_deref() @@ -1344,6 +1531,8 @@ fn execute_agent(input: AgentInput) -> Result<AgentOutput, String> { .filter(|name| !name.is_empty()) .unwrap_or_else(|| slugify_agent_name(&input.description)); let created_at = iso8601_now(); + let system_prompt = build_agent_system_prompt(&normalized_subagent_type)?; + let allowed_tools = allowed_tools_for_subagent(&normalized_subagent_type); let output_contents = format!( "# Agent Task @@ -1367,21 +1556,519 @@ fn execute_agent(input: AgentInput) -> Result<AgentOutput, String> { name: agent_name, description: input.description, subagent_type: Some(normalized_subagent_type), - model: input.model, - status: String::from("queued"), + model: Some(model), + status: String::from("running"), output_file: output_file.display().to_string(), manifest_file: manifest_file.display().to_string(), - created_at, + created_at: created_at.clone(), + started_at: Some(created_at), + completed_at: None, + error: None, }; - std::fs::write( - &manifest_file, - serde_json::to_string_pretty(&manifest).map_err(|error| error.to_string())?, - ) - .map_err(|error| error.to_string())?; + write_agent_manifest(&manifest)?; + + let manifest_for_spawn = manifest.clone(); + let job = AgentJob { + manifest: manifest_for_spawn, + prompt: input.prompt, + system_prompt, + allowed_tools, + }; + if let Err(error) = spawn_fn(job) { + let error = format!("failed to spawn sub-agent: {error}"); + persist_agent_terminal_state(&manifest, "failed", None, Some(error.clone()))?; + return Err(error); + } Ok(manifest) } +fn spawn_agent_job(job: AgentJob) -> Result<(), String> { + let thread_name = format!("claw-agent-{}", job.manifest.agent_id); + std::thread::Builder::new() + .name(thread_name) + .spawn(move || { + let result = + std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| run_agent_job(&job))); + match result { + Ok(Ok(())) => {} + Ok(Err(error)) => { + let _ = + persist_agent_terminal_state(&job.manifest, "failed", None, Some(error)); + } + Err(_) => { + let _ = persist_agent_terminal_state( + &job.manifest, + "failed", + None, + Some(String::from("sub-agent thread panicked")), + ); + } + } + }) + .map(|_| ()) + .map_err(|error| error.to_string()) +} + +fn run_agent_job(job: &AgentJob) -> Result<(), String> { + let mut runtime = build_agent_runtime(job)?.with_max_iterations(DEFAULT_AGENT_MAX_ITERATIONS); + let summary = runtime + .run_turn(job.prompt.clone(), None) + .map_err(|error| error.to_string())?; + let final_text = final_assistant_text(&summary); + persist_agent_terminal_state(&job.manifest, "completed", Some(final_text.as_str()), None) +} + +fn build_agent_runtime( + job: &AgentJob, +) -> Result<ConversationRuntime<ProviderRuntimeClient, SubagentToolExecutor>, String> { + let model = job + .manifest + .model + .clone() + .unwrap_or_else(|| DEFAULT_AGENT_MODEL.to_string()); + let allowed_tools = job.allowed_tools.clone(); + let api_client = ProviderRuntimeClient::new(model, allowed_tools.clone())?; + let tool_executor = SubagentToolExecutor::new(allowed_tools); + Ok(ConversationRuntime::new( + Session::new(), + api_client, + tool_executor, + agent_permission_policy(), + job.system_prompt.clone(), + )) +} + +fn build_agent_system_prompt(subagent_type: &str) -> Result<Vec<String>, String> { + let cwd = std::env::current_dir().map_err(|error| error.to_string())?; + let mut prompt = load_system_prompt( + cwd, + DEFAULT_AGENT_SYSTEM_DATE.to_string(), + std::env::consts::OS, + "unknown", + ) + .map_err(|error| error.to_string())?; + prompt.push(format!( + "You are a background sub-agent of type `{subagent_type}`. Work only on the delegated task, use only the tools available to you, do not ask the user questions, and finish with a concise result." + )); + Ok(prompt) +} + +fn resolve_agent_model(model: Option<&str>) -> String { + model + .map(str::trim) + .filter(|model| !model.is_empty()) + .unwrap_or(DEFAULT_AGENT_MODEL) + .to_string() +} + +fn allowed_tools_for_subagent(subagent_type: &str) -> BTreeSet<String> { + let tools = match subagent_type { + "Explore" => vec![ + "read_file", + "glob_search", + "grep_search", + "WebFetch", + "WebSearch", + "ToolSearch", + "Skill", + "StructuredOutput", + ], + "Plan" => vec![ + "read_file", + "glob_search", + "grep_search", + "WebFetch", + "WebSearch", + "ToolSearch", + "Skill", + "TodoWrite", + "StructuredOutput", + "SendUserMessage", + ], + "Verification" => vec![ + "bash", + "read_file", + "glob_search", + "grep_search", + "WebFetch", + "WebSearch", + "ToolSearch", + "TodoWrite", + "StructuredOutput", + "SendUserMessage", + "PowerShell", + ], + "claw-guide" => vec![ + "read_file", + "glob_search", + "grep_search", + "WebFetch", + "WebSearch", + "ToolSearch", + "Skill", + "StructuredOutput", + "SendUserMessage", + ], + "statusline-setup" => vec![ + "bash", + "read_file", + "write_file", + "edit_file", + "glob_search", + "grep_search", + "ToolSearch", + ], + _ => vec![ + "bash", + "read_file", + "write_file", + "edit_file", + "glob_search", + "grep_search", + "WebFetch", + "WebSearch", + "TodoWrite", + "Skill", + "ToolSearch", + "NotebookEdit", + "Sleep", + "SendUserMessage", + "Config", + "StructuredOutput", + "REPL", + "PowerShell", + ], + }; + tools.into_iter().map(str::to_string).collect() +} + +fn agent_permission_policy() -> PermissionPolicy { + mvp_tool_specs().into_iter().fold( + PermissionPolicy::new(PermissionMode::DangerFullAccess), + |policy, spec| policy.with_tool_requirement(spec.name, spec.required_permission), + ) +} + +fn write_agent_manifest(manifest: &AgentOutput) -> Result<(), String> { + std::fs::write( + &manifest.manifest_file, + serde_json::to_string_pretty(manifest).map_err(|error| error.to_string())?, + ) + .map_err(|error| error.to_string()) +} + +fn persist_agent_terminal_state( + manifest: &AgentOutput, + status: &str, + result: Option<&str>, + error: Option<String>, +) -> Result<(), String> { + append_agent_output( + &manifest.output_file, + &format_agent_terminal_output(status, result, error.as_deref()), + )?; + let mut next_manifest = manifest.clone(); + next_manifest.status = status.to_string(); + next_manifest.completed_at = Some(iso8601_now()); + next_manifest.error = error; + write_agent_manifest(&next_manifest) +} + +fn append_agent_output(path: &str, suffix: &str) -> Result<(), String> { + use std::io::Write as _; + + let mut file = std::fs::OpenOptions::new() + .append(true) + .open(path) + .map_err(|error| error.to_string())?; + file.write_all(suffix.as_bytes()) + .map_err(|error| error.to_string()) +} + +fn format_agent_terminal_output(status: &str, result: Option<&str>, error: Option<&str>) -> String { + let mut sections = vec![format!("\n## Result\n\n- status: {status}\n")]; + if let Some(result) = result.filter(|value| !value.trim().is_empty()) { + sections.push(format!("\n### Final response\n\n{}\n", result.trim())); + } + if let Some(error) = error.filter(|value| !value.trim().is_empty()) { + sections.push(format!("\n### Error\n\n{}\n", error.trim())); + } + sections.join("") +} + +struct ProviderRuntimeClient { + runtime: tokio::runtime::Runtime, + client: ProviderClient, + model: String, + allowed_tools: BTreeSet<String>, +} + +impl ProviderRuntimeClient { + fn new(model: String, allowed_tools: BTreeSet<String>) -> Result<Self, String> { + let model = resolve_model_alias(&model).to_string(); + let client = ProviderClient::from_model(&model).map_err(|error| error.to_string())?; + Ok(Self { + runtime: tokio::runtime::Runtime::new().map_err(|error| error.to_string())?, + client, + model, + allowed_tools, + }) + } +} + +impl ApiClient for ProviderRuntimeClient { + fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { + let tools = tool_specs_for_allowed_tools(Some(&self.allowed_tools)) + .into_iter() + .map(|spec| ToolDefinition { + name: spec.name.to_string(), + description: Some(spec.description.to_string()), + input_schema: spec.input_schema, + }) + .collect::<Vec<_>>(); + let message_request = MessageRequest { + model: self.model.clone(), + max_tokens: max_tokens_for_model(&self.model), + messages: convert_messages(&request.messages), + system: (!request.system_prompt.is_empty()).then(|| request.system_prompt.join("\n\n")), + tools: (!tools.is_empty()).then_some(tools), + tool_choice: (!self.allowed_tools.is_empty()).then_some(ToolChoice::Auto), + stream: true, + }; + + self.runtime.block_on(async { + let mut stream = self + .client + .stream_message(&message_request) + .await + .map_err(|error| RuntimeError::new(error.to_string()))?; + let mut events = Vec::new(); + let mut pending_tools: BTreeMap<u32, (String, String, String)> = BTreeMap::new(); + let mut saw_stop = false; + + while let Some(event) = stream + .next_event() + .await + .map_err(|error| RuntimeError::new(error.to_string()))? + { + match event { + ApiStreamEvent::MessageStart(start) => { + for block in start.message.content { + push_output_block(block, 0, &mut events, &mut pending_tools, true); + } + } + ApiStreamEvent::ContentBlockStart(start) => { + push_output_block( + start.content_block, + start.index, + &mut events, + &mut pending_tools, + true, + ); + } + ApiStreamEvent::ContentBlockDelta(delta) => match delta.delta { + ContentBlockDelta::TextDelta { text } => { + if !text.is_empty() { + events.push(AssistantEvent::TextDelta(text)); + } + } + ContentBlockDelta::InputJsonDelta { partial_json } => { + if let Some((_, _, input)) = pending_tools.get_mut(&delta.index) { + input.push_str(&partial_json); + } + } + ContentBlockDelta::ThinkingDelta { .. } + | ContentBlockDelta::SignatureDelta { .. } => {} + }, + ApiStreamEvent::ContentBlockStop(stop) => { + if let Some((id, name, input)) = pending_tools.remove(&stop.index) { + events.push(AssistantEvent::ToolUse { id, name, input }); + } + } + ApiStreamEvent::MessageDelta(delta) => { + events.push(AssistantEvent::Usage(TokenUsage { + input_tokens: delta.usage.input_tokens, + output_tokens: delta.usage.output_tokens, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + })); + } + ApiStreamEvent::MessageStop(_) => { + saw_stop = true; + events.push(AssistantEvent::MessageStop); + } + } + } + + if !saw_stop + && events.iter().any(|event| { + matches!(event, AssistantEvent::TextDelta(text) if !text.is_empty()) + || matches!(event, AssistantEvent::ToolUse { .. }) + }) + { + events.push(AssistantEvent::MessageStop); + } + + if events + .iter() + .any(|event| matches!(event, AssistantEvent::MessageStop)) + { + return Ok(events); + } + + let response = self + .client + .send_message(&MessageRequest { + stream: false, + ..message_request.clone() + }) + .await + .map_err(|error| RuntimeError::new(error.to_string()))?; + Ok(response_to_events(response)) + }) + } +} + +struct SubagentToolExecutor { + allowed_tools: BTreeSet<String>, +} + +impl SubagentToolExecutor { + fn new(allowed_tools: BTreeSet<String>) -> Self { + Self { allowed_tools } + } +} + +impl ToolExecutor for SubagentToolExecutor { + fn execute(&mut self, tool_name: &str, input: &str) -> Result<String, ToolError> { + if !self.allowed_tools.contains(tool_name) { + return Err(ToolError::new(format!( + "tool `{tool_name}` is not enabled for this sub-agent" + ))); + } + let value = serde_json::from_str(input) + .map_err(|error| ToolError::new(format!("invalid tool input JSON: {error}")))?; + execute_tool(tool_name, &value).map_err(ToolError::new) + } +} + +fn tool_specs_for_allowed_tools(allowed_tools: Option<&BTreeSet<String>>) -> Vec<ToolSpec> { + mvp_tool_specs() + .into_iter() + .filter(|spec| allowed_tools.is_none_or(|allowed| allowed.contains(spec.name))) + .collect() +} + +fn convert_messages(messages: &[ConversationMessage]) -> Vec<InputMessage> { + messages + .iter() + .filter_map(|message| { + let role = match message.role { + MessageRole::System | MessageRole::User | MessageRole::Tool => "user", + MessageRole::Assistant => "assistant", + }; + let content = message + .blocks + .iter() + .map(|block| match block { + ContentBlock::Text { text } => InputContentBlock::Text { text: text.clone() }, + ContentBlock::ToolUse { id, name, input } => InputContentBlock::ToolUse { + id: id.clone(), + name: name.clone(), + input: serde_json::from_str(input) + .unwrap_or_else(|_| serde_json::json!({ "raw": input })), + }, + ContentBlock::ToolResult { + tool_use_id, + output, + is_error, + .. + } => InputContentBlock::ToolResult { + tool_use_id: tool_use_id.clone(), + content: vec![ToolResultContentBlock::Text { + text: output.clone(), + }], + is_error: *is_error, + }, + }) + .collect::<Vec<_>>(); + (!content.is_empty()).then(|| InputMessage { + role: role.to_string(), + content, + }) + }) + .collect() +} + +fn push_output_block( + block: OutputContentBlock, + block_index: u32, + events: &mut Vec<AssistantEvent>, + pending_tools: &mut BTreeMap<u32, (String, String, String)>, + streaming_tool_input: bool, +) { + match block { + OutputContentBlock::Text { text } => { + if !text.is_empty() { + events.push(AssistantEvent::TextDelta(text)); + } + } + OutputContentBlock::ToolUse { id, name, input } => { + let initial_input = if streaming_tool_input + && input.is_object() + && input.as_object().is_some_and(serde_json::Map::is_empty) + { + String::new() + } else { + input.to_string() + }; + pending_tools.insert(block_index, (id, name, initial_input)); + } + OutputContentBlock::Thinking { .. } | OutputContentBlock::RedactedThinking { .. } => {} + } +} + +fn response_to_events(response: MessageResponse) -> Vec<AssistantEvent> { + let mut events = Vec::new(); + let mut pending_tools = BTreeMap::new(); + + for (index, block) in response.content.into_iter().enumerate() { + let index = u32::try_from(index).expect("response block index overflow"); + push_output_block(block, index, &mut events, &mut pending_tools, false); + if let Some((id, name, input)) = pending_tools.remove(&index) { + events.push(AssistantEvent::ToolUse { id, name, input }); + } + } + + events.push(AssistantEvent::Usage(TokenUsage { + input_tokens: response.usage.input_tokens, + output_tokens: response.usage.output_tokens, + cache_creation_input_tokens: response.usage.cache_creation_input_tokens, + cache_read_input_tokens: response.usage.cache_read_input_tokens, + })); + events.push(AssistantEvent::MessageStop); + events +} + +fn final_assistant_text(summary: &runtime::TurnSummary) -> String { + summary + .assistant_messages + .last() + .map(|message| { + message + .blocks + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::<Vec<_>>() + .join("") + }) + .unwrap_or_default() +} + #[allow(clippy::needless_pass_by_value)] fn execute_tool_search(input: ToolSearchInput) -> ToolSearchOutput { let deferred = deferred_tool_specs(); @@ -1519,14 +2206,14 @@ fn canonical_tool_token(value: &str) -> String { } fn agent_store_dir() -> Result<std::path::PathBuf, String> { - if let Ok(path) = std::env::var("CLAWD_AGENT_STORE") { + if let Ok(path) = std::env::var("CLAW_AGENT_STORE") { return Ok(std::path::PathBuf::from(path)); } let cwd = std::env::current_dir().map_err(|error| error.to_string())?; if let Some(workspace_root) = cwd.ancestors().nth(2) { - return Ok(workspace_root.join(".clawd-agents")); + return Ok(workspace_root.join(".claw-agents")); } - Ok(cwd.join(".clawd-agents")) + Ok(cwd.join(".claw-agents")) } fn make_agent_id() -> String { @@ -1567,7 +2254,7 @@ fn normalize_subagent_type(subagent_type: Option<&str>) -> String { "verification" | "verificationagent" | "verify" | "verifier" => { String::from("Verification") } - "claudecodeguide" | "claudecodeguideagent" | "guide" => String::from("claude-code-guide"), + "clawguide" | "clawguideagent" | "guide" => String::from("claw-guide"), "statusline" | "statuslinesetup" => String::from("statusline-setup"), _ => trimmed.to_string(), } @@ -2067,16 +2754,16 @@ fn config_file_for_scope(scope: ConfigScope) -> Result<PathBuf, String> { let cwd = std::env::current_dir().map_err(|error| error.to_string())?; Ok(match scope { ConfigScope::Global => config_home_dir()?.join("settings.json"), - ConfigScope::Settings => cwd.join(".claude").join("settings.local.json"), + ConfigScope::Settings => cwd.join(".claw").join("settings.local.json"), }) } fn config_home_dir() -> Result<PathBuf, String> { - if let Ok(path) = std::env::var("CLAUDE_CONFIG_HOME") { + if let Ok(path) = std::env::var("CLAW_CONFIG_HOME") { return Ok(PathBuf::from(path)); } let home = std::env::var("HOME").map_err(|_| String::from("HOME is not set"))?; - Ok(PathBuf::from(home).join(".claude")) + Ok(PathBuf::from(home).join(".claw")) } fn read_json_object(path: &Path) -> Result<serde_json::Map<String, Value>, String> { @@ -2215,7 +2902,7 @@ fn execute_shell_command( persisted_output_path: None, persisted_output_size: None, sandbox_status: None, -}); + }); } let mut process = std::process::Command::new(shell); @@ -2284,7 +2971,7 @@ Command exceeded timeout of {timeout_ms} ms", persisted_output_path: None, persisted_output_size: None, sandbox_status: None, -}); + }); } std::thread::sleep(Duration::from_millis(10)); } @@ -2373,6 +3060,8 @@ fn parse_skill_description(contents: &str) -> Option<String> { #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use std::collections::BTreeSet; use std::fs; use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener}; @@ -2381,7 +3070,13 @@ mod tests { use std::thread; use std::time::Duration; - use super::{execute_tool, mvp_tool_specs}; + use super::{ + agent_permission_policy, allowed_tools_for_subagent, execute_agent_with_spawn, + execute_tool, final_assistant_text, mvp_tool_specs, persist_agent_terminal_state, + push_output_block, AgentInput, AgentJob, SubagentToolExecutor, + }; + use api::OutputContentBlock; + use runtime::{ApiRequest, AssistantEvent, ConversationRuntime, RuntimeError, Session}; use serde_json::json; fn env_lock() -> &'static Mutex<()> { @@ -2394,7 +3089,7 @@ mod tests { .duration_since(std::time::UNIX_EPOCH) .expect("time") .as_nanos(); - std::env::temp_dir().join(format!("clawd-tools-{unique}-{name}")) + std::env::temp_dir().join(format!("claw-tools-{unique}-{name}")) } #[test] @@ -2517,7 +3212,7 @@ mod tests { })); std::env::set_var( - "CLAWD_WEB_SEARCH_BASE_URL", + "CLAW_WEB_SEARCH_BASE_URL", format!("http://{}/search", server.addr()), ); let result = execute_tool( @@ -2529,7 +3224,7 @@ mod tests { }), ) .expect("WebSearch should succeed"); - std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL"); + std::env::remove_var("CLAW_WEB_SEARCH_BASE_URL"); let output: serde_json::Value = serde_json::from_str(&result).expect("valid json"); assert_eq!(output["query"], "rust web search"); @@ -2565,7 +3260,7 @@ mod tests { })); std::env::set_var( - "CLAWD_WEB_SEARCH_BASE_URL", + "CLAW_WEB_SEARCH_BASE_URL", format!("http://{}/fallback", server.addr()), ); let result = execute_tool( @@ -2575,7 +3270,7 @@ mod tests { }), ) .expect("WebSearch fallback parsing should succeed"); - std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL"); + std::env::remove_var("CLAW_WEB_SEARCH_BASE_URL"); let output: serde_json::Value = serde_json::from_str(&result).expect("valid json"); let results = output["results"].as_array().expect("results array"); @@ -2588,20 +3283,77 @@ mod tests { assert_eq!(content[0]["url"], "https://example.com/one"); assert_eq!(content[1]["url"], "https://docs.rs/tokio"); - std::env::set_var("CLAWD_WEB_SEARCH_BASE_URL", "://bad-base-url"); + std::env::set_var("CLAW_WEB_SEARCH_BASE_URL", "://bad-base-url"); let error = execute_tool("WebSearch", &json!({ "query": "generic links" })) .expect_err("invalid base URL should fail"); - std::env::remove_var("CLAWD_WEB_SEARCH_BASE_URL"); + std::env::remove_var("CLAW_WEB_SEARCH_BASE_URL"); assert!(error.contains("relative URL without a base") || error.contains("empty host")); } + #[test] + fn pending_tools_preserve_multiple_streaming_tool_calls_by_index() { + let mut events = Vec::new(); + let mut pending_tools = BTreeMap::new(); + + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({}), + }, + 1, + &mut events, + &mut pending_tools, + true, + ); + push_output_block( + OutputContentBlock::ToolUse { + id: "tool-2".to_string(), + name: "grep_search".to_string(), + input: json!({}), + }, + 2, + &mut events, + &mut pending_tools, + true, + ); + + pending_tools + .get_mut(&1) + .expect("first tool pending") + .2 + .push_str("{\"path\":\"src/main.rs\"}"); + pending_tools + .get_mut(&2) + .expect("second tool pending") + .2 + .push_str("{\"pattern\":\"TODO\"}"); + + assert_eq!( + pending_tools.remove(&1), + Some(( + "tool-1".to_string(), + "read_file".to_string(), + "{\"path\":\"src/main.rs\"}".to_string(), + )) + ); + assert_eq!( + pending_tools.remove(&2), + Some(( + "tool-2".to_string(), + "grep_search".to_string(), + "{\"pattern\":\"TODO\"}".to_string(), + )) + ); + } + #[test] fn todo_write_persists_and_returns_previous_state() { let _guard = env_lock() .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let path = temp_path("todos.json"); - std::env::set_var("CLAWD_TODO_STORE", &path); + std::env::set_var("CLAW_TODO_STORE", &path); let first = execute_tool( "TodoWrite", @@ -2627,7 +3379,7 @@ mod tests { }), ) .expect("TodoWrite should succeed"); - std::env::remove_var("CLAWD_TODO_STORE"); + std::env::remove_var("CLAW_TODO_STORE"); let _ = std::fs::remove_file(path); let second_output: serde_json::Value = serde_json::from_str(&second).expect("valid json"); @@ -2648,13 +3400,14 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let path = temp_path("todos-errors.json"); - std::env::set_var("CLAWD_TODO_STORE", &path); + std::env::set_var("CLAW_TODO_STORE", &path); let empty = execute_tool("TodoWrite", &json!({ "todos": [] })) .expect_err("empty todos should fail"); assert!(empty.contains("todos must not be empty")); - let too_many_active = execute_tool( + // Multiple in_progress items are now allowed for parallel workflows + let _multi_active = execute_tool( "TodoWrite", &json!({ "todos": [ @@ -2663,8 +3416,7 @@ mod tests { ] }), ) - .expect_err("multiple in-progress todos should fail"); - assert!(too_many_active.contains("zero or one todo items may be in_progress")); + .expect("multiple in-progress todos should succeed"); let blank_content = execute_tool( "TodoWrite", @@ -2688,7 +3440,7 @@ mod tests { }), ) .expect("completed todos should succeed"); - std::env::remove_var("CLAWD_TODO_STORE"); + std::env::remove_var("CLAW_TODO_STORE"); let _ = fs::remove_file(path); let output: serde_json::Value = serde_json::from_str(&nudge).expect("valid json"); @@ -2697,6 +3449,9 @@ mod tests { #[test] fn skill_loads_local_skill_prompt() { + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); let result = execute_tool( "Skill", &json!({ @@ -2772,33 +3527,49 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let dir = temp_path("agent-store"); - std::env::set_var("CLAWD_AGENT_STORE", &dir); + std::env::set_var("CLAW_AGENT_STORE", &dir); + let captured = Arc::new(Mutex::new(None::<AgentJob>)); + let captured_for_spawn = Arc::clone(&captured); - let result = execute_tool( - "Agent", - &json!({ - "description": "Audit the branch", - "prompt": "Check tests and outstanding work.", - "subagent_type": "Explore", - "name": "ship-audit" - }), + let manifest = execute_agent_with_spawn( + AgentInput { + description: "Audit the branch".to_string(), + prompt: "Check tests and outstanding work.".to_string(), + subagent_type: Some("Explore".to_string()), + name: Some("ship-audit".to_string()), + model: None, + }, + move |job| { + *captured_for_spawn + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(job); + Ok(()) + }, ) .expect("Agent should succeed"); - std::env::remove_var("CLAWD_AGENT_STORE"); + std::env::remove_var("CLAW_AGENT_STORE"); - let output: serde_json::Value = serde_json::from_str(&result).expect("valid json"); - assert_eq!(output["name"], "ship-audit"); - assert_eq!(output["subagentType"], "Explore"); - assert_eq!(output["status"], "queued"); - assert!(output["createdAt"].as_str().is_some()); - let manifest_file = output["manifestFile"].as_str().expect("manifest file"); - let output_file = output["outputFile"].as_str().expect("output file"); - let contents = std::fs::read_to_string(output_file).expect("agent file exists"); + assert_eq!(manifest.name, "ship-audit"); + assert_eq!(manifest.subagent_type.as_deref(), Some("Explore")); + assert_eq!(manifest.status, "running"); + assert!(!manifest.created_at.is_empty()); + assert!(manifest.started_at.is_some()); + assert!(manifest.completed_at.is_none()); + let contents = std::fs::read_to_string(&manifest.output_file).expect("agent file exists"); let manifest_contents = - std::fs::read_to_string(manifest_file).expect("manifest file exists"); + std::fs::read_to_string(&manifest.manifest_file).expect("manifest file exists"); assert!(contents.contains("Audit the branch")); assert!(contents.contains("Check tests and outstanding work.")); assert!(manifest_contents.contains("\"subagentType\": \"Explore\"")); + assert!(manifest_contents.contains("\"status\": \"running\"")); + let captured_job = captured + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + .expect("spawn job should be captured"); + assert_eq!(captured_job.prompt, "Check tests and outstanding work."); + assert!(captured_job.allowed_tools.contains("read_file")); + assert!(!captured_job.allowed_tools.contains("Agent")); let normalized = execute_tool( "Agent", @@ -2827,6 +3598,195 @@ mod tests { let _ = std::fs::remove_dir_all(dir); } + #[test] + fn agent_fake_runner_can_persist_completion_and_failure() { + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let dir = temp_path("agent-runner"); + std::env::set_var("CLAW_AGENT_STORE", &dir); + + let completed = execute_agent_with_spawn( + AgentInput { + description: "Complete the task".to_string(), + prompt: "Do the work".to_string(), + subagent_type: Some("Explore".to_string()), + name: Some("complete-task".to_string()), + model: Some("claude-sonnet-4-6".to_string()), + }, + |job| { + persist_agent_terminal_state( + &job.manifest, + "completed", + Some("Finished successfully"), + None, + ) + }, + ) + .expect("completed agent should succeed"); + + let completed_manifest = std::fs::read_to_string(&completed.manifest_file) + .expect("completed manifest should exist"); + let completed_output = + std::fs::read_to_string(&completed.output_file).expect("completed output should exist"); + assert!(completed_manifest.contains("\"status\": \"completed\"")); + assert!(completed_output.contains("Finished successfully")); + + let failed = execute_agent_with_spawn( + AgentInput { + description: "Fail the task".to_string(), + prompt: "Do the failing work".to_string(), + subagent_type: Some("Verification".to_string()), + name: Some("fail-task".to_string()), + model: None, + }, + |job| { + persist_agent_terminal_state( + &job.manifest, + "failed", + None, + Some(String::from("simulated failure")), + ) + }, + ) + .expect("failed agent should still spawn"); + + let failed_manifest = + std::fs::read_to_string(&failed.manifest_file).expect("failed manifest should exist"); + let failed_output = + std::fs::read_to_string(&failed.output_file).expect("failed output should exist"); + assert!(failed_manifest.contains("\"status\": \"failed\"")); + assert!(failed_manifest.contains("simulated failure")); + assert!(failed_output.contains("simulated failure")); + + let spawn_error = execute_agent_with_spawn( + AgentInput { + description: "Spawn error task".to_string(), + prompt: "Never starts".to_string(), + subagent_type: None, + name: Some("spawn-error".to_string()), + model: None, + }, + |_| Err(String::from("thread creation failed")), + ) + .expect_err("spawn errors should surface"); + assert!(spawn_error.contains("failed to spawn sub-agent")); + let spawn_error_manifest = std::fs::read_dir(&dir) + .expect("agent dir should exist") + .filter_map(Result::ok) + .map(|entry| entry.path()) + .filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json")) + .find_map(|path| { + let contents = std::fs::read_to_string(&path).ok()?; + contents + .contains("\"name\": \"spawn-error\"") + .then_some(contents) + }) + .expect("failed manifest should still be written"); + assert!(spawn_error_manifest.contains("\"status\": \"failed\"")); + assert!(spawn_error_manifest.contains("thread creation failed")); + + std::env::remove_var("CLAW_AGENT_STORE"); + let _ = std::fs::remove_dir_all(dir); + } + + #[test] + fn agent_tool_subset_mapping_is_expected() { + let general = allowed_tools_for_subagent("general-purpose"); + assert!(general.contains("bash")); + assert!(general.contains("write_file")); + assert!(!general.contains("Agent")); + + let explore = allowed_tools_for_subagent("Explore"); + assert!(explore.contains("read_file")); + assert!(explore.contains("grep_search")); + assert!(!explore.contains("bash")); + + let plan = allowed_tools_for_subagent("Plan"); + assert!(plan.contains("TodoWrite")); + assert!(plan.contains("StructuredOutput")); + assert!(!plan.contains("Agent")); + + let verification = allowed_tools_for_subagent("Verification"); + assert!(verification.contains("bash")); + assert!(verification.contains("PowerShell")); + assert!(!verification.contains("write_file")); + } + + #[derive(Debug)] + struct MockSubagentApiClient { + calls: usize, + input_path: String, + } + + impl runtime::ApiClient for MockSubagentApiClient { + fn stream(&mut self, request: ApiRequest) -> Result<Vec<AssistantEvent>, RuntimeError> { + self.calls += 1; + match self.calls { + 1 => { + assert_eq!(request.messages.len(), 1); + Ok(vec![ + AssistantEvent::ToolUse { + id: "tool-1".to_string(), + name: "read_file".to_string(), + input: json!({ "path": self.input_path }).to_string(), + }, + AssistantEvent::MessageStop, + ]) + } + 2 => { + assert!(request.messages.len() >= 3); + Ok(vec![ + AssistantEvent::TextDelta("Scope: completed mock review".to_string()), + AssistantEvent::MessageStop, + ]) + } + _ => panic!("unexpected mock stream call"), + } + } + } + + #[test] + fn subagent_runtime_executes_tool_loop_with_isolated_session() { + let _guard = env_lock() + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let path = temp_path("subagent-input.txt"); + std::fs::write(&path, "hello from child").expect("write input file"); + + let mut runtime = ConversationRuntime::new( + Session::new(), + MockSubagentApiClient { + calls: 0, + input_path: path.display().to_string(), + }, + SubagentToolExecutor::new(BTreeSet::from([String::from("read_file")])), + agent_permission_policy(), + vec![String::from("system prompt")], + ); + + let summary = runtime + .run_turn("Inspect the delegated file", None) + .expect("subagent loop should succeed"); + + assert_eq!( + final_assistant_text(&summary), + "Scope: completed mock review" + ); + assert!(runtime + .session() + .messages + .iter() + .flat_map(|message| message.blocks.iter()) + .any(|block| matches!( + block, + runtime::ContentBlock::ToolResult { output, .. } + if output.contains("hello from child") + ))); + + let _ = std::fs::remove_file(path); + } + #[test] fn agent_rejects_blank_required_fields() { let missing_description = execute_tool( @@ -3212,7 +4172,7 @@ mod tests { #[test] fn brief_returns_sent_message_and_attachment_metadata() { let attachment = std::env::temp_dir().join(format!( - "clawd-brief-{}.png", + "claw-brief-{}.png", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .expect("time") @@ -3243,7 +4203,7 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let root = std::env::temp_dir().join(format!( - "clawd-config-{}", + "claw-config-{}", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .expect("time") @@ -3251,19 +4211,19 @@ mod tests { )); let home = root.join("home"); let cwd = root.join("cwd"); - std::fs::create_dir_all(home.join(".claude")).expect("home dir"); - std::fs::create_dir_all(cwd.join(".claude")).expect("cwd dir"); + std::fs::create_dir_all(home.join(".claw")).expect("home dir"); + std::fs::create_dir_all(cwd.join(".claw")).expect("cwd dir"); std::fs::write( - home.join(".claude").join("settings.json"), + home.join(".claw").join("settings.json"), r#"{"verbose":false}"#, ) .expect("write global settings"); let original_home = std::env::var("HOME").ok(); - let original_claude_home = std::env::var("CLAUDE_CONFIG_HOME").ok(); + let original_config_home = std::env::var("CLAW_CONFIG_HOME").ok(); let original_dir = std::env::current_dir().expect("cwd"); std::env::set_var("HOME", &home); - std::env::remove_var("CLAUDE_CONFIG_HOME"); + std::env::remove_var("CLAW_CONFIG_HOME"); std::env::set_current_dir(&cwd).expect("set cwd"); let get = execute_tool("Config", &json!({"setting": "verbose"})).expect("get config"); @@ -3296,9 +4256,9 @@ mod tests { Some(value) => std::env::set_var("HOME", value), None => std::env::remove_var("HOME"), } - match original_claude_home { - Some(value) => std::env::set_var("CLAUDE_CONFIG_HOME", value), - None => std::env::remove_var("CLAUDE_CONFIG_HOME"), + match original_config_home { + Some(value) => std::env::set_var("CLAW_CONFIG_HOME", value), + None => std::env::remove_var("CLAW_CONFIG_HOME"), } let _ = std::fs::remove_dir_all(root); } @@ -3332,7 +4292,7 @@ mod tests { .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); let dir = std::env::temp_dir().join(format!( - "clawd-pwsh-bin-{}", + "claw-pwsh-bin-{}", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .expect("time") @@ -3389,7 +4349,7 @@ printf 'pwsh:%s' "$1" .unwrap_or_else(std::sync::PoisonError::into_inner); let original_path = std::env::var("PATH").unwrap_or_default(); let empty_dir = std::env::temp_dir().join(format!( - "clawd-empty-bin-{}", + "claw-empty-bin-{}", std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .expect("time") diff --git a/rust/docs/releases/0.1.0.md b/rust/docs/releases/0.1.0.md new file mode 100644 index 0000000..5254475 --- /dev/null +++ b/rust/docs/releases/0.1.0.md @@ -0,0 +1,51 @@ +# Claw Code 0.1.0 发行说明(草案) + +## 摘要 + +Claw Code `0.1.0` 是当前 Rust 实现的第一个公开发布准备里程碑。Claw Code 的灵感来自 Claude Code,并作为一个净室(clean-room)Rust 实现构建;它不是直接的移植或复制。此版本专注于可用的本地 CLI 体验:交互式会话、非交互式提示词、工作区工具、配置加载、会话、插件以及本地代理/技能发现。 + +## 亮点 + +- Claw Code 的首个公开 `0.1.0` 发行候选版本 +- 作为当前主要产品界面的安全 Rust 实现 +- 用于交互式和单次编码代理工作流的 `claw` CLI +- 内置工作区工具:用于 shell、文件操作、搜索、网页获取/搜索、待办事项跟踪和笔记本更新 +- 斜杠命令界面:用于状态、压缩、配置检查、会话、差异/导出以及版本信息 +- 本地插件、代理和技能的发现/管理界面 +- OAuth 登录/注销以及模型/提供商选择 + +## 安装与运行 + +此版本目前旨在通过源码构建: + +```bash +cargo install --path crates/claw-cli --locked +# 或者 +cargo build --release -p claw-cli +``` + +运行: + +```bash +claw +claw prompt "总结此仓库" +``` + +## 已知限制 + +- 仅限源码构建分发;尚未发布打包好的发行构件 +- CI 目前覆盖 Ubuntu 和 macOS 的发布构建、检查和测试 +- Windows 的发布就绪性尚未建立 +- 部分集成覆盖是可选的,因为需要实时提供商凭据和网络访问 +- 公开接口可能会在 `0.x` 版本系列期间继续演进 + +## 推荐的发行定位 + +将 `0.1.0` 定位为 Claw Code 当前 Rust 实现的首个公开发布版本,面向习惯于从源码构建的早期采用者。功能表面已足够广泛以支持实际使用,而打包和发布自动化可以在后续版本中继续改进。 + +## 用于此草案的验证 + +- 通过 `Cargo.toml` 验证了工作区版本 +- 通过 `cargo metadata` 验证了 `claw` 二进制文件/包路径 +- 通过 `cargo run --quiet --bin claw -- --help` 验证了 CLI 命令表面 +- 通过 `.github/workflows/ci.yml` 验证了 CI 覆盖范围 diff --git a/src/__init__.py b/src/__init__.py index 2dc0c05..1360a1c 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,4 +1,4 @@ -"""Python porting workspace for the Claude Code rewrite effort.""" +"""Python porting workspace for the Claw Code rewrite effort.""" from .commands import PORTED_COMMANDS, build_command_backlog from .parity_audit import ParityAuditResult, run_parity_audit diff --git a/src/context.py b/src/context.py index 4cc59d7..e208fcd 100644 --- a/src/context.py +++ b/src/context.py @@ -21,7 +21,7 @@ def build_port_context(base: Path | None = None) -> PortContext: source_root = root / 'src' tests_root = root / 'tests' assets_root = root / 'assets' - archive_root = root / 'archive' / 'claude_code_ts_snapshot' / 'src' + archive_root = root / 'archive' / 'claw_code_ts_snapshot' / 'src' return PortContext( source_root=source_root, tests_root=tests_root, diff --git a/src/main.py b/src/main.py index e1fa9ed..9d74335 100644 --- a/src/main.py +++ b/src/main.py @@ -19,7 +19,7 @@ from .tools import execute_tool, get_tool, get_tools, render_tool_index def build_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description='Python porting workspace for the Claude Code rewrite effort') + parser = argparse.ArgumentParser(description='Python porting workspace for the Claw Code rewrite effort') subparsers = parser.add_subparsers(dest='command', required=True) subparsers.add_parser('summary', help='render a Markdown summary of the Python porting workspace') subparsers.add_parser('manifest', help='print the current Python workspace manifest') diff --git a/src/parity_audit.py b/src/parity_audit.py index 37b134c..39230d9 100644 --- a/src/parity_audit.py +++ b/src/parity_audit.py @@ -4,7 +4,7 @@ import json from dataclasses import dataclass from pathlib import Path -ARCHIVE_ROOT = Path(__file__).resolve().parent.parent / 'archive' / 'claude_code_ts_snapshot' / 'src' +ARCHIVE_ROOT = Path(__file__).resolve().parent.parent / 'archive' / 'claw_code_ts_snapshot' / 'src' CURRENT_ROOT = Path(__file__).resolve().parent REFERENCE_SURFACE_PATH = CURRENT_ROOT / 'reference_data' / 'archive_surface_snapshot.json' COMMAND_SNAPSHOT_PATH = CURRENT_ROOT / 'reference_data' / 'commands_snapshot.json' diff --git a/src/reference_data/archive_surface_snapshot.json b/src/reference_data/archive_surface_snapshot.json index aec56d6..0167acd 100644 --- a/src/reference_data/archive_surface_snapshot.json +++ b/src/reference_data/archive_surface_snapshot.json @@ -1,5 +1,5 @@ { - "archive_root": "archive/claude_code_ts_snapshot/src", + "archive_root": "archive/claw_code_ts_snapshot/src", "root_files": [ "QueryEngine.ts", "Task.ts", diff --git a/src/reference_data/commands_snapshot.json b/src/reference_data/commands_snapshot.json index 7177b2c..eb85fd5 100644 --- a/src/reference_data/commands_snapshot.json +++ b/src/reference_data/commands_snapshot.json @@ -330,9 +330,9 @@ "responsibility": "Command module mirrored from archived TypeScript path commands/files/index.ts" }, { - "name": "good-claude", - "source_hint": "commands/good-claude/index.js", - "responsibility": "Command module mirrored from archived TypeScript path commands/good-claude/index.js" + "name": "good-claw", + "source_hint": "commands/good-claw/index.js", + "responsibility": "Command module mirrored from archived TypeScript path commands/good-claw/index.js" }, { "name": "heapdump", diff --git a/src/reference_data/subsystems/components.json b/src/reference_data/subsystems/components.json index 329e882..510971a 100644 --- a/src/reference_data/subsystems/components.json +++ b/src/reference_data/subsystems/components.json @@ -15,9 +15,9 @@ "components/BridgeDialog.tsx", "components/BypassPermissionsModeDialog.tsx", "components/ChannelDowngradeDialog.tsx", - "components/ClaudeCodeHint/PluginHintMenu.tsx", - "components/ClaudeInChromeOnboarding.tsx", - "components/ClaudeMdExternalIncludesDialog.tsx", + "components/ClawCodeHint/PluginHintMenu.tsx", + "components/ClawInChromeOnboarding.tsx", + "components/ClawMdExternalIncludesDialog.tsx", "components/ClickableImageRef.tsx", "components/CompactSummary.tsx", "components/ConfigurableShortcutHint.tsx", diff --git a/src/reference_data/subsystems/services.json b/src/reference_data/subsystems/services.json index 554beb4..9f506ee 100644 --- a/src/reference_data/subsystems/services.json +++ b/src/reference_data/subsystems/services.json @@ -22,7 +22,7 @@ "services/analytics/sinkKillswitch.ts", "services/api/adminRequests.ts", "services/api/bootstrap.ts", - "services/api/claude.ts", + "services/api/claw.ts", "services/api/client.ts", "services/api/dumpPrompts.ts", "services/api/emptyUsage.ts", diff --git a/src/reference_data/subsystems/skills.json b/src/reference_data/subsystems/skills.json index 70ab672..5b323b1 100644 --- a/src/reference_data/subsystems/skills.json +++ b/src/reference_data/subsystems/skills.json @@ -4,9 +4,9 @@ "module_count": 20, "sample_files": [ "skills/bundled/batch.ts", - "skills/bundled/claudeApi.ts", - "skills/bundled/claudeApiContent.ts", - "skills/bundled/claudeInChrome.ts", + "skills/bundled/clawApi.ts", + "skills/bundled/clawApiContent.ts", + "skills/bundled/clawInChrome.ts", "skills/bundled/debug.ts", "skills/bundled/index.ts", "skills/bundled/keybindings.ts", diff --git a/src/reference_data/subsystems/types.json b/src/reference_data/subsystems/types.json index 0e35390..31d2e40 100644 --- a/src/reference_data/subsystems/types.json +++ b/src/reference_data/subsystems/types.json @@ -4,7 +4,7 @@ "module_count": 11, "sample_files": [ "types/command.ts", - "types/generated/events_mono/claude_code/v1/claude_code_internal_event.ts", + "types/generated/events_mono/claw_code/v1/claw_code_internal_event.ts", "types/generated/events_mono/common/v1/auth.ts", "types/generated/events_mono/growthbook/v1/growthbook_experiment_event.ts", "types/generated/google/protobuf/timestamp.ts", diff --git a/src/reference_data/tools_snapshot.json b/src/reference_data/tools_snapshot.json index cb3a293..3d4ac5f 100644 --- a/src/reference_data/tools_snapshot.json +++ b/src/reference_data/tools_snapshot.json @@ -35,9 +35,9 @@ "responsibility": "Tool module mirrored from archived TypeScript path tools/AgentTool/agentToolUtils.ts" }, { - "name": "claudeCodeGuideAgent", - "source_hint": "tools/AgentTool/built-in/claudeCodeGuideAgent.ts", - "responsibility": "Tool module mirrored from archived TypeScript path tools/AgentTool/built-in/claudeCodeGuideAgent.ts" + "name": "clawCodeGuideAgent", + "source_hint": "tools/AgentTool/built-in/clawCodeGuideAgent.ts", + "responsibility": "Tool module mirrored from archived TypeScript path tools/AgentTool/built-in/clawCodeGuideAgent.ts" }, { "name": "exploreAgent",