SpectraRust/scripts/extract_fortran_data.py
2026-03-19 22:16:23 +08:00

439 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
从 Fortran 源文件提取数组数据,生成 Rust data.rs
用法:
python3 scripts/extract_fortran_data.py
输出: src/data.rs
"""
import re
from pathlib import Path
def parse_fortran_file(filepath: Path, global_params: dict = None) -> list[dict]:
"""解析单个 Fortran 文件中的数组
Args:
filepath: Fortran 文件路径
global_params: 从 include 文件中提取的全局参数表
"""
if global_params is None:
global_params = {}
with open(filepath, 'r') as f:
content = f.read()
arrays = {}
# 0. 预处理
# 先清理每行的 Fortran 注释 (! 后面的内容)
# 同时移除 Fortran 77 风格的注释行 (以 C 或 * 开头)
lines = content.split('\n')
cleaned_lines = []
for line in lines:
# Fortran 77 注释行 (第1列是 C, c, *, 或完全空行)
if len(line) > 0 and line[0] in 'Cc*':
continue # 跳过整行注释
# Fortran 90 行内注释
if '!' in line:
line = line.split('!')[0]
cleaned_lines.append(line)
# 再合并 Fortran 续行 (第6列是 *, +, &, 数字, 或字母)
# Fortran 允许使用字母作为续行标记 (A, B, C, ... 用于超过9个续行)
merged_lines = []
for line in cleaned_lines:
# 检查是否是续行 (第6列是 *, +, &, 数字, 或非空格字符)
if len(line) >= 6 and line[5] != ' ' and line[5] not in '\n\r\t':
# 续行: 追加到上一行 (去掉前6列)
if merged_lines:
merged_lines[-1] += ' ' + line[6:].strip()
else:
merged_lines.append(line)
content = '\n'.join(merged_lines)
# 1. 首先解析所有 parameter 语句,建立局部常量表
# 合并全局参数和局部参数
param_table = dict(global_params) # 复制全局参数
param_pattern = r'parameter\s*\(([^)]+)\)'
for match in re.finditer(param_pattern, content, re.IGNORECASE):
params_str = match.group(1)
for param in params_str.split(','):
param = param.strip()
if '=' in param:
name, val = param.split('=', 1)
name = name.strip().lower()
val = val.strip().lower()
# 尝试解析为整数或浮点数
try:
# 先尝试整数
if '.' not in val and 'e' not in val and 'd' not in val:
param_table[name] = int(val)
else:
# 浮点数,转换为整数(用于数组维度)
val = val.replace('d', 'e')
param_table[name] = int(float(val))
except ValueError:
pass
# 2. 解析 dimension 语句
# dimension p4a(22), p4b(10,28), adi(nni), ...
# 注意: 使用 [ \t] 代替 \s 避免跨行匹配
dim_pattern = r'dimension[ \t]+([a-z0-9_,() \t]+)'
for match in re.finditer(dim_pattern, content, re.IGNORECASE):
dim_str = match.group(1)
# 清理 dim_str - 移除可能包含的下一个关键字
for keyword in ['\nreal', '\ninteger', '\ncomplex', '\nlogical', '\ncharacter',
'\ndimension', '\ndata', '\nparameter', '\nequivalence']:
if keyword in dim_str.lower():
dim_str = dim_str[:dim_str.lower().find(keyword)]
break
arr_pattern = r'(\w+)\s*\(([^)]+)\)'
for arr_match in re.finditer(arr_pattern, dim_str):
name = arr_match.group(1).lower()
dims_str = arr_match.group(2)
# 解析维度,支持常量和 parameter 变量
dims = []
valid = True
for d in dims_str.split(','):
d = d.strip().lower()
if d in param_table:
dims.append(param_table[d])
else:
try:
dims.append(int(d))
except ValueError:
valid = False
break
if valid and dims:
arrays[name] = {"name": name, "dims": dims, "data": None, "source": filepath.name}
# 2.5 解析类型声明中的数组
# REAL frac(MR), INTEGER arr(10), REAL*4 arr(10), CHARACTER*10 str(5), etc.
# 注意: 使用 [ \t] 代替 \s 避免跨行匹配
type_decl_pattern = r'(real(?:\*[\d]+)?|integer(?:\*[\d]+)?|complex(?:\*[\d]+)?|logical(?:\*[\d]+)?|character(?:\*[\d]+)?)[ \t]+([a-z0-9_,() \t]+)'
for match in re.finditer(type_decl_pattern, content, re.IGNORECASE):
decl_str = match.group(2)
# 清理 decl_str - 移除可能包含的下一个类型声明
for keyword in ['\nreal', '\ninteger', '\ncomplex', '\nlogical', '\ncharacter',
'\ndimension', '\ndata', '\nparameter', '\nequivalence']:
if keyword in decl_str.lower():
decl_str = decl_str[:decl_str.lower().find(keyword)]
break
# 匹配变量名(维度)
arr_pattern = r'(\w+)\s*\(([^)]+)\)'
for arr_match in re.finditer(arr_pattern, decl_str):
name = arr_match.group(1).lower()
if name in arrays:
continue # 已有定义
dims_str = arr_match.group(2)
# 解析维度
dims = []
valid = True
for d in dims_str.split(','):
d = d.strip().lower()
if d in param_table:
dims.append(param_table[d])
else:
try:
dims.append(int(d))
except ValueError:
valid = False
break
if valid and dims:
arrays[name] = {"name": name, "dims": dims, "data": None, "source": filepath.name}
# 3. 解析 data 语句 (支持多行)
# data name / val1, val2, ... /
data_pattern = r'data\s+(\w+)\s*/\s*([^/]+)\s*/'
for match in re.finditer(data_pattern, content, re.IGNORECASE | re.DOTALL):
name = match.group(1).lower()
data_str = match.group(2)
if name not in arrays:
arrays[name] = {"name": name, "dims": [], "data": None, "source": filepath.name}
values = parse_data_values(data_str)
arrays[name]["data"] = values
# 4. 处理 parameter 语句中的标量常量(用于导出)
for match in re.finditer(param_pattern, content, re.IGNORECASE):
params_str = match.group(1)
for param in params_str.split(','):
param = param.strip()
if '=' in param:
name, val = param.split('=', 1)
name = name.strip().lower()
val = val.strip().lower()
if name not in arrays:
try:
val = val.replace('d', 'e')
arrays[name] = {"name": name, "dims": [], "data": [float(val)], "source": filepath.name, "is_param": True}
except ValueError:
pass
return list(arrays.values())
def parse_data_values(data_str: str) -> list[float]:
"""解析 DATA 语句中的数值,处理重复语法如 7*1.387"""
values = []
# 清理
lines = data_str.split('\n')
cleaned_lines = []
for line in lines:
line = line.strip()
if line.startswith('*'):
line = line[1:].strip()
cleaned_lines.append(line)
data_str = ' '.join(cleaned_lines)
for part in data_str.split(','):
part = part.strip()
if not part:
continue
# 移除末尾的 / (DATA 语句结束符)
part = part.rstrip('/')
# 处理重复语法: "7*1.387"
if '*' in part and not part.startswith('-'):
match = re.match(r'(\d+)\s*\*\s*(-?[\d.]+)', part)
if match:
count = int(match.group(1))
val = float(match.group(2))
values.extend([val] * count)
continue
try:
# 处理 Fortran 科学计数法
# 处理 "- 14.2" 这种中间有空格的负数
val = part.replace('d', 'e').replace('D', 'e')
# 移除负号和数字之间的空格
val = re.sub(r'-\s+(\d)', r'-\1', val)
# 移除科学计数法中的多余空格 (如 "1.48 e-2" -> "1.48e-2")
val = re.sub(r'(\d)\s+([eEdD])', r'\1\2', val)
values.append(float(val))
except ValueError:
pass
return values
def generate_data_rs(all_arrays: dict[str, list[dict]]) -> str:
"""生成 src/data.rs 内容"""
lines = []
lines.append("//! Fortran 数据数组自动导出")
lines.append("//!")
lines.append("//! 由 extract_fortran_data.py 自动生成,请勿手动修改")
lines.append("")
# 收集已使用的名称,避免重复
used_names = set()
# 按源文件分组
for source, arrays in sorted(all_arrays.items()):
# 过滤有数据的数组
valid_arrays = [a for a in arrays if a.get("data") and len(a["data"]) > 0]
if not valid_arrays:
continue
lines.append(f"// ========== {source} ==========")
lines.append("")
for arr in valid_arrays:
base_name = arr["name"].upper()
# 清理名称中的特殊字符
base_name = re.sub(r'[^A-Z0-9_]', '', base_name)
# 所有变量都添加文件名前缀,避免命名冲突
prefix = Path(source).stem.upper()[:8] # 取文件名前8个字符
name = f"{prefix}_{base_name}"
# 如果加上前缀后仍有冲突,添加序号
if name in used_names:
counter = 1
while f"{name}_{counter}" in used_names:
counter += 1
name = f"{name}_{counter}"
used_names.add(name)
dims = arr["dims"]
data = arr["data"]
if not data:
continue
total_size = len(data)
# 跳过单值 parameter
if arr.get("is_param") and total_size == 1:
lines.append(f"/// {arr['name']} (from {source})")
lines.append(f"pub const {name}: f64 = {data[0]};")
lines.append("")
continue
if len(dims) == 0:
# 未知维度,用 Vec 格式输出以便检查
lines.append(f"/// {arr['name']} (from {source}, 未知维度,共 {len(data)} 个值)")
lines.append(f"pub const {name}: [f64; {len(data)}] = [")
for i, val in enumerate(data):
if i % 10 == 0:
lines.append(" ")
lines[-1] += f"{val},"
lines.append("];")
lines.append("")
elif len(dims) == 1:
# 1D 数组 - 检查数据量是否匹配
expected_size = dims[0]
if len(data) != expected_size:
print(f"警告: {name} 期望 {expected_size} 个值,实际 {len(data)} 个,跳过")
continue
lines.append(f"/// {arr['name']}({dims[0]}) from {source}")
lines.append(f"pub const {name}: [f64; {dims[0]}] = [")
for i, val in enumerate(data):
if i % 10 == 0:
lines.append(" ")
lines[-1] += f"{val},"
lines.append("];")
lines.append("")
elif len(dims) == 2:
# 2D 数组 - 直接转换为 Rust 行优先格式
nj, ni = dims[0], dims[1]
expected_size = nj * ni
if len(data) < expected_size:
print(f"警告: {name} 期望 {expected_size} 个值,实际 {len(data)} 个,跳过")
continue
lines.append(f"/// {arr['name']}({nj}, {ni}) from {source}")
lines.append(f"/// 已转换为 Rust 行优先格式")
lines.append(f"pub const {name}: [[f64; {ni}]; {nj}] = [")
# 列优先 → 行优先 转换
for j in range(nj):
row = []
for i in range(ni):
idx = j + i * nj # Fortran 列优先索引
row.append(str(data[idx]))
lines.append(f" [{','.join(row)}],")
lines.append("];")
lines.append("")
# 不再需要转换函数和 getter2D 数组直接生成为 const
return '\n'.join(lines)
def parse_include_files(extracted_dir: Path) -> dict:
"""解析 .FOR include 文件中的全局参数"""
global_params = {}
# 扫描 .FOR 文件
for for_file in extracted_dir.glob("*.FOR"):
try:
content = for_file.read_text()
except:
# 也尝试 tlusty/ 根目录
for_file = Path("tlusty") / for_file.name
if for_file.exists():
content = for_file.read_text()
else:
continue
# 清理注释
lines = []
for line in content.split('\n'):
if '!' in line:
line = line.split('!')[0]
lines.append(line)
content = '\n'.join(lines)
# 解析 parameter 语句
param_pattern = r'parameter\s*\(([^)]+)\)'
for match in re.finditer(param_pattern, content, re.IGNORECASE):
params_str = match.group(1)
for param in params_str.split(','):
param = param.strip()
if '=' in param:
name, val = param.split('=', 1)
name = name.strip().lower()
val = val.strip().lower()
try:
if '.' not in val and 'e' not in val and 'd' not in val:
global_params[name] = int(val)
else:
val = val.replace('d', 'e')
global_params[name] = int(float(val))
except ValueError:
pass
return global_params
def main():
# 扫描 tlusty/extracted 目录
extracted_dir = Path("tlusty/extracted")
if not extracted_dir.exists():
print(f"错误: 目录不存在: {extracted_dir}")
return
# 首先解析 include 文件中的全局参数
global_params = parse_include_files(extracted_dir)
print(f"从 .FOR include 文件中提取了 {len(global_params)} 个全局参数")
all_arrays = {}
# 扫描所有 .f 文件
for fortran_file in sorted(extracted_dir.glob("*.f")):
arrays = parse_fortran_file(fortran_file, global_params)
if arrays:
all_arrays[fortran_file.name] = arrays
print(f"解析: {fortran_file.name} -> {len(arrays)} 个数组")
# 统计
total_arrays = sum(len(arrs) for arrs in all_arrays.values())
arrays_with_data = sum(
1 for arrs in all_arrays.values()
for a in arrs if a.get("data") and len(a["data"]) > 0
)
arrays_2d = sum(
1 for arrs in all_arrays.values()
for a in arrs if len(a.get("dims", [])) == 2 and a.get("data")
)
print()
print("=" * 60)
print(f"总计: {total_arrays} 个数组, {arrays_with_data} 个有数据, {arrays_2d} 个 2D 数组")
print("=" * 60)
# 生成 data.rs
output_path = Path("src/data.rs")
rust_code = generate_data_rs(all_arrays)
with open(output_path, 'w') as f:
f.write(rust_code)
print(f"已生成: {output_path}")
print()
print("在 lib.rs 或 main.rs 中添加:")
print(" pub mod data;")
print()
print("使用方法:")
print(" use crate::data::{TT, PN, get_p4b};")
print(" let p4b = get_p4b(); // 自动初始化并返回 2D 数组")
if __name__ == "__main__":
main()