SpectraRust/.claude/skills/f2r-check/scripts/common_db.py
2026-04-01 16:35:36 +08:00

492 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
COMMON 变量映射数据库
解析 Fortran COMMON 块定义和 Rust struct 字段,构建完整的变量映射关系。
核心功能:
- parse_all_commons() — 解析 Fortran COMMON 定义
- parse_rust_structs() — 解析 Rust struct 字段
- build_mapping() — 交叉引用生成完整映射
- get_vars_for_module(module_name) — 返回某模块用到的所有 COMMON 变量
"""
import os
import re
import sys
from typing import Dict, List, Optional, Tuple, Set
from dataclasses import dataclass, field
# ============================================================================
# 路径配置
# ============================================================================
FORTRAN_COMMON_DIR = "/home/fmq/program/tlusty/tl208-s54/tlusty"
RUST_STATE_DIR = "/home/fmq/.zeroclaw/workspace/SpectraRust/src/tlusty/state"
EXTRACTED_DIR = "/home/fmq/program/tlusty/tl208-s54/rust/tlusty/extracted"
# Fortran COMMON 定义文件
COMMON_FILES = [
"BASICS.FOR", "ATOMIC.FOR", "MODELQ.FOR", "ARRAY1.FOR",
"ITERAT.FOR", "ALIPAR.FOR", "ODFPAR.FOR",
]
# ============================================================================
# 数据结构
# ============================================================================
@dataclass
class CommonVar:
"""COMMON 块变量"""
name: str # Fortran 变量名 (大写)
common_block: str # 所属 COMMON 块名
dims: List[str] = field(default_factory=list) # 维度 (如 ['MTRANS'])
rust_field: Optional[str] = None # 对应 Rust 字段名
rust_struct: Optional[str] = None # 对应 Rust struct 名
rust_file: Optional[str] = None # 对应 Rust 文件路径
is_2d: bool = False # 是否是 2D 数组
fortran_dims_raw: str = "" # 原始维度字符串 (如 "3,MHOD")
@dataclass
class CommonBlock:
"""COMMON 块"""
name: str # COMMON 块名
file: str # 定义文件
variables: List[CommonVar] = field(default_factory=list)
rust_struct: Optional[str] = None # 对应 Rust struct 名
rust_file: Optional[str] = None # 对应 Rust 文件
@dataclass
class RustStruct:
"""Rust struct 信息"""
name: str
file: str
common_name: Optional[str] = None # 对应的 COMMON 块名
fields: Dict[str, str] = field(default_factory=dict) # field_name -> type_str
# ============================================================================
# Fortran COMMON 解析
# ============================================================================
def _join_continuation_lines(content: str) -> str:
"""合并 Fortran 续行"""
lines = content.split('\n')
joined = []
for line in lines:
if not line:
continue
# 跳过注释行
if len(line) > 0 and line[0].upper() in ('C', '!', '*'):
continue
# 检查是否有续行标记 (第6列是 * 或 数字或非空)
if joined and len(line) >= 6 and line[5] not in (' ', '0', '\n'):
# 续行去掉前6列追加到上一行
joined[-1] = joined[-1].rstrip() + ' ' + line[6:].strip()
else:
joined.append(line)
return '\n'.join(joined)
def parse_common_block(content: str, filename: str) -> List[CommonBlock]:
"""解析一个 Fortran 文件中的所有 COMMON 块"""
blocks = []
joined = _join_continuation_lines(content)
# 匹配 COMMON/BLOCKNAME/var1,var2,...
# 处理多个 COMMON 语句可能属于同一个块
pattern = r'COMMON\s*/\s*(\w+)\s*/\s*(.+?)(?=\n\s*COMMON|\n\s*PARAMETER|\n\s*REAL|\n\s*INTEGER|\n\s*LOGICAL|\n\s*CHARACTER|\n\s*$|\nC|\n!|\Z)'
matches = re.finditer(pattern, joined, re.IGNORECASE | re.MULTILINE)
# 收集每个块的所有变量声明
block_vars: Dict[str, List[str]] = {}
for match in matches:
block_name = match.group(1).upper()
vars_str = match.group(2).strip()
# 去掉行尾的 Fortran 注释
if '!' in vars_str:
vars_str = vars_str[:vars_str.index('!')].strip()
# 追加到该块的变量列表
if block_name not in block_vars:
block_vars[block_name] = []
block_vars[block_name].append(vars_str)
for block_name, var_lists in block_vars.items():
all_vars_str = ','.join(var_lists)
variables = _parse_var_list(all_vars_str, block_name)
blocks.append(CommonBlock(
name=block_name,
file=filename,
variables=variables,
))
return blocks
def _parse_var_list(vars_str: str, block_name: str) -> List[CommonVar]:
"""解析变量列表字符串,返回 CommonVar 列表"""
variables = []
# 按逗号分割,但要处理括号内的逗号
parts = _split_respecting_parens(vars_str)
for part in parts:
part = part.strip()
if not part:
continue
# 匹配 VARNAME(DIMS) 或 VARNAME
m = re.match(r'^(\w+)\(([^)]+)\)$', part, re.IGNORECASE)
if m:
name = m.group(1).upper()
dims_str = m.group(2)
dims = [d.strip().upper() for d in dims_str.split(',')]
is_2d = len(dims) >= 2
variables.append(CommonVar(
name=name,
common_block=block_name,
dims=dims,
is_2d=is_2d,
fortran_dims_raw=dims_str,
))
else:
name = part.upper()
# 过滤非变量名
if re.match(r'^[A-Z]\w*$', name):
variables.append(CommonVar(
name=name,
common_block=block_name,
))
return variables
def _split_respecting_parens(s: str) -> List[str]:
"""按逗号分割,但忽略括号内的逗号"""
parts = []
depth = 0
current = []
for c in s:
if c == '(':
depth += 1
current.append(c)
elif c == ')':
depth -= 1
current.append(c)
elif c == ',' and depth == 0:
parts.append(''.join(current))
current = []
else:
current.append(c)
if current:
parts.append(''.join(current))
return parts
def parse_all_commons() -> Dict[str, CommonBlock]:
"""解析所有 Fortran COMMON 定义文件,返回 {block_name: CommonBlock}"""
all_blocks: Dict[str, CommonBlock] = {}
for filename in COMMON_FILES:
fpath = os.path.join(FORTRAN_COMMON_DIR, filename)
if not os.path.exists(fpath):
continue
with open(fpath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
blocks = parse_common_block(content, filename)
for block in blocks:
if block.name in all_blocks:
# 追加变量(可能同一块在不同文件中有补充定义)
all_blocks[block.name].variables.extend(block.variables)
else:
all_blocks[block.name] = block
return all_blocks
# ============================================================================
# Rust Struct 解析
# ============================================================================
def parse_rust_structs() -> List[RustStruct]:
"""解析所有 Rust state struct提取字段和 COMMON 对应关系"""
structs = []
if not os.path.isdir(RUST_STATE_DIR):
return structs
for fname in sorted(os.listdir(RUST_STATE_DIR)):
if not fname.endswith('.rs'):
continue
fpath = os.path.join(RUST_STATE_DIR, fname)
with open(fpath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
# 查找带有 "对应 COMMON" 注释的 struct
# 允许在注释和 pub struct 之间出现属性行如 #[derive(...)]
# 以及空行
pattern = (
r'///\s*对应\s*COMMON\s*/\s*(\w+)\s*/\s*\n'
r'(?:(?:\s*#[^\n]*\n|\s*///?[^\n]*\n|\s*\n))*' # 属性、注释、空行
r'\s*pub\s+struct\s+(\w+)\s*\{'
)
for match in re.finditer(pattern, content, re.IGNORECASE):
common_name = match.group(1).upper()
struct_name = match.group(2)
# 提取 struct body处理嵌套大括号
body_start = match.end()
body = _extract_braced_body(content, body_start)
# 提取字段
fields = {}
field_pattern = r'pub\s+(\w+)\s*:\s*([^,\n]+)'
for fm in re.finditer(field_pattern, body):
field_name = fm.group(1)
type_str = fm.group(2).strip()
fields[field_name] = type_str
structs.append(RustStruct(
name=struct_name,
file=fpath,
common_name=common_name,
fields=fields,
))
return structs
def _extract_braced_body(content: str, start: int) -> str:
"""从 start 位置(紧跟 { 之后)提取匹配的大括号体"""
depth = 1
i = start
while i < len(content) and depth > 0:
if content[i] == '{':
depth += 1
elif content[i] == '}':
depth -= 1
i += 1
return content[start:i-1] if depth == 0 else content[start:]
# ============================================================================
# 映射构建
# ============================================================================
def _fortran_to_rust_name(fortran_name: str) -> str:
"""Fortran 变量名转 Rust 字段名(大写 → 小写)"""
return fortran_name.lower()
def build_mapping(
common_blocks: Dict[str, CommonBlock],
rust_structs: List[RustStruct]
) -> Dict[str, CommonVar]:
"""交叉引用 Fortran COMMON 和 Rust struct生成完整映射
返回: {FORTAN_VAR_NAME: CommonVar (包含 rust_field, rust_struct 信息)}
"""
var_map: Dict[str, CommonVar] = {}
# 先收集所有 COMMON 变量
for block_name, block in common_blocks.items():
for var in block.variables:
var_map[var.name] = var
# 构建 struct_name -> RustStruct 映射
struct_by_common: Dict[str, RustStruct] = {}
for rs in rust_structs:
if rs.common_name:
struct_by_common[rs.common_name.upper()] = rs
# 交叉引用
for var_name, var in var_map.items():
# 查找对应 Rust struct
rs = struct_by_common.get(var.common_block)
if rs:
var.rust_struct = rs.name
var.rust_file = rs.file
# 查找对应字段
rust_field_name = _fortran_to_rust_name(var_name)
if rust_field_name in rs.fields:
var.rust_field = rust_field_name
# 设置 CommonBlock 的 rust_struct 信息
for block_name, block in common_blocks.items():
rs = struct_by_common.get(block_name)
if rs:
block.rust_struct = rs.name
block.rust_file = rs.file
return var_map
# ============================================================================
# 模块级查询
# ============================================================================
def get_includes_for_module(module_name: str) -> List[str]:
"""获取某 Fortran 模块 INCLUDE 的文件列表"""
fpath = os.path.join(EXTRACTED_DIR, f"{module_name.lower()}.f")
if not os.path.exists(fpath):
return []
with open(fpath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
includes = re.findall(r"INCLUDE\s*'([^']+)\.FOR'", content, re.IGNORECASE)
return [inc.upper() for inc in includes if inc.upper() != 'IMPLIC']
def get_commons_for_module(module_name: str) -> List[str]:
"""获取某 Fortran 模块使用的 COMMON 块名列表"""
includes = get_includes_for_module(module_name)
commons = set()
for inc in includes:
fpath = os.path.join(FORTRAN_COMMON_DIR, f"{inc}.FOR")
if not os.path.exists(fpath):
continue
with open(fpath, 'r', encoding='utf-8', errors='ignore') as f:
content = f.read()
blocks = re.findall(r'(?i)COMMON\s*/(\w+)/', content)
commons.update(b.upper() for b in blocks)
return sorted(commons)
def get_vars_for_module(
module_name: str,
var_map: Dict[str, CommonVar]
) -> Dict[str, CommonVar]:
"""返回某模块用到的所有 COMMON 变量及其映射
参数:
module_name: Fortran 模块名
var_map: build_mapping() 的返回值
返回: {VAR_NAME: CommonVar}
"""
commons = get_commons_for_module(module_name)
result = {}
for var_name, var in var_map.items():
if var.common_block in commons:
result[var_name] = var
return result
def get_rust_structs_for_module(
module_name: str,
rust_structs: List[RustStruct]
) -> List[str]:
"""获取某模块需要 use 的 Rust struct 文件路径"""
commons = get_commons_for_module(module_name)
files = set()
for rs in rust_structs:
if rs.common_name and rs.common_name.upper() in commons:
files.add(rs.file)
return sorted(files)
# ============================================================================
# 缓存单例
# ============================================================================
_cached_mapping = None
_cached_structs = None
_cached_blocks = None
def get_mapping():
"""获取缓存的变量映射"""
global _cached_mapping, _cached_structs, _cached_blocks
if _cached_mapping is None:
_cached_blocks = parse_all_commons()
_cached_structs = parse_rust_structs()
_cached_mapping = build_mapping(_cached_blocks, _cached_structs)
return _cached_mapping
def get_structs():
"""获取缓存的 Rust struct 列表"""
global _cached_structs
if _cached_structs is None:
get_mapping()
return _cached_structs
def get_blocks():
"""获取缓存的 COMMON 块"""
global _cached_blocks
if _cached_blocks is None:
get_mapping()
return _cached_blocks
# ============================================================================
# CLI
# ============================================================================
def main():
import argparse
parser = argparse.ArgumentParser(description='COMMON 变量映射数据库')
parser.add_argument('--module', help='显示某模块使用的 COMMON 变量')
parser.add_argument('--block', help='显示某 COMMON 块的变量')
parser.add_argument('--mapping', action='store_true', help='显示完整映射')
parser.add_argument('--unmapped', action='store_true', help='显示未映射的变量')
args = parser.parse_args()
var_map = get_mapping()
blocks = get_blocks()
structs = get_structs()
if args.module:
vars = get_vars_for_module(args.module.upper(), var_map)
print(f"模块 {args.module.upper()} 使用的 COMMON 变量:")
print(f" 总计: {len(vars)} 个变量")
for vname, var in sorted(vars.items()):
dims_str = f"({', '.join(var.dims)})" if var.dims else ""
rust_str = f"{var.rust_struct}.{var.rust_field}" if var.rust_field else "→ (未映射)"
print(f" {vname:20s} {dims_str:20s} {rust_str}")
return
if args.block:
block = blocks.get(args.block.upper())
if not block:
print(f"COMMON 块 {args.block} 未找到")
return
print(f"COMMON /{block.name}/ (文件: {block.file})")
for var in block.variables:
dims_str = f"({', '.join(var.dims)})" if var.dims else ""
rust_str = f"{var.rust_field}" if var.rust_field else "→ (未映射)"
print(f" {var.name:20s} {dims_str:20s} {rust_str}")
return
if args.unmapped:
unmapped = {k: v for k, v in var_map.items() if not v.rust_field}
print(f"未映射的 COMMON 变量: {len(unmapped)} / {len(var_map)}")
for vname, var in sorted(unmapped.items()):
dims_str = f"({', '.join(var.dims)})" if var.dims else ""
print(f" /{var.common_block}/ {vname:20s} {dims_str}")
return
if args.mapping:
print(f"COMMON 变量映射统计:")
mapped = sum(1 for v in var_map.values() if v.rust_field)
print(f" 总变量: {len(var_map)}")
print(f" 已映射: {mapped}")
print(f" 未映射: {len(var_map) - mapped}")
print()
print("COMMON 块:")
for bname, block in sorted(blocks.items()):
n_mapped = sum(1 for v in block.variables if v.rust_field)
print(f" /{bname}/ → {block.rust_struct or '(无)'} ({n_mapped}/{len(block.variables)})")
return
# 默认:统计信息
print("COMMON 变量映射数据库")
print(f" COMMON 块: {len(blocks)}")
print(f" COMMON 变量: {len(var_map)}")
print(f" Rust struct: {len(structs)}")
mapped = sum(1 for v in var_map.values() if v.rust_field)
print(f" 已映射: {mapped}/{len(var_map)}")
if __name__ == "__main__":
main()