SpectraRust/.claude/skills/f2r-check/scripts/next_module.py
2026-04-01 16:35:36 +08:00

570 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
f2r_next - 下一个需要检查/修复的模块
根据依赖关系和当前状态,推荐下一个应该检查的模块。
策略:
1. 优先修复被多个模块依赖的基础模块
2. 从顶层模块(如 TLUSTY, START向下追踪
3. 跳过已完全匹配的模块
用法:
python3 next_module.py # 推荐下一个模块
python3 next_module.py --path START # 从 START 开始追踪
python3 next_module.py --chain TLUSTY # 显示完整调用链
python3 next_module.py --priority # 显示修复优先级列表
"""
import os
import re
import sys
import argparse
import glob
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import List, Dict, Set, Optional, Tuple
# 导入 f2r_check 的状态检测函数
try:
from f2r_check import check_module
USE_F2R_CHECK = True
except ImportError:
# 如果导入失败,添加脚本目录到路径
script_dir = os.path.dirname(os.path.abspath(__file__))
if script_dir not in sys.path:
sys.path.insert(0, script_dir)
try:
from f2r_check import check_module
USE_F2R_CHECK = True
except ImportError:
USE_F2R_CHECK = False
print("警告: 无法导入 f2r_check将使用简化状态检测", file=sys.stderr)
# ============================================================================
# 路径配置
# ============================================================================
EXTRACTED_DIR = "/home/fmq/program/tlusty/tl208-s54/rust/tlusty/extracted"
RUST_BASE_DIR = "/home/fmq/.zeroclaw/workspace/SpectraRust/src"
# ============================================================================
# 数据结构
# ============================================================================
@dataclass
class ModuleInfo:
"""模块信息"""
name: str
fortran_file: str = ""
rust_file: str = ""
status: str = "missing" # match, partial, mismatch, missing
calls: List[str] = field(default_factory=list)
called_by: List[str] = field(default_factory=list) # 被谁调用
depth: int = 0 # 依赖深度
trans_pending: int = 0 # 传递未实现依赖数
is_stub: bool = False
# ============================================================================
# Fortran 解析
# ============================================================================
FORTRAN_INTRINSICS = {
'SIN', 'COS', 'TAN', 'ASIN', 'ACOS', 'ATAN', 'ATAN2',
'SINH', 'COSH', 'TANH', 'EXP', 'LOG', 'LOG10', 'LOG2',
'SQRT', 'ABS', 'MOD', 'SIGN', 'MAX', 'MIN', 'MAX0', 'MIN0',
'INT', 'IFIX', 'IDINT', 'FLOAT', 'SNGL', 'DBLE', 'CMPLX',
'REAL', 'AIMAG', 'CONJG', 'ICHAR', 'CHAR', 'INDEX', 'LEN',
'IF', 'THEN', 'ELSE', 'ENDIF', 'END', 'DO', 'CONTINUE',
'RETURN', 'STOP', 'PAUSE', 'GOTO', 'CALL', 'SUBROUTINE',
'FUNCTION', 'PROGRAM', 'MODULE', 'USE', 'IMPLICIT',
'PARAMETER', 'DATA', 'DIMENSION', 'COMMON', 'SAVE',
'EXTERNAL', 'INTRINSIC', 'READ', 'WRITE', 'OPEN', 'CLOSE',
'FORMAT', 'PRINT', 'ERF', 'ERFC', 'GAMMA',
}
def strip_comments(content: str) -> str:
"""移除 Fortran 注释"""
lines = content.split('\n')
code_lines = []
for line in lines:
if len(line) == 0:
continue
first_char = line[0].upper()
if first_char in ('C', '!', '*'):
continue
code_lines.append(line)
return '\n'.join(code_lines)
def extract_calls(content: str) -> List[str]:
"""提取 CALL 语句"""
code_content = strip_comments(content)
calls = re.findall(r'(?i)CALL\s+(\w+)(?:\s*\(|\s*$|\s*\n)', code_content)
return list(set(c.upper() for c in calls if c.upper() not in FORTRAN_INTRINSICS))
def extract_subroutine_name(content: str) -> Optional[str]:
"""提取子程序名"""
match = re.search(r'(?i)^\s*SUBROUTINE\s+(\w+)', content, re.MULTILINE)
if match:
return match.group(1).upper()
match = re.search(r'(?i)^\s*PROGRAM\s+(\w+)', content, re.MULTILINE)
if match:
return match.group(1).upper()
# 尝试匹配 BLOCK DATA
match = re.search(r'^ BLOCK\s+DATA\s*([A-Za-z0-9_]*)\s*$', content, re.MULTILINE)
if match:
block_name = match.group(1).strip()
if block_name:
return block_name.upper()
else:
return "_UNNAMED_BLOCK_DATA_"
return None
# ============================================================================
# Rust 检查
# ============================================================================
SPECIAL_MAPPINGS = {
'gfree': ['gfree0', 'gfreed', 'gfree1'],
'interpolate': ['yint', 'lagran'],
'sgmer': ['sgmer0', 'sgmer1', 'sgmerd'],
'ctdata': ['hction', 'hctrecom'],
'cross': ['cross', 'crossd'],
'expint': ['eint', 'expinx'],
'erfcx': ['erfcx', 'erfcin'],
'lineqs': ['lineqs', 'lineqs_nr'],
'gamsp': ['gamsp'],
'bhe': ['bhe', 'bhed', 'bhez'],
'comset': ['comset'],
'ghydop': ['ghydop'],
'levgrp': ['levgrp'],
'profil': ['profil'],
'linspl': ['linspl'],
'convec': ['convec', 'convc1'],
}
def find_rust_module(fortran_name: str) -> Tuple[str, bool]:
"""查找对应的 Rust 模块,返回 (路径, 是否简化实现)"""
rust_name = fortran_name.lower()
math_subdirs = [
'ali', 'atomic', 'continuum', 'convection', 'eos', 'hydrogen',
'interpolation', 'io', 'odf', 'opacity', 'partition', 'population',
'radiative', 'rates', 'solvers', 'special', 'temperature', 'utils'
]
# 检查路径列表
search_paths = []
# 主程序
if fortran_name.upper() == 'TLUSTY':
search_paths.append(os.path.join(RUST_BASE_DIR, 'bin', 'tlusty.rs'))
search_paths.append(os.path.join(RUST_BASE_DIR, 'tlusty', 'main.rs'))
# tlusty/io/
search_paths.append(os.path.join(RUST_BASE_DIR, 'tlusty', 'io', f"{rust_name}.rs"))
# tlusty/math/
search_paths.append(os.path.join(RUST_BASE_DIR, 'tlusty', 'math', f"{rust_name}.rs"))
# tlusty/math/子目录
for subdir in math_subdirs:
search_paths.append(os.path.join(RUST_BASE_DIR, 'tlusty', 'math', subdir, f"{rust_name}.rs"))
# tlusty/state/
search_paths.append(os.path.join(RUST_BASE_DIR, 'tlusty', 'state', f"{rust_name}.rs"))
# 特殊映射
for rust_mod, fortran_funcs in SPECIAL_MAPPINGS.items():
if fortran_name.lower() in [f.lower() for f in fortran_funcs]:
search_paths.append(os.path.join(RUST_BASE_DIR, 'tlusty', 'math', f"{rust_mod}.rs"))
for subdir in math_subdirs:
search_paths.append(os.path.join(RUST_BASE_DIR, 'tlusty', 'math', subdir, f"{rust_mod}.rs"))
# BLOCK DATA 特殊处理 -> data.rs
if fortran_name.upper() == '_UNNAMED_BLOCK_DATA_':
search_paths.append(os.path.join(RUST_BASE_DIR, 'tlusty', 'data.rs'))
# 检查文件是否存在
for path in search_paths:
if os.path.exists(path):
with open(path, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# 只检查主函数体是否是简化实现(而非整个文件)
is_stub = check_main_function_stub(content, rust_name)
return path, is_stub
return "", False
def check_main_function_stub(content: str, func_name: str) -> bool:
"""检查主函数是否是简化实现(只检查主函数体,不检查辅助函数)"""
import re
# 查找主函数定义
# 支持多种模式pub fn name(...), pub fn name_pure(...), fn name(...)
patterns = [
rf'pub\s+fn\s+{func_name}\s*(?:<[^>]+>)?\s*\(',
rf'pub\s+fn\s+{func_name}_pure\s*(?:<[^>]+>)?\s*\(',
rf'fn\s+{func_name}\s*(?:<[^>]+>)?\s*\(',
]
func_body = ""
for pattern in patterns:
match = re.search(pattern, content, re.IGNORECASE | re.DOTALL)
if match:
# 提取函数体
func_start = match.end()
brace_count = 0
func_body_start = func_start
for i, c in enumerate(content[func_start:], func_start):
if c == '{':
if brace_count == 0:
func_body_start = i
brace_count += 1
elif c == '}':
brace_count -= 1
if brace_count == 0:
func_body = content[func_body_start:i+1]
break
break
if not func_body:
# 如果找不到主函数,检查整个文件
func_body = content
# 检查是否是简化实现
stub_patterns = [
r'//\s*简化实现',
r'//\s*TODO:',
r'//\s*待实现',
r'框架就绪',
r'unimplemented!',
r'todo!',
]
for p in stub_patterns:
if re.search(p, func_body, re.IGNORECASE):
return True
return False
# ============================================================================
# 依赖分析
# ============================================================================
def build_dependency_graph() -> Dict[str, ModuleInfo]:
"""构建依赖图"""
modules = {}
# 第一遍:收集所有模块
for fpath in glob.glob(os.path.join(EXTRACTED_DIR, "*.f")):
with open(fpath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
name = extract_subroutine_name(content)
if not name:
name = os.path.splitext(os.path.basename(fpath))[0].upper()
calls = extract_calls(content)
rust_file, is_stub = find_rust_module(name)
# 使用 f2r_check 的详细状态检测(如果可用)
if USE_F2R_CHECK and rust_file:
result = check_module(name, verbose=False)
status = result.status
# 从 result 获取更多调用信息
if result.issues:
is_stub = any('简化版本' in issue or '占位符' in issue for issue in result.issues)
else:
# 回退到简化状态检测
if not rust_file:
status = "missing"
elif is_stub:
status = "partial"
else:
status = "match"
modules[name] = ModuleInfo(
name=name,
fortran_file=os.path.basename(fpath),
rust_file=rust_file,
status=status,
calls=calls,
is_stub=is_stub,
)
# 第二遍:建立反向依赖
for name, info in modules.items():
for call in info.calls:
if call in modules:
modules[call].called_by.append(name)
# 计算依赖深度
def calc_depth(name: str, visited: Set[str]) -> int:
if name in visited:
return 0
if name not in modules:
return 0
visited.add(name)
calls = modules[name].calls
if not calls:
return 0
max_dep = 0
for call in calls:
if call != name:
max_dep = max(max_dep, calc_depth(call, visited.copy()))
return max_dep + 1
for name in modules:
modules[name].depth = calc_depth(name, set())
# 计算传递未实现依赖数
def calc_trans_pending(name: str, visited: Set[str]) -> int:
if name in visited:
return 0
if name not in modules:
return 1 # 未实现的模块
visited.add(name)
count = 0
for call in modules[name].calls:
if call not in modules:
count += 1
elif modules[call].status != "match":
count += 1 + calc_trans_pending(call, visited.copy())
return count
for name in modules:
modules[name].trans_pending = calc_trans_pending(name, set())
return modules
# ============================================================================
# 推荐逻辑
# ============================================================================
def find_next_module(modules: Dict[str, ModuleInfo], start_from: str = None) -> List[ModuleInfo]:
"""找到下一个需要检查的模块"""
if start_from and start_from.upper() in modules:
# 从指定模块开始,找其未实现的依赖
start = modules[start_from.upper()]
# BFS 遍历依赖
queue = deque([(start.name, 0)])
visited = set()
candidates = []
while queue:
name, level = queue.popleft()
if name in visited:
continue
visited.add(name)
if name not in modules:
continue
info = modules[name]
# 检查每个依赖
for call in info.calls:
if call in visited:
continue
if call not in modules:
# 未实现的模块
candidates.append((call, level + 1, "missing", 0))
elif modules[call].status == "partial":
candidates.append((call, level + 1, "partial", modules[call].called_by.__len__()))
elif modules[call].status == "mismatch":
candidates.append((call, level + 1, "mismatch", modules[call].called_by.__len__()))
elif modules[call].status == "missing":
candidates.append((call, level + 1, "missing", 0))
else:
# 已匹配,继续深入
queue.append((call, level + 1))
# 按优先级排序
candidates.sort(key=lambda x: (x[1], 0 if x[2] == "missing" else 1, -x[3]))
return candidates[:10]
else:
# 全局推荐:优先级 = 传递未实现依赖少 + 被调用次数多
candidates = []
for name, info in modules.items():
if info.status != "match":
# 计算被调用次数
called_count = len(info.called_by)
candidates.append((name, info.status, info.trans_pending, called_count, info.depth))
# 排序:传递未实现少 > 被调用多 > 深度小
candidates.sort(key=lambda x: (x[2], -x[3], x[4]))
return [(c[0], 0, c[1], c[3]) for c in candidates[:20]]
def get_call_chain(modules: Dict[str, ModuleInfo], start: str, end: str = None) -> List[str]:
"""获取调用链"""
chain = []
visited = set()
def dfs(name: str, path: List[str]) -> bool:
if name in visited:
return False
visited.add(name)
path.append(name)
if end and name == end:
chain.extend(path)
return True
if name not in modules:
if not end:
chain.extend(path)
return not end
for call in modules[name].calls:
if dfs(call, path.copy()):
return True
if not end:
chain.extend(path)
return True
return False
dfs(start.upper(), [])
return chain
# ============================================================================
# 输出格式
# ============================================================================
def print_next_module(modules: Dict[str, ModuleInfo], candidates: List[Tuple]):
"""打印推荐的下一个模块"""
print("=" * 70)
print("📋 下一个需要检查的模块")
print("=" * 70)
if not candidates:
print("✅ 所有模块都已匹配!")
return
for i, (name, level, status, called_count) in enumerate(candidates[:10], 1):
if name in modules:
info = modules[name]
status_icon = {"match": "", "partial": "⚠️", "mismatch": "", "missing": ""}.get(status, "")
print(f"\n{i}. {status_icon} {name}")
print(f" 状态: {status}")
print(f" Fortran: {info.fortran_file}")
if info.rust_file:
rust_rel = info.rust_file.replace(RUST_BASE_DIR, "src")
print(f" Rust: {rust_rel}")
else:
print(f" Rust: 未实现")
print(f" 被调用: {called_count}")
if info.trans_pending > 0:
print(f" 传递未实现依赖: {info.trans_pending}")
# 显示被谁调用
if info.called_by:
callers = info.called_by[:5]
print(f" 调用者: {', '.join(callers)}")
if len(info.called_by) > 5:
print(f" ... 还有 {len(info.called_by) - 5}")
else:
# 模块未实现
print(f"\n{i}. ❓ {name}")
print(f" 状态: missing")
print(f" Fortran: {name.lower()}.f")
print(f" Rust: 未实现")
print("\n" + "-" * 70)
print("建议:")
print(" 1. 先检查模块的 Fortran 源码")
print(" 2. 运行: python3 f2r_check.py --diff <模块名>")
print(" 3. 按照 Fortran 逻辑修复 Rust 实现")
def print_call_chain(modules: Dict[str, ModuleInfo], start: str):
"""打印调用链"""
print("=" * 70)
print(f"🔗 调用链: {start}")
print("=" * 70)
chain = get_call_chain(modules, start)
indent = 0
for i, name in enumerate(chain[:50]):
if name in modules:
info = modules[name]
status_icon = {"match": "", "partial": "⚠️", "mismatch": "", "missing": ""}.get(info.status, "")
print(f"{' ' * indent}{status_icon} {name}")
else:
print(f"{' ' * indent}{name} (未实现)")
indent = min(indent + 1, 5)
if len(chain) > 50:
print(f"{' ' * indent}... 还有 {len(chain) - 50} 个模块")
def print_priority_list(modules: Dict[str, ModuleInfo]):
"""打印修复优先级列表"""
print("=" * 70)
print("📊 修复优先级列表")
print("=" * 70)
print(f"{'排名':<4} {'模块':<15} {'状态':<10} {'被调用':<8} {'传递未实现':<10}")
print("-" * 70)
# 收集需要修复的模块
candidates = []
for name, info in modules.items():
if info.status != "match":
candidates.append((name, info.status, len(info.called_by), info.trans_pending))
# 按优先级排序
candidates.sort(key=lambda x: (x[3], -x[2]))
for i, (name, status, called, pending) in enumerate(candidates[:50], 1):
status_icon = {"match": "", "partial": "⚠️", "mismatch": "", "missing": ""}.get(status, "")
print(f"{i:<4} {name:<15} {status_icon} {status:<8} {called:<8} {pending:<10}")
# ============================================================================
# 主函数
# ============================================================================
def main():
parser = argparse.ArgumentParser(description='推荐下一个需要检查的模块')
parser.add_argument('--path', metavar='MODULE', help='从指定模块开始追踪')
parser.add_argument('--chain', metavar='MODULE', help='显示调用链')
parser.add_argument('--priority', action='store_true', help='显示修复优先级列表')
args = parser.parse_args()
# 构建依赖图
modules = build_dependency_graph()
if args.chain:
print_call_chain(modules, args.chain)
elif args.priority:
print_priority_list(modules)
else:
# 推荐下一个模块
candidates = find_next_module(modules, args.path)
print_next_module(modules, candidates)
if __name__ == "__main__":
main()