492 lines
17 KiB
Python
492 lines
17 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
COMMON 变量映射数据库
|
||
|
||
解析 Fortran COMMON 块定义和 Rust struct 字段,构建完整的变量映射关系。
|
||
|
||
核心功能:
|
||
- parse_all_commons() — 解析 Fortran COMMON 定义
|
||
- parse_rust_structs() — 解析 Rust struct 字段
|
||
- build_mapping() — 交叉引用生成完整映射
|
||
- get_vars_for_module(module_name) — 返回某模块用到的所有 COMMON 变量
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import sys
|
||
from typing import Dict, List, Optional, Tuple, Set
|
||
from dataclasses import dataclass, field
|
||
|
||
# ============================================================================
|
||
# 路径配置
|
||
# ============================================================================
|
||
|
||
FORTRAN_COMMON_DIR = "/home/fmq/program/tlusty/tl208-s54/tlusty"
|
||
RUST_STATE_DIR = "/home/fmq/.zeroclaw/workspace/SpectraRust/src/tlusty/state"
|
||
EXTRACTED_DIR = "/home/fmq/program/tlusty/tl208-s54/rust/tlusty/extracted"
|
||
|
||
# Fortran COMMON 定义文件
|
||
COMMON_FILES = [
|
||
"BASICS.FOR", "ATOMIC.FOR", "MODELQ.FOR", "ARRAY1.FOR",
|
||
"ITERAT.FOR", "ALIPAR.FOR", "ODFPAR.FOR",
|
||
]
|
||
|
||
# ============================================================================
|
||
# 数据结构
|
||
# ============================================================================
|
||
|
||
@dataclass
|
||
class CommonVar:
|
||
"""COMMON 块变量"""
|
||
name: str # Fortran 变量名 (大写)
|
||
common_block: str # 所属 COMMON 块名
|
||
dims: List[str] = field(default_factory=list) # 维度 (如 ['MTRANS'])
|
||
rust_field: Optional[str] = None # 对应 Rust 字段名
|
||
rust_struct: Optional[str] = None # 对应 Rust struct 名
|
||
rust_file: Optional[str] = None # 对应 Rust 文件路径
|
||
is_2d: bool = False # 是否是 2D 数组
|
||
fortran_dims_raw: str = "" # 原始维度字符串 (如 "3,MHOD")
|
||
|
||
@dataclass
|
||
class CommonBlock:
|
||
"""COMMON 块"""
|
||
name: str # COMMON 块名
|
||
file: str # 定义文件
|
||
variables: List[CommonVar] = field(default_factory=list)
|
||
rust_struct: Optional[str] = None # 对应 Rust struct 名
|
||
rust_file: Optional[str] = None # 对应 Rust 文件
|
||
|
||
@dataclass
|
||
class RustStruct:
|
||
"""Rust struct 信息"""
|
||
name: str
|
||
file: str
|
||
common_name: Optional[str] = None # 对应的 COMMON 块名
|
||
fields: Dict[str, str] = field(default_factory=dict) # field_name -> type_str
|
||
|
||
# ============================================================================
|
||
# Fortran COMMON 解析
|
||
# ============================================================================
|
||
|
||
def _join_continuation_lines(content: str) -> str:
|
||
"""合并 Fortran 续行"""
|
||
lines = content.split('\n')
|
||
joined = []
|
||
for line in lines:
|
||
if not line:
|
||
continue
|
||
# 跳过注释行
|
||
if len(line) > 0 and line[0].upper() in ('C', '!', '*'):
|
||
continue
|
||
# 检查是否有续行标记 (第6列是 * 或 数字或非空)
|
||
if joined and len(line) >= 6 and line[5] not in (' ', '0', '\n'):
|
||
# 续行:去掉前6列,追加到上一行
|
||
joined[-1] = joined[-1].rstrip() + ' ' + line[6:].strip()
|
||
else:
|
||
joined.append(line)
|
||
return '\n'.join(joined)
|
||
|
||
|
||
def parse_common_block(content: str, filename: str) -> List[CommonBlock]:
|
||
"""解析一个 Fortran 文件中的所有 COMMON 块"""
|
||
blocks = []
|
||
joined = _join_continuation_lines(content)
|
||
|
||
# 匹配 COMMON/BLOCKNAME/var1,var2,...
|
||
# 处理多个 COMMON 语句可能属于同一个块
|
||
pattern = r'COMMON\s*/\s*(\w+)\s*/\s*(.+?)(?=\n\s*COMMON|\n\s*PARAMETER|\n\s*REAL|\n\s*INTEGER|\n\s*LOGICAL|\n\s*CHARACTER|\n\s*$|\nC|\n!|\Z)'
|
||
|
||
matches = re.finditer(pattern, joined, re.IGNORECASE | re.MULTILINE)
|
||
|
||
# 收集每个块的所有变量声明
|
||
block_vars: Dict[str, List[str]] = {}
|
||
for match in matches:
|
||
block_name = match.group(1).upper()
|
||
vars_str = match.group(2).strip()
|
||
# 去掉行尾的 Fortran 注释
|
||
if '!' in vars_str:
|
||
vars_str = vars_str[:vars_str.index('!')].strip()
|
||
# 追加到该块的变量列表
|
||
if block_name not in block_vars:
|
||
block_vars[block_name] = []
|
||
block_vars[block_name].append(vars_str)
|
||
|
||
for block_name, var_lists in block_vars.items():
|
||
all_vars_str = ','.join(var_lists)
|
||
variables = _parse_var_list(all_vars_str, block_name)
|
||
blocks.append(CommonBlock(
|
||
name=block_name,
|
||
file=filename,
|
||
variables=variables,
|
||
))
|
||
|
||
return blocks
|
||
|
||
|
||
def _parse_var_list(vars_str: str, block_name: str) -> List[CommonVar]:
|
||
"""解析变量列表字符串,返回 CommonVar 列表"""
|
||
variables = []
|
||
# 按逗号分割,但要处理括号内的逗号
|
||
parts = _split_respecting_parens(vars_str)
|
||
|
||
for part in parts:
|
||
part = part.strip()
|
||
if not part:
|
||
continue
|
||
# 匹配 VARNAME(DIMS) 或 VARNAME
|
||
m = re.match(r'^(\w+)\(([^)]+)\)$', part, re.IGNORECASE)
|
||
if m:
|
||
name = m.group(1).upper()
|
||
dims_str = m.group(2)
|
||
dims = [d.strip().upper() for d in dims_str.split(',')]
|
||
is_2d = len(dims) >= 2
|
||
variables.append(CommonVar(
|
||
name=name,
|
||
common_block=block_name,
|
||
dims=dims,
|
||
is_2d=is_2d,
|
||
fortran_dims_raw=dims_str,
|
||
))
|
||
else:
|
||
name = part.upper()
|
||
# 过滤非变量名
|
||
if re.match(r'^[A-Z]\w*$', name):
|
||
variables.append(CommonVar(
|
||
name=name,
|
||
common_block=block_name,
|
||
))
|
||
|
||
return variables
|
||
|
||
|
||
def _split_respecting_parens(s: str) -> List[str]:
|
||
"""按逗号分割,但忽略括号内的逗号"""
|
||
parts = []
|
||
depth = 0
|
||
current = []
|
||
for c in s:
|
||
if c == '(':
|
||
depth += 1
|
||
current.append(c)
|
||
elif c == ')':
|
||
depth -= 1
|
||
current.append(c)
|
||
elif c == ',' and depth == 0:
|
||
parts.append(''.join(current))
|
||
current = []
|
||
else:
|
||
current.append(c)
|
||
if current:
|
||
parts.append(''.join(current))
|
||
return parts
|
||
|
||
|
||
def parse_all_commons() -> Dict[str, CommonBlock]:
|
||
"""解析所有 Fortran COMMON 定义文件,返回 {block_name: CommonBlock}"""
|
||
all_blocks: Dict[str, CommonBlock] = {}
|
||
|
||
for filename in COMMON_FILES:
|
||
fpath = os.path.join(FORTRAN_COMMON_DIR, filename)
|
||
if not os.path.exists(fpath):
|
||
continue
|
||
with open(fpath, 'r', encoding='utf-8', errors='ignore') as f:
|
||
content = f.read()
|
||
|
||
blocks = parse_common_block(content, filename)
|
||
for block in blocks:
|
||
if block.name in all_blocks:
|
||
# 追加变量(可能同一块在不同文件中有补充定义)
|
||
all_blocks[block.name].variables.extend(block.variables)
|
||
else:
|
||
all_blocks[block.name] = block
|
||
|
||
return all_blocks
|
||
|
||
|
||
# ============================================================================
|
||
# Rust Struct 解析
|
||
# ============================================================================
|
||
|
||
def parse_rust_structs() -> List[RustStruct]:
|
||
"""解析所有 Rust state struct,提取字段和 COMMON 对应关系"""
|
||
structs = []
|
||
|
||
if not os.path.isdir(RUST_STATE_DIR):
|
||
return structs
|
||
|
||
for fname in sorted(os.listdir(RUST_STATE_DIR)):
|
||
if not fname.endswith('.rs'):
|
||
continue
|
||
fpath = os.path.join(RUST_STATE_DIR, fname)
|
||
with open(fpath, 'r', encoding='utf-8', errors='ignore') as f:
|
||
content = f.read()
|
||
|
||
# 查找带有 "对应 COMMON" 注释的 struct
|
||
# 允许在注释和 pub struct 之间出现属性行如 #[derive(...)]
|
||
# 以及空行
|
||
pattern = (
|
||
r'///\s*对应\s*COMMON\s*/\s*(\w+)\s*/\s*\n'
|
||
r'(?:(?:\s*#[^\n]*\n|\s*///?[^\n]*\n|\s*\n))*' # 属性、注释、空行
|
||
r'\s*pub\s+struct\s+(\w+)\s*\{'
|
||
)
|
||
for match in re.finditer(pattern, content, re.IGNORECASE):
|
||
common_name = match.group(1).upper()
|
||
struct_name = match.group(2)
|
||
|
||
# 提取 struct body(处理嵌套大括号)
|
||
body_start = match.end()
|
||
body = _extract_braced_body(content, body_start)
|
||
|
||
# 提取字段
|
||
fields = {}
|
||
field_pattern = r'pub\s+(\w+)\s*:\s*([^,\n]+)'
|
||
for fm in re.finditer(field_pattern, body):
|
||
field_name = fm.group(1)
|
||
type_str = fm.group(2).strip()
|
||
fields[field_name] = type_str
|
||
|
||
structs.append(RustStruct(
|
||
name=struct_name,
|
||
file=fpath,
|
||
common_name=common_name,
|
||
fields=fields,
|
||
))
|
||
|
||
return structs
|
||
|
||
|
||
def _extract_braced_body(content: str, start: int) -> str:
|
||
"""从 start 位置(紧跟 { 之后)提取匹配的大括号体"""
|
||
depth = 1
|
||
i = start
|
||
while i < len(content) and depth > 0:
|
||
if content[i] == '{':
|
||
depth += 1
|
||
elif content[i] == '}':
|
||
depth -= 1
|
||
i += 1
|
||
return content[start:i-1] if depth == 0 else content[start:]
|
||
|
||
|
||
# ============================================================================
|
||
# 映射构建
|
||
# ============================================================================
|
||
|
||
def _fortran_to_rust_name(fortran_name: str) -> str:
|
||
"""Fortran 变量名转 Rust 字段名(大写 → 小写)"""
|
||
return fortran_name.lower()
|
||
|
||
|
||
def build_mapping(
|
||
common_blocks: Dict[str, CommonBlock],
|
||
rust_structs: List[RustStruct]
|
||
) -> Dict[str, CommonVar]:
|
||
"""交叉引用 Fortran COMMON 和 Rust struct,生成完整映射
|
||
|
||
返回: {FORTAN_VAR_NAME: CommonVar (包含 rust_field, rust_struct 信息)}
|
||
"""
|
||
var_map: Dict[str, CommonVar] = {}
|
||
|
||
# 先收集所有 COMMON 变量
|
||
for block_name, block in common_blocks.items():
|
||
for var in block.variables:
|
||
var_map[var.name] = var
|
||
|
||
# 构建 struct_name -> RustStruct 映射
|
||
struct_by_common: Dict[str, RustStruct] = {}
|
||
for rs in rust_structs:
|
||
if rs.common_name:
|
||
struct_by_common[rs.common_name.upper()] = rs
|
||
|
||
# 交叉引用
|
||
for var_name, var in var_map.items():
|
||
# 查找对应 Rust struct
|
||
rs = struct_by_common.get(var.common_block)
|
||
if rs:
|
||
var.rust_struct = rs.name
|
||
var.rust_file = rs.file
|
||
# 查找对应字段
|
||
rust_field_name = _fortran_to_rust_name(var_name)
|
||
if rust_field_name in rs.fields:
|
||
var.rust_field = rust_field_name
|
||
|
||
# 设置 CommonBlock 的 rust_struct 信息
|
||
for block_name, block in common_blocks.items():
|
||
rs = struct_by_common.get(block_name)
|
||
if rs:
|
||
block.rust_struct = rs.name
|
||
block.rust_file = rs.file
|
||
|
||
return var_map
|
||
|
||
|
||
# ============================================================================
|
||
# 模块级查询
|
||
# ============================================================================
|
||
|
||
def get_includes_for_module(module_name: str) -> List[str]:
|
||
"""获取某 Fortran 模块 INCLUDE 的文件列表"""
|
||
fpath = os.path.join(EXTRACTED_DIR, f"{module_name.lower()}.f")
|
||
if not os.path.exists(fpath):
|
||
return []
|
||
with open(fpath, 'r', encoding='utf-8', errors='ignore') as f:
|
||
content = f.read()
|
||
includes = re.findall(r"INCLUDE\s*'([^']+)\.FOR'", content, re.IGNORECASE)
|
||
return [inc.upper() for inc in includes if inc.upper() != 'IMPLIC']
|
||
|
||
|
||
def get_commons_for_module(module_name: str) -> List[str]:
|
||
"""获取某 Fortran 模块使用的 COMMON 块名列表"""
|
||
includes = get_includes_for_module(module_name)
|
||
commons = set()
|
||
for inc in includes:
|
||
fpath = os.path.join(FORTRAN_COMMON_DIR, f"{inc}.FOR")
|
||
if not os.path.exists(fpath):
|
||
continue
|
||
with open(fpath, 'r', encoding='utf-8', errors='ignore') as f:
|
||
content = f.read()
|
||
blocks = re.findall(r'(?i)COMMON\s*/(\w+)/', content)
|
||
commons.update(b.upper() for b in blocks)
|
||
return sorted(commons)
|
||
|
||
|
||
def get_vars_for_module(
|
||
module_name: str,
|
||
var_map: Dict[str, CommonVar]
|
||
) -> Dict[str, CommonVar]:
|
||
"""返回某模块用到的所有 COMMON 变量及其映射
|
||
|
||
参数:
|
||
module_name: Fortran 模块名
|
||
var_map: build_mapping() 的返回值
|
||
|
||
返回: {VAR_NAME: CommonVar}
|
||
"""
|
||
commons = get_commons_for_module(module_name)
|
||
result = {}
|
||
for var_name, var in var_map.items():
|
||
if var.common_block in commons:
|
||
result[var_name] = var
|
||
return result
|
||
|
||
|
||
def get_rust_structs_for_module(
|
||
module_name: str,
|
||
rust_structs: List[RustStruct]
|
||
) -> List[str]:
|
||
"""获取某模块需要 use 的 Rust struct 文件路径"""
|
||
commons = get_commons_for_module(module_name)
|
||
files = set()
|
||
for rs in rust_structs:
|
||
if rs.common_name and rs.common_name.upper() in commons:
|
||
files.add(rs.file)
|
||
return sorted(files)
|
||
|
||
|
||
# ============================================================================
|
||
# 缓存单例
|
||
# ============================================================================
|
||
|
||
_cached_mapping = None
|
||
_cached_structs = None
|
||
_cached_blocks = None
|
||
|
||
|
||
def get_mapping():
|
||
"""获取缓存的变量映射"""
|
||
global _cached_mapping, _cached_structs, _cached_blocks
|
||
if _cached_mapping is None:
|
||
_cached_blocks = parse_all_commons()
|
||
_cached_structs = parse_rust_structs()
|
||
_cached_mapping = build_mapping(_cached_blocks, _cached_structs)
|
||
return _cached_mapping
|
||
|
||
|
||
def get_structs():
|
||
"""获取缓存的 Rust struct 列表"""
|
||
global _cached_structs
|
||
if _cached_structs is None:
|
||
get_mapping()
|
||
return _cached_structs
|
||
|
||
|
||
def get_blocks():
|
||
"""获取缓存的 COMMON 块"""
|
||
global _cached_blocks
|
||
if _cached_blocks is None:
|
||
get_mapping()
|
||
return _cached_blocks
|
||
|
||
|
||
# ============================================================================
|
||
# CLI
|
||
# ============================================================================
|
||
|
||
def main():
|
||
import argparse
|
||
parser = argparse.ArgumentParser(description='COMMON 变量映射数据库')
|
||
parser.add_argument('--module', help='显示某模块使用的 COMMON 变量')
|
||
parser.add_argument('--block', help='显示某 COMMON 块的变量')
|
||
parser.add_argument('--mapping', action='store_true', help='显示完整映射')
|
||
parser.add_argument('--unmapped', action='store_true', help='显示未映射的变量')
|
||
args = parser.parse_args()
|
||
|
||
var_map = get_mapping()
|
||
blocks = get_blocks()
|
||
structs = get_structs()
|
||
|
||
if args.module:
|
||
vars = get_vars_for_module(args.module.upper(), var_map)
|
||
print(f"模块 {args.module.upper()} 使用的 COMMON 变量:")
|
||
print(f" 总计: {len(vars)} 个变量")
|
||
for vname, var in sorted(vars.items()):
|
||
dims_str = f"({', '.join(var.dims)})" if var.dims else ""
|
||
rust_str = f"→ {var.rust_struct}.{var.rust_field}" if var.rust_field else "→ (未映射)"
|
||
print(f" {vname:20s} {dims_str:20s} {rust_str}")
|
||
return
|
||
|
||
if args.block:
|
||
block = blocks.get(args.block.upper())
|
||
if not block:
|
||
print(f"COMMON 块 {args.block} 未找到")
|
||
return
|
||
print(f"COMMON /{block.name}/ (文件: {block.file})")
|
||
for var in block.variables:
|
||
dims_str = f"({', '.join(var.dims)})" if var.dims else ""
|
||
rust_str = f"→ {var.rust_field}" if var.rust_field else "→ (未映射)"
|
||
print(f" {var.name:20s} {dims_str:20s} {rust_str}")
|
||
return
|
||
|
||
if args.unmapped:
|
||
unmapped = {k: v for k, v in var_map.items() if not v.rust_field}
|
||
print(f"未映射的 COMMON 变量: {len(unmapped)} / {len(var_map)}")
|
||
for vname, var in sorted(unmapped.items()):
|
||
dims_str = f"({', '.join(var.dims)})" if var.dims else ""
|
||
print(f" /{var.common_block}/ {vname:20s} {dims_str}")
|
||
return
|
||
|
||
if args.mapping:
|
||
print(f"COMMON 变量映射统计:")
|
||
mapped = sum(1 for v in var_map.values() if v.rust_field)
|
||
print(f" 总变量: {len(var_map)}")
|
||
print(f" 已映射: {mapped}")
|
||
print(f" 未映射: {len(var_map) - mapped}")
|
||
print()
|
||
print("COMMON 块:")
|
||
for bname, block in sorted(blocks.items()):
|
||
n_mapped = sum(1 for v in block.variables if v.rust_field)
|
||
print(f" /{bname}/ → {block.rust_struct or '(无)'} ({n_mapped}/{len(block.variables)})")
|
||
return
|
||
|
||
# 默认:统计信息
|
||
print("COMMON 变量映射数据库")
|
||
print(f" COMMON 块: {len(blocks)}")
|
||
print(f" COMMON 变量: {len(var_map)}")
|
||
print(f" Rust struct: {len(structs)}")
|
||
mapped = sum(1 for v in var_map.values() if v.rust_field)
|
||
print(f" 已映射: {mapped}/{len(var_map)}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|