206 lines
6.3 KiB
Python
206 lines
6.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
将 Fortran DATA 语句转换为 Rust 2D 数组
|
|
|
|
用法:
|
|
python3 scripts/fortran_to_rust_array.py tlusty/extracted/pffe.f
|
|
|
|
输出: Rust 代码片段,可直接复制到 .rs 文件中
|
|
"""
|
|
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
def parse_fortran_arrays(filepath: str) -> list[dict]:
|
|
"""
|
|
解析 Fortran 文件中的数组定义和 DATA 语句
|
|
|
|
返回: [{"name": "p4a", "dims": [22], "data": [...]}, ...]
|
|
"""
|
|
with open(filepath, 'r') as f:
|
|
content = f.read()
|
|
|
|
arrays = {}
|
|
|
|
# 1. 解析 dimension 语句
|
|
# dimension p4a(22), p4b(10,28), ...
|
|
dim_pattern = r'dimension\s+([a-z0-9_,()\s]+)'
|
|
for match in re.finditer(dim_pattern, content, re.IGNORECASE):
|
|
dim_str = match.group(1)
|
|
# 解析每个数组
|
|
arr_pattern = r'(\w+)\s*\(([^)]+)\)'
|
|
for arr_match in re.finditer(arr_pattern, dim_str):
|
|
name = arr_match.group(1).lower()
|
|
dims = [int(d.strip()) for d in arr_match.group(2).split(',')]
|
|
arrays[name] = {"name": name, "dims": dims, "data": None}
|
|
|
|
# 2. 解析 data 语句
|
|
# data p4a / val1, val2, ... /
|
|
# 或多行:
|
|
# data p4b /
|
|
# * val1, val2, ...,
|
|
# * val3, ... /
|
|
|
|
# 先找到所有 data 块
|
|
data_pattern = r'data\s+(\w+)\s*/\s*([^/]+)\s*/'
|
|
for match in re.finditer(data_pattern, content, re.IGNORECASE | re.DOTALL):
|
|
name = match.group(1).lower()
|
|
data_str = match.group(2)
|
|
|
|
if name not in arrays:
|
|
arrays[name] = {"name": name, "dims": [], "data": None}
|
|
|
|
# 解析数值
|
|
values = parse_data_values(data_str)
|
|
arrays[name]["data"] = values
|
|
|
|
return list(arrays.values())
|
|
|
|
|
|
def parse_data_values(data_str: str) -> list[float]:
|
|
"""解析 DATA 语句中的数值,处理重复语法如 7*1.387"""
|
|
values = []
|
|
|
|
# 清理: 移除注释、换行,但保留 * 用于重复语法
|
|
data_str = re.sub(r'[cC]\s*$', '', data_str) # 行尾注释
|
|
# 移除 Fortran 续行符 * (行首的 *),但保留数据中的 *
|
|
lines = data_str.split('\n')
|
|
cleaned_lines = []
|
|
for line in lines:
|
|
line = line.strip()
|
|
if line.startswith('*'):
|
|
line = line[1:].strip()
|
|
cleaned_lines.append(line)
|
|
data_str = ' '.join(cleaned_lines)
|
|
|
|
for part in data_str.split(','):
|
|
part = part.strip()
|
|
if not part:
|
|
continue
|
|
|
|
# 处理重复语法: "7*1.387" 或 "7*1.387"
|
|
if '*' in part and not part.startswith('-'): # 避免把负数当重复
|
|
# 检查是否真的是重复语法
|
|
match = re.match(r'(\d+)\s*\*\s*(-?[\d.]+)', part)
|
|
if match:
|
|
count = int(match.group(1))
|
|
val = float(match.group(2))
|
|
values.extend([val] * count)
|
|
continue
|
|
|
|
try:
|
|
values.append(float(part))
|
|
except ValueError:
|
|
# 跳过无法解析的部分
|
|
pass
|
|
|
|
return values
|
|
|
|
|
|
def generate_rust_code(arrays: list[dict]) -> str:
|
|
"""生成 Rust 代码"""
|
|
lines = []
|
|
lines.append("// 自动生成的数组数据")
|
|
lines.append("")
|
|
|
|
for arr in arrays:
|
|
name = arr["name"].upper()
|
|
dims = arr["dims"]
|
|
data = arr["data"]
|
|
|
|
if not data:
|
|
continue
|
|
|
|
total_size = len(data)
|
|
|
|
if len(dims) == 1:
|
|
# 1D 数组
|
|
lines.append(f"const {name}_RAW: [f64; {total_size}] = [")
|
|
for i, val in enumerate(data):
|
|
if i % 10 == 0:
|
|
lines.append(" ",)
|
|
lines[-1] += f"{val},"
|
|
lines.append("];")
|
|
lines.append(f"static {name}: OnceLock<[f64; {dims[0]}]> = OnceLock::new();")
|
|
lines.append("")
|
|
|
|
elif len(dims) == 2:
|
|
# 2D 数组
|
|
nj, ni = dims[0], dims[1]
|
|
lines.append(f"// {name}: Fortran {name.lower()}({nj}, {ni})")
|
|
lines.append(f"const {name}_RAW: [f64; {total_size}] = [")
|
|
for i, val in enumerate(data):
|
|
if i % 10 == 0:
|
|
lines.append(" ")
|
|
lines[-1] += f"{val},"
|
|
lines.append("];")
|
|
lines.append(f"static {name}: OnceLock<[[f64; {ni}]; {nj}]> = OnceLock::new();")
|
|
lines.append("")
|
|
|
|
# 添加转换函数
|
|
lines.append("")
|
|
lines.append("/// Fortran 列优先 → Rust 行优先")
|
|
lines.append("const fn fortran_to_rust_2d<const NJ: usize, const NI: usize>(")
|
|
lines.append(" data: &[f64; NJ * NI],")
|
|
lines.append(") -> [[f64; NI]; NJ] {")
|
|
lines.append(" let mut result = [[0.0; NI]; NJ];")
|
|
lines.append(" let mut i = 0;")
|
|
lines.append(" while i < NI {")
|
|
lines.append(" let mut j = 0;")
|
|
lines.append(" while j < NJ {")
|
|
lines.append(" result[j][i] = data[j + i * NJ];")
|
|
lines.append(" j += 1;")
|
|
lines.append(" }")
|
|
lines.append(" i += 1;")
|
|
lines.append(" }")
|
|
lines.append(" result")
|
|
lines.append("}")
|
|
lines.append("")
|
|
|
|
# 添加初始化代码
|
|
lines.append("// 初始化函数中调用:")
|
|
for arr in arrays:
|
|
name = arr["name"].upper()
|
|
dims = arr["dims"]
|
|
if len(dims) == 2 and arr["data"]:
|
|
nj, ni = dims[0], dims[1]
|
|
lines.append(f"let {name.lower()} = {name}.get_or_init(|| fortran_to_rust_2d::<{nj}, {ni}>(&{name}_RAW));")
|
|
|
|
return '\n'.join(lines)
|
|
|
|
|
|
def main():
|
|
if len(sys.argv) < 2:
|
|
print("用法: python3 scripts/fortran_to_rust_array.py <fortran_file>")
|
|
print("示例: python3 scripts/fortran_to_rust_array.py tlusty/extracted/pffe.f")
|
|
sys.exit(1)
|
|
|
|
filepath = sys.argv[1]
|
|
|
|
if not Path(filepath).exists():
|
|
print(f"错误: 文件不存在: {filepath}")
|
|
sys.exit(1)
|
|
|
|
print(f"解析: {filepath}")
|
|
print("=" * 60)
|
|
|
|
arrays = parse_fortran_arrays(filepath)
|
|
|
|
print(f"找到 {len(arrays)} 个数组:")
|
|
for arr in arrays:
|
|
dims_str = ', '.join(str(d) for d in arr['dims'])
|
|
data_count = len(arr['data']) if arr['data'] else 0
|
|
print(f" - {arr['name']}({dims_str}): {data_count} 个值")
|
|
|
|
print()
|
|
print("=" * 60)
|
|
print("生成的 Rust 代码:")
|
|
print("=" * 60)
|
|
print(generate_rust_code(arrays))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|