SpectraRust/.claude/skills/f2r-check/scripts/deep_check_prompt.py
2026-04-01 16:35:36 +08:00

242 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
深度检查提示生成器
根据模块名自动生成 Claude Phase 2 深度检查所需的文件列表和检查提示。
用法:
python3 deep_check_prompt.py ODFHYS # 生成检查文件列表
python3 deep_check_prompt.py ODFHYS --prompt # 生成完整检查提示
"""
import os
import re
import sys
import argparse
from typing import List, Dict, Optional
# 路径配置
EXTRACTED_DIR = "/home/fmq/program/tlusty/tl208-s54/rust/tlusty/extracted"
RUST_BASE_DIR = "/home/fmq/.zeroclaw/workspace/SpectraRust/src"
FORTRAN_COMMON_DIR = "/home/fmq/program/tlusty/tl208-s54/tlusty"
# 导入 common_db
script_dir = os.path.dirname(os.path.abspath(__file__))
if script_dir not in sys.path:
sys.path.insert(0, script_dir)
from common_db import (
get_includes_for_module,
get_commons_for_module,
get_vars_for_module,
get_rust_structs_for_module,
get_mapping,
get_structs,
get_blocks,
)
def find_rust_file(module_name: str) -> Optional[str]:
"""查找模块的 Rust 文件路径"""
rust_name = module_name.lower()
math_subdirs = [
'ali', 'atomic', 'continuum', 'convection', 'eos', 'hydrogen',
'interpolation', 'io', 'odf', 'opacity', 'partition', 'population',
'radiative', 'rates', 'solvers', 'special', 'temperature', 'utils'
]
# tlusty/io/
path = os.path.join(RUST_BASE_DIR, 'tlusty', 'io', f"{rust_name}.rs")
if os.path.exists(path):
return path
# tlusty/math/
path = os.path.join(RUST_BASE_DIR, 'tlusty', 'math', f"{rust_name}.rs")
if os.path.exists(path):
return path
# tlusty/math/子目录
for subdir in math_subdirs:
path = os.path.join(RUST_BASE_DIR, 'tlusty', 'math', subdir, f"{rust_name}.rs")
if os.path.exists(path):
return path
# tlusty/state/
path = os.path.join(RUST_BASE_DIR, 'tlusty', 'state', f"{rust_name}.rs")
if os.path.exists(path):
return path
return None
def find_rust_use_imports(rust_file: str) -> List[str]:
"""从 Rust 文件中提取 use 引用的 state 文件"""
state_files = set()
if not os.path.exists(rust_file):
return []
with open(rust_file, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# 匹配 use super::xxx 或 use crate::tlusty::state::xxx
patterns = [
r'use\s+super::(\w+)',
r'use\s+crate::tlusty::state::(\w+)',
r'use\s+super::super::state::(\w+)',
]
for pattern in patterns:
for m in re.finditer(pattern, content):
mod_name = m.group(1)
# 查找对应的 .rs 文件
state_file = os.path.join(RUST_BASE_DIR, 'tlusty', 'state', f"{mod_name}.rs")
if os.path.exists(state_file):
state_files.add(state_file)
return sorted(state_files)
def generate_file_list(module_name: str) -> Dict[str, str]:
"""生成深度检查所需的文件列表"""
files = {}
name_upper = module_name.upper()
# 1. Fortran 源文件
fortran_file = os.path.join(EXTRACTED_DIR, f"{module_name.lower()}.f")
if os.path.exists(fortran_file):
files['fortran_source'] = fortran_file
else:
files['fortran_source'] = f"(未找到: {fortran_file})"
# 2. Rust 源文件
rust_file = find_rust_file(module_name)
if rust_file:
files['rust_source'] = rust_file
else:
files['rust_source'] = "(未找到)"
# 3. INCLUDE 的 COMMON 定义文件
includes = get_includes_for_module(name_upper)
for inc in includes:
inc_path = os.path.join(FORTRAN_COMMON_DIR, f"{inc}.FOR")
key = f"common_{inc.lower()}"
if os.path.exists(inc_path):
files[key] = inc_path
else:
files[key] = f"(未找到: {inc_path})"
# 4. Rust state struct 文件(通过 use 导入)
if rust_file:
state_files = find_rust_use_imports(rust_file)
for i, sf in enumerate(state_files):
files[f"rust_state_{i}"] = sf
return files
def generate_prompt(module_name: str) -> str:
"""生成完整的 Phase 2 检查提示"""
files = generate_file_list(module_name)
var_map = get_mapping()
structs = get_structs()
# 获取模块的 COMMON 变量
module_vars = get_vars_for_module(module_name.upper(), var_map)
lines = []
lines.append(f"# Phase 2 深度语义检查: {module_name.upper()}")
lines.append("")
lines.append("## 需要读取的文件")
lines.append("")
for key, path in files.items():
if not path.startswith("(未找到"):
lines.append(f"- `{path}`")
else:
lines.append(f"- {path}")
lines.append("")
lines.append("## COMMON 变量映射")
lines.append("")
lines.append("```")
# 按 COMMON 块分组
vars_by_block: Dict[str, List] = {}
for vname, var in module_vars.items():
if var.common_block not in vars_by_block:
vars_by_block[var.common_block] = []
vars_by_block[var.common_block].append(var)
for block_name, vars in sorted(vars_by_block.items()):
lines.append(f"COMMON /{block_name}/")
for var in sorted(vars, key=lambda v: v.name):
dims_str = f"({', '.join(var.dims)})" if var.dims else ""
rust_str = f"{var.rust_struct}.{var.rust_field}" if var.rust_field else "(未映射)"
lines.append(f" {var.name:20s} {dims_str:20s}{rust_str}")
lines.append("")
lines.append("```")
lines.append("")
lines.append("## 检查清单")
lines.append("")
lines.append("逐项检查以下内容:")
lines.append("")
checklist = [
"[ ] COMMON 变量 → 正确的 Rust struct 字段",
"[ ] 2D 数组下标顺序Fortran 列主序 → Rust 行主序)",
"[ ] 1-based → 0-based 索引一致性",
"[ ] 循环边界转换DO I=1,N → for i in 0..n",
"[ ] IF 条件完整保留(<= vs <, >= vs >",
"[ ] 所有赋值目标存在(无遗漏)",
"[ ] CALL 顺序和数量一致",
"[ ] 类型转换正确INTEGER→i32, REAL*8→f64, LOGICAL→bool",
]
for item in checklist:
lines.append(item)
lines.append("")
lines.append("## 发现问题处理")
lines.append("")
lines.append("发现 bug → 立即修复 → cargo build 验证 → 继续检查")
lines.append("无 bug → 输出 '深度检查通过'")
return "\n".join(lines)
def main():
parser = argparse.ArgumentParser(description='Phase 2 深度检查提示生成器')
parser.add_argument('module', help='模块名')
parser.add_argument('--prompt', action='store_true', help='生成完整检查提示')
parser.add_argument('--files', action='store_true', help='只列出文件')
args = parser.parse_args()
if args.prompt:
print(generate_prompt(args.module))
elif args.files:
files = generate_file_list(args.module)
for key, path in files.items():
print(f" {key:20s} {path}")
else:
# 默认:输出文件列表
files = generate_file_list(args.module)
print(f"模块 {args.module.upper()} 深度检查文件列表:")
print()
for key, path in files.items():
icon = "📄" if not path.startswith("(未找到") else ""
print(f" {icon} {key:20s} {path}")
# 也显示 COMMON 变量数
var_map = get_mapping()
module_vars = get_vars_for_module(args.module.upper(), var_map)
mapped = sum(1 for v in module_vars.values() if v.rust_field)
print(f"\n COMMON 变量: {mapped}/{len(module_vars)} 已映射")
# 提示使用 --prompt 获取完整检查提示
print(f"\n 生成完整检查提示: python3 deep_check_prompt.py {args.module} --prompt")
if __name__ == "__main__":
main()