242 lines
7.8 KiB
Python
242 lines
7.8 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
深度检查提示生成器
|
||
|
||
根据模块名自动生成 Claude Phase 2 深度检查所需的文件列表和检查提示。
|
||
|
||
用法:
|
||
python3 deep_check_prompt.py ODFHYS # 生成检查文件列表
|
||
python3 deep_check_prompt.py ODFHYS --prompt # 生成完整检查提示
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import sys
|
||
import argparse
|
||
from typing import List, Dict, Optional
|
||
|
||
# 路径配置
|
||
EXTRACTED_DIR = "/home/fmq/program/tlusty/tl208-s54/rust/tlusty/extracted"
|
||
RUST_BASE_DIR = "/home/fmq/.zeroclaw/workspace/SpectraRust/src"
|
||
FORTRAN_COMMON_DIR = "/home/fmq/program/tlusty/tl208-s54/tlusty"
|
||
|
||
# 导入 common_db
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
if script_dir not in sys.path:
|
||
sys.path.insert(0, script_dir)
|
||
|
||
from common_db import (
|
||
get_includes_for_module,
|
||
get_commons_for_module,
|
||
get_vars_for_module,
|
||
get_rust_structs_for_module,
|
||
get_mapping,
|
||
get_structs,
|
||
get_blocks,
|
||
)
|
||
|
||
|
||
def find_rust_file(module_name: str) -> Optional[str]:
|
||
"""查找模块的 Rust 文件路径"""
|
||
rust_name = module_name.lower()
|
||
|
||
math_subdirs = [
|
||
'ali', 'atomic', 'continuum', 'convection', 'eos', 'hydrogen',
|
||
'interpolation', 'io', 'odf', 'opacity', 'partition', 'population',
|
||
'radiative', 'rates', 'solvers', 'special', 'temperature', 'utils'
|
||
]
|
||
|
||
# tlusty/io/
|
||
path = os.path.join(RUST_BASE_DIR, 'tlusty', 'io', f"{rust_name}.rs")
|
||
if os.path.exists(path):
|
||
return path
|
||
|
||
# tlusty/math/
|
||
path = os.path.join(RUST_BASE_DIR, 'tlusty', 'math', f"{rust_name}.rs")
|
||
if os.path.exists(path):
|
||
return path
|
||
|
||
# tlusty/math/子目录
|
||
for subdir in math_subdirs:
|
||
path = os.path.join(RUST_BASE_DIR, 'tlusty', 'math', subdir, f"{rust_name}.rs")
|
||
if os.path.exists(path):
|
||
return path
|
||
|
||
# tlusty/state/
|
||
path = os.path.join(RUST_BASE_DIR, 'tlusty', 'state', f"{rust_name}.rs")
|
||
if os.path.exists(path):
|
||
return path
|
||
|
||
return None
|
||
|
||
|
||
def find_rust_use_imports(rust_file: str) -> List[str]:
|
||
"""从 Rust 文件中提取 use 引用的 state 文件"""
|
||
state_files = set()
|
||
if not os.path.exists(rust_file):
|
||
return []
|
||
|
||
with open(rust_file, 'r', encoding='utf-8', errors='ignore') as f:
|
||
content = f.read()
|
||
|
||
# 匹配 use super::xxx 或 use crate::tlusty::state::xxx
|
||
patterns = [
|
||
r'use\s+super::(\w+)',
|
||
r'use\s+crate::tlusty::state::(\w+)',
|
||
r'use\s+super::super::state::(\w+)',
|
||
]
|
||
for pattern in patterns:
|
||
for m in re.finditer(pattern, content):
|
||
mod_name = m.group(1)
|
||
# 查找对应的 .rs 文件
|
||
state_file = os.path.join(RUST_BASE_DIR, 'tlusty', 'state', f"{mod_name}.rs")
|
||
if os.path.exists(state_file):
|
||
state_files.add(state_file)
|
||
|
||
return sorted(state_files)
|
||
|
||
|
||
def generate_file_list(module_name: str) -> Dict[str, str]:
|
||
"""生成深度检查所需的文件列表"""
|
||
files = {}
|
||
name_upper = module_name.upper()
|
||
|
||
# 1. Fortran 源文件
|
||
fortran_file = os.path.join(EXTRACTED_DIR, f"{module_name.lower()}.f")
|
||
if os.path.exists(fortran_file):
|
||
files['fortran_source'] = fortran_file
|
||
else:
|
||
files['fortran_source'] = f"(未找到: {fortran_file})"
|
||
|
||
# 2. Rust 源文件
|
||
rust_file = find_rust_file(module_name)
|
||
if rust_file:
|
||
files['rust_source'] = rust_file
|
||
else:
|
||
files['rust_source'] = "(未找到)"
|
||
|
||
# 3. INCLUDE 的 COMMON 定义文件
|
||
includes = get_includes_for_module(name_upper)
|
||
for inc in includes:
|
||
inc_path = os.path.join(FORTRAN_COMMON_DIR, f"{inc}.FOR")
|
||
key = f"common_{inc.lower()}"
|
||
if os.path.exists(inc_path):
|
||
files[key] = inc_path
|
||
else:
|
||
files[key] = f"(未找到: {inc_path})"
|
||
|
||
# 4. Rust state struct 文件(通过 use 导入)
|
||
if rust_file:
|
||
state_files = find_rust_use_imports(rust_file)
|
||
for i, sf in enumerate(state_files):
|
||
files[f"rust_state_{i}"] = sf
|
||
|
||
return files
|
||
|
||
|
||
def generate_prompt(module_name: str) -> str:
|
||
"""生成完整的 Phase 2 检查提示"""
|
||
files = generate_file_list(module_name)
|
||
var_map = get_mapping()
|
||
structs = get_structs()
|
||
|
||
# 获取模块的 COMMON 变量
|
||
module_vars = get_vars_for_module(module_name.upper(), var_map)
|
||
|
||
lines = []
|
||
lines.append(f"# Phase 2 深度语义检查: {module_name.upper()}")
|
||
lines.append("")
|
||
lines.append("## 需要读取的文件")
|
||
lines.append("")
|
||
|
||
for key, path in files.items():
|
||
if not path.startswith("(未找到"):
|
||
lines.append(f"- `{path}`")
|
||
else:
|
||
lines.append(f"- {path}")
|
||
|
||
lines.append("")
|
||
lines.append("## COMMON 变量映射")
|
||
lines.append("")
|
||
lines.append("```")
|
||
|
||
# 按 COMMON 块分组
|
||
vars_by_block: Dict[str, List] = {}
|
||
for vname, var in module_vars.items():
|
||
if var.common_block not in vars_by_block:
|
||
vars_by_block[var.common_block] = []
|
||
vars_by_block[var.common_block].append(var)
|
||
|
||
for block_name, vars in sorted(vars_by_block.items()):
|
||
lines.append(f"COMMON /{block_name}/")
|
||
for var in sorted(vars, key=lambda v: v.name):
|
||
dims_str = f"({', '.join(var.dims)})" if var.dims else ""
|
||
rust_str = f"{var.rust_struct}.{var.rust_field}" if var.rust_field else "(未映射)"
|
||
lines.append(f" {var.name:20s} {dims_str:20s} → {rust_str}")
|
||
lines.append("")
|
||
|
||
lines.append("```")
|
||
lines.append("")
|
||
lines.append("## 检查清单")
|
||
lines.append("")
|
||
lines.append("逐项检查以下内容:")
|
||
lines.append("")
|
||
|
||
checklist = [
|
||
"[ ] COMMON 变量 → 正确的 Rust struct 字段",
|
||
"[ ] 2D 数组下标顺序(Fortran 列主序 → Rust 行主序)",
|
||
"[ ] 1-based → 0-based 索引一致性",
|
||
"[ ] 循环边界转换(DO I=1,N → for i in 0..n)",
|
||
"[ ] IF 条件完整保留(<= vs <, >= vs >)",
|
||
"[ ] 所有赋值目标存在(无遗漏)",
|
||
"[ ] CALL 顺序和数量一致",
|
||
"[ ] 类型转换正确(INTEGER→i32, REAL*8→f64, LOGICAL→bool)",
|
||
]
|
||
|
||
for item in checklist:
|
||
lines.append(item)
|
||
|
||
lines.append("")
|
||
lines.append("## 发现问题处理")
|
||
lines.append("")
|
||
lines.append("发现 bug → 立即修复 → cargo build 验证 → 继续检查")
|
||
lines.append("无 bug → 输出 '深度检查通过'")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='Phase 2 深度检查提示生成器')
|
||
parser.add_argument('module', help='模块名')
|
||
parser.add_argument('--prompt', action='store_true', help='生成完整检查提示')
|
||
parser.add_argument('--files', action='store_true', help='只列出文件')
|
||
args = parser.parse_args()
|
||
|
||
if args.prompt:
|
||
print(generate_prompt(args.module))
|
||
elif args.files:
|
||
files = generate_file_list(args.module)
|
||
for key, path in files.items():
|
||
print(f" {key:20s} {path}")
|
||
else:
|
||
# 默认:输出文件列表
|
||
files = generate_file_list(args.module)
|
||
print(f"模块 {args.module.upper()} 深度检查文件列表:")
|
||
print()
|
||
for key, path in files.items():
|
||
icon = "📄" if not path.startswith("(未找到") else "❓"
|
||
print(f" {icon} {key:20s} {path}")
|
||
|
||
# 也显示 COMMON 变量数
|
||
var_map = get_mapping()
|
||
module_vars = get_vars_for_module(args.module.upper(), var_map)
|
||
mapped = sum(1 for v in module_vars.values() if v.rust_field)
|
||
print(f"\n COMMON 变量: {mapped}/{len(module_vars)} 已映射")
|
||
|
||
# 提示使用 --prompt 获取完整检查提示
|
||
print(f"\n 生成完整检查提示: python3 deep_check_prompt.py {args.module} --prompt")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|