SpectraRust/scripts/fortran_to_rust_array.py
2026-03-19 22:16:23 +08:00

206 lines
6.3 KiB
Python

#!/usr/bin/env python3
"""
将 Fortran DATA 语句转换为 Rust 2D 数组
用法:
python3 scripts/fortran_to_rust_array.py tlusty/extracted/pffe.f
输出: Rust 代码片段,可直接复制到 .rs 文件中
"""
import re
import sys
from pathlib import Path
def parse_fortran_arrays(filepath: str) -> list[dict]:
"""
解析 Fortran 文件中的数组定义和 DATA 语句
返回: [{"name": "p4a", "dims": [22], "data": [...]}, ...]
"""
with open(filepath, 'r') as f:
content = f.read()
arrays = {}
# 1. 解析 dimension 语句
# dimension p4a(22), p4b(10,28), ...
dim_pattern = r'dimension\s+([a-z0-9_,()\s]+)'
for match in re.finditer(dim_pattern, content, re.IGNORECASE):
dim_str = match.group(1)
# 解析每个数组
arr_pattern = r'(\w+)\s*\(([^)]+)\)'
for arr_match in re.finditer(arr_pattern, dim_str):
name = arr_match.group(1).lower()
dims = [int(d.strip()) for d in arr_match.group(2).split(',')]
arrays[name] = {"name": name, "dims": dims, "data": None}
# 2. 解析 data 语句
# data p4a / val1, val2, ... /
# 或多行:
# data p4b /
# * val1, val2, ...,
# * val3, ... /
# 先找到所有 data 块
data_pattern = r'data\s+(\w+)\s*/\s*([^/]+)\s*/'
for match in re.finditer(data_pattern, content, re.IGNORECASE | re.DOTALL):
name = match.group(1).lower()
data_str = match.group(2)
if name not in arrays:
arrays[name] = {"name": name, "dims": [], "data": None}
# 解析数值
values = parse_data_values(data_str)
arrays[name]["data"] = values
return list(arrays.values())
def parse_data_values(data_str: str) -> list[float]:
"""解析 DATA 语句中的数值,处理重复语法如 7*1.387"""
values = []
# 清理: 移除注释、换行,但保留 * 用于重复语法
data_str = re.sub(r'[cC]\s*$', '', data_str) # 行尾注释
# 移除 Fortran 续行符 * (行首的 *),但保留数据中的 *
lines = data_str.split('\n')
cleaned_lines = []
for line in lines:
line = line.strip()
if line.startswith('*'):
line = line[1:].strip()
cleaned_lines.append(line)
data_str = ' '.join(cleaned_lines)
for part in data_str.split(','):
part = part.strip()
if not part:
continue
# 处理重复语法: "7*1.387" 或 "7*1.387"
if '*' in part and not part.startswith('-'): # 避免把负数当重复
# 检查是否真的是重复语法
match = re.match(r'(\d+)\s*\*\s*(-?[\d.]+)', part)
if match:
count = int(match.group(1))
val = float(match.group(2))
values.extend([val] * count)
continue
try:
values.append(float(part))
except ValueError:
# 跳过无法解析的部分
pass
return values
def generate_rust_code(arrays: list[dict]) -> str:
"""生成 Rust 代码"""
lines = []
lines.append("// 自动生成的数组数据")
lines.append("")
for arr in arrays:
name = arr["name"].upper()
dims = arr["dims"]
data = arr["data"]
if not data:
continue
total_size = len(data)
if len(dims) == 1:
# 1D 数组
lines.append(f"const {name}_RAW: [f64; {total_size}] = [")
for i, val in enumerate(data):
if i % 10 == 0:
lines.append(" ",)
lines[-1] += f"{val},"
lines.append("];")
lines.append(f"static {name}: OnceLock<[f64; {dims[0]}]> = OnceLock::new();")
lines.append("")
elif len(dims) == 2:
# 2D 数组
nj, ni = dims[0], dims[1]
lines.append(f"// {name}: Fortran {name.lower()}({nj}, {ni})")
lines.append(f"const {name}_RAW: [f64; {total_size}] = [")
for i, val in enumerate(data):
if i % 10 == 0:
lines.append(" ")
lines[-1] += f"{val},"
lines.append("];")
lines.append(f"static {name}: OnceLock<[[f64; {ni}]; {nj}]> = OnceLock::new();")
lines.append("")
# 添加转换函数
lines.append("")
lines.append("/// Fortran 列优先 → Rust 行优先")
lines.append("const fn fortran_to_rust_2d<const NJ: usize, const NI: usize>(")
lines.append(" data: &[f64; NJ * NI],")
lines.append(") -> [[f64; NI]; NJ] {")
lines.append(" let mut result = [[0.0; NI]; NJ];")
lines.append(" let mut i = 0;")
lines.append(" while i < NI {")
lines.append(" let mut j = 0;")
lines.append(" while j < NJ {")
lines.append(" result[j][i] = data[j + i * NJ];")
lines.append(" j += 1;")
lines.append(" }")
lines.append(" i += 1;")
lines.append(" }")
lines.append(" result")
lines.append("}")
lines.append("")
# 添加初始化代码
lines.append("// 初始化函数中调用:")
for arr in arrays:
name = arr["name"].upper()
dims = arr["dims"]
if len(dims) == 2 and arr["data"]:
nj, ni = dims[0], dims[1]
lines.append(f"let {name.lower()} = {name}.get_or_init(|| fortran_to_rust_2d::<{nj}, {ni}>(&{name}_RAW));")
return '\n'.join(lines)
def main():
if len(sys.argv) < 2:
print("用法: python3 scripts/fortran_to_rust_array.py <fortran_file>")
print("示例: python3 scripts/fortran_to_rust_array.py tlusty/extracted/pffe.f")
sys.exit(1)
filepath = sys.argv[1]
if not Path(filepath).exists():
print(f"错误: 文件不存在: {filepath}")
sys.exit(1)
print(f"解析: {filepath}")
print("=" * 60)
arrays = parse_fortran_arrays(filepath)
print(f"找到 {len(arrays)} 个数组:")
for arr in arrays:
dims_str = ', '.join(str(d) for d in arr['dims'])
data_count = len(arr['data']) if arr['data'] else 0
print(f" - {arr['name']}({dims_str}): {data_count} 个值")
print()
print("=" * 60)
print("生成的 Rust 代码:")
print("=" * 60)
print(generate_rust_code(arrays))
if __name__ == "__main__":
main()