439 lines
16 KiB
Python
439 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
从 Fortran 源文件提取数组数据,生成 Rust data.rs
|
||
|
||
用法:
|
||
python3 scripts/extract_fortran_data.py
|
||
|
||
输出: src/data.rs
|
||
"""
|
||
|
||
import re
|
||
from pathlib import Path
|
||
|
||
|
||
def parse_fortran_file(filepath: Path, global_params: dict = None) -> list[dict]:
|
||
"""解析单个 Fortran 文件中的数组
|
||
|
||
Args:
|
||
filepath: Fortran 文件路径
|
||
global_params: 从 include 文件中提取的全局参数表
|
||
"""
|
||
if global_params is None:
|
||
global_params = {}
|
||
|
||
with open(filepath, 'r') as f:
|
||
content = f.read()
|
||
|
||
arrays = {}
|
||
|
||
# 0. 预处理
|
||
# 先清理每行的 Fortran 注释 (! 后面的内容)
|
||
# 同时移除 Fortran 77 风格的注释行 (以 C 或 * 开头)
|
||
lines = content.split('\n')
|
||
cleaned_lines = []
|
||
for line in lines:
|
||
# Fortran 77 注释行 (第1列是 C, c, *, 或完全空行)
|
||
if len(line) > 0 and line[0] in 'Cc*':
|
||
continue # 跳过整行注释
|
||
# Fortran 90 行内注释
|
||
if '!' in line:
|
||
line = line.split('!')[0]
|
||
cleaned_lines.append(line)
|
||
|
||
# 再合并 Fortran 续行 (第6列是 *, +, &, 数字, 或字母)
|
||
# Fortran 允许使用字母作为续行标记 (A, B, C, ... 用于超过9个续行)
|
||
merged_lines = []
|
||
for line in cleaned_lines:
|
||
# 检查是否是续行 (第6列是 *, +, &, 数字, 或非空格字符)
|
||
if len(line) >= 6 and line[5] != ' ' and line[5] not in '\n\r\t':
|
||
# 续行: 追加到上一行 (去掉前6列)
|
||
if merged_lines:
|
||
merged_lines[-1] += ' ' + line[6:].strip()
|
||
else:
|
||
merged_lines.append(line)
|
||
content = '\n'.join(merged_lines)
|
||
|
||
# 1. 首先解析所有 parameter 语句,建立局部常量表
|
||
# 合并全局参数和局部参数
|
||
param_table = dict(global_params) # 复制全局参数
|
||
param_pattern = r'parameter\s*\(([^)]+)\)'
|
||
for match in re.finditer(param_pattern, content, re.IGNORECASE):
|
||
params_str = match.group(1)
|
||
for param in params_str.split(','):
|
||
param = param.strip()
|
||
if '=' in param:
|
||
name, val = param.split('=', 1)
|
||
name = name.strip().lower()
|
||
val = val.strip().lower()
|
||
# 尝试解析为整数或浮点数
|
||
try:
|
||
# 先尝试整数
|
||
if '.' not in val and 'e' not in val and 'd' not in val:
|
||
param_table[name] = int(val)
|
||
else:
|
||
# 浮点数,转换为整数(用于数组维度)
|
||
val = val.replace('d', 'e')
|
||
param_table[name] = int(float(val))
|
||
except ValueError:
|
||
pass
|
||
|
||
# 2. 解析 dimension 语句
|
||
# dimension p4a(22), p4b(10,28), adi(nni), ...
|
||
# 注意: 使用 [ \t] 代替 \s 避免跨行匹配
|
||
dim_pattern = r'dimension[ \t]+([a-z0-9_,() \t]+)'
|
||
for match in re.finditer(dim_pattern, content, re.IGNORECASE):
|
||
dim_str = match.group(1)
|
||
# 清理 dim_str - 移除可能包含的下一个关键字
|
||
for keyword in ['\nreal', '\ninteger', '\ncomplex', '\nlogical', '\ncharacter',
|
||
'\ndimension', '\ndata', '\nparameter', '\nequivalence']:
|
||
if keyword in dim_str.lower():
|
||
dim_str = dim_str[:dim_str.lower().find(keyword)]
|
||
break
|
||
arr_pattern = r'(\w+)\s*\(([^)]+)\)'
|
||
for arr_match in re.finditer(arr_pattern, dim_str):
|
||
name = arr_match.group(1).lower()
|
||
dims_str = arr_match.group(2)
|
||
# 解析维度,支持常量和 parameter 变量
|
||
dims = []
|
||
valid = True
|
||
for d in dims_str.split(','):
|
||
d = d.strip().lower()
|
||
if d in param_table:
|
||
dims.append(param_table[d])
|
||
else:
|
||
try:
|
||
dims.append(int(d))
|
||
except ValueError:
|
||
valid = False
|
||
break
|
||
if valid and dims:
|
||
arrays[name] = {"name": name, "dims": dims, "data": None, "source": filepath.name}
|
||
|
||
# 2.5 解析类型声明中的数组
|
||
# REAL frac(MR), INTEGER arr(10), REAL*4 arr(10), CHARACTER*10 str(5), etc.
|
||
# 注意: 使用 [ \t] 代替 \s 避免跨行匹配
|
||
type_decl_pattern = r'(real(?:\*[\d]+)?|integer(?:\*[\d]+)?|complex(?:\*[\d]+)?|logical(?:\*[\d]+)?|character(?:\*[\d]+)?)[ \t]+([a-z0-9_,() \t]+)'
|
||
for match in re.finditer(type_decl_pattern, content, re.IGNORECASE):
|
||
decl_str = match.group(2)
|
||
# 清理 decl_str - 移除可能包含的下一个类型声明
|
||
for keyword in ['\nreal', '\ninteger', '\ncomplex', '\nlogical', '\ncharacter',
|
||
'\ndimension', '\ndata', '\nparameter', '\nequivalence']:
|
||
if keyword in decl_str.lower():
|
||
decl_str = decl_str[:decl_str.lower().find(keyword)]
|
||
break
|
||
|
||
# 匹配变量名(维度)
|
||
arr_pattern = r'(\w+)\s*\(([^)]+)\)'
|
||
for arr_match in re.finditer(arr_pattern, decl_str):
|
||
name = arr_match.group(1).lower()
|
||
if name in arrays:
|
||
continue # 已有定义
|
||
dims_str = arr_match.group(2)
|
||
# 解析维度
|
||
dims = []
|
||
valid = True
|
||
for d in dims_str.split(','):
|
||
d = d.strip().lower()
|
||
if d in param_table:
|
||
dims.append(param_table[d])
|
||
else:
|
||
try:
|
||
dims.append(int(d))
|
||
except ValueError:
|
||
valid = False
|
||
break
|
||
if valid and dims:
|
||
arrays[name] = {"name": name, "dims": dims, "data": None, "source": filepath.name}
|
||
|
||
# 3. 解析 data 语句 (支持多行)
|
||
# data name / val1, val2, ... /
|
||
data_pattern = r'data\s+(\w+)\s*/\s*([^/]+)\s*/'
|
||
for match in re.finditer(data_pattern, content, re.IGNORECASE | re.DOTALL):
|
||
name = match.group(1).lower()
|
||
data_str = match.group(2)
|
||
|
||
if name not in arrays:
|
||
arrays[name] = {"name": name, "dims": [], "data": None, "source": filepath.name}
|
||
|
||
values = parse_data_values(data_str)
|
||
arrays[name]["data"] = values
|
||
|
||
# 4. 处理 parameter 语句中的标量常量(用于导出)
|
||
for match in re.finditer(param_pattern, content, re.IGNORECASE):
|
||
params_str = match.group(1)
|
||
for param in params_str.split(','):
|
||
param = param.strip()
|
||
if '=' in param:
|
||
name, val = param.split('=', 1)
|
||
name = name.strip().lower()
|
||
val = val.strip().lower()
|
||
if name not in arrays:
|
||
try:
|
||
val = val.replace('d', 'e')
|
||
arrays[name] = {"name": name, "dims": [], "data": [float(val)], "source": filepath.name, "is_param": True}
|
||
except ValueError:
|
||
pass
|
||
|
||
return list(arrays.values())
|
||
|
||
|
||
def parse_data_values(data_str: str) -> list[float]:
|
||
"""解析 DATA 语句中的数值,处理重复语法如 7*1.387"""
|
||
values = []
|
||
|
||
# 清理
|
||
lines = data_str.split('\n')
|
||
cleaned_lines = []
|
||
for line in lines:
|
||
line = line.strip()
|
||
if line.startswith('*'):
|
||
line = line[1:].strip()
|
||
cleaned_lines.append(line)
|
||
data_str = ' '.join(cleaned_lines)
|
||
|
||
for part in data_str.split(','):
|
||
part = part.strip()
|
||
if not part:
|
||
continue
|
||
|
||
# 移除末尾的 / (DATA 语句结束符)
|
||
part = part.rstrip('/')
|
||
|
||
# 处理重复语法: "7*1.387"
|
||
if '*' in part and not part.startswith('-'):
|
||
match = re.match(r'(\d+)\s*\*\s*(-?[\d.]+)', part)
|
||
if match:
|
||
count = int(match.group(1))
|
||
val = float(match.group(2))
|
||
values.extend([val] * count)
|
||
continue
|
||
|
||
try:
|
||
# 处理 Fortran 科学计数法
|
||
# 处理 "- 14.2" 这种中间有空格的负数
|
||
val = part.replace('d', 'e').replace('D', 'e')
|
||
# 移除负号和数字之间的空格
|
||
val = re.sub(r'-\s+(\d)', r'-\1', val)
|
||
# 移除科学计数法中的多余空格 (如 "1.48 e-2" -> "1.48e-2")
|
||
val = re.sub(r'(\d)\s+([eEdD])', r'\1\2', val)
|
||
values.append(float(val))
|
||
except ValueError:
|
||
pass
|
||
|
||
return values
|
||
|
||
|
||
def generate_data_rs(all_arrays: dict[str, list[dict]]) -> str:
|
||
"""生成 src/data.rs 内容"""
|
||
lines = []
|
||
lines.append("//! Fortran 数据数组自动导出")
|
||
lines.append("//!")
|
||
lines.append("//! 由 extract_fortran_data.py 自动生成,请勿手动修改")
|
||
lines.append("")
|
||
|
||
# 收集已使用的名称,避免重复
|
||
used_names = set()
|
||
|
||
# 按源文件分组
|
||
for source, arrays in sorted(all_arrays.items()):
|
||
# 过滤有数据的数组
|
||
valid_arrays = [a for a in arrays if a.get("data") and len(a["data"]) > 0]
|
||
if not valid_arrays:
|
||
continue
|
||
|
||
lines.append(f"// ========== {source} ==========")
|
||
lines.append("")
|
||
|
||
for arr in valid_arrays:
|
||
base_name = arr["name"].upper()
|
||
# 清理名称中的特殊字符
|
||
base_name = re.sub(r'[^A-Z0-9_]', '', base_name)
|
||
|
||
# 所有变量都添加文件名前缀,避免命名冲突
|
||
prefix = Path(source).stem.upper()[:8] # 取文件名前8个字符
|
||
name = f"{prefix}_{base_name}"
|
||
|
||
# 如果加上前缀后仍有冲突,添加序号
|
||
if name in used_names:
|
||
counter = 1
|
||
while f"{name}_{counter}" in used_names:
|
||
counter += 1
|
||
name = f"{name}_{counter}"
|
||
|
||
used_names.add(name)
|
||
|
||
dims = arr["dims"]
|
||
data = arr["data"]
|
||
|
||
if not data:
|
||
continue
|
||
|
||
total_size = len(data)
|
||
|
||
# 跳过单值 parameter
|
||
if arr.get("is_param") and total_size == 1:
|
||
lines.append(f"/// {arr['name']} (from {source})")
|
||
lines.append(f"pub const {name}: f64 = {data[0]};")
|
||
lines.append("")
|
||
continue
|
||
|
||
if len(dims) == 0:
|
||
# 未知维度,用 Vec 格式输出以便检查
|
||
lines.append(f"/// {arr['name']} (from {source}, 未知维度,共 {len(data)} 个值)")
|
||
lines.append(f"pub const {name}: [f64; {len(data)}] = [")
|
||
for i, val in enumerate(data):
|
||
if i % 10 == 0:
|
||
lines.append(" ")
|
||
lines[-1] += f"{val},"
|
||
lines.append("];")
|
||
lines.append("")
|
||
|
||
elif len(dims) == 1:
|
||
# 1D 数组 - 检查数据量是否匹配
|
||
expected_size = dims[0]
|
||
if len(data) != expected_size:
|
||
print(f"警告: {name} 期望 {expected_size} 个值,实际 {len(data)} 个,跳过")
|
||
continue
|
||
|
||
lines.append(f"/// {arr['name']}({dims[0]}) from {source}")
|
||
lines.append(f"pub const {name}: [f64; {dims[0]}] = [")
|
||
for i, val in enumerate(data):
|
||
if i % 10 == 0:
|
||
lines.append(" ")
|
||
lines[-1] += f"{val},"
|
||
lines.append("];")
|
||
lines.append("")
|
||
|
||
elif len(dims) == 2:
|
||
# 2D 数组 - 直接转换为 Rust 行优先格式
|
||
nj, ni = dims[0], dims[1]
|
||
expected_size = nj * ni
|
||
|
||
if len(data) < expected_size:
|
||
print(f"警告: {name} 期望 {expected_size} 个值,实际 {len(data)} 个,跳过")
|
||
continue
|
||
|
||
lines.append(f"/// {arr['name']}({nj}, {ni}) from {source}")
|
||
lines.append(f"/// 已转换为 Rust 行优先格式")
|
||
lines.append(f"pub const {name}: [[f64; {ni}]; {nj}] = [")
|
||
|
||
# 列优先 → 行优先 转换
|
||
for j in range(nj):
|
||
row = []
|
||
for i in range(ni):
|
||
idx = j + i * nj # Fortran 列优先索引
|
||
row.append(str(data[idx]))
|
||
lines.append(f" [{','.join(row)}],")
|
||
|
||
lines.append("];")
|
||
lines.append("")
|
||
|
||
# 不再需要转换函数和 getter,2D 数组直接生成为 const
|
||
|
||
return '\n'.join(lines)
|
||
|
||
|
||
def parse_include_files(extracted_dir: Path) -> dict:
|
||
"""解析 .FOR include 文件中的全局参数"""
|
||
global_params = {}
|
||
|
||
# 扫描 .FOR 文件
|
||
for for_file in extracted_dir.glob("*.FOR"):
|
||
try:
|
||
content = for_file.read_text()
|
||
except:
|
||
# 也尝试 tlusty/ 根目录
|
||
for_file = Path("tlusty") / for_file.name
|
||
if for_file.exists():
|
||
content = for_file.read_text()
|
||
else:
|
||
continue
|
||
|
||
# 清理注释
|
||
lines = []
|
||
for line in content.split('\n'):
|
||
if '!' in line:
|
||
line = line.split('!')[0]
|
||
lines.append(line)
|
||
content = '\n'.join(lines)
|
||
|
||
# 解析 parameter 语句
|
||
param_pattern = r'parameter\s*\(([^)]+)\)'
|
||
for match in re.finditer(param_pattern, content, re.IGNORECASE):
|
||
params_str = match.group(1)
|
||
for param in params_str.split(','):
|
||
param = param.strip()
|
||
if '=' in param:
|
||
name, val = param.split('=', 1)
|
||
name = name.strip().lower()
|
||
val = val.strip().lower()
|
||
try:
|
||
if '.' not in val and 'e' not in val and 'd' not in val:
|
||
global_params[name] = int(val)
|
||
else:
|
||
val = val.replace('d', 'e')
|
||
global_params[name] = int(float(val))
|
||
except ValueError:
|
||
pass
|
||
|
||
return global_params
|
||
|
||
|
||
def main():
|
||
# 扫描 tlusty/extracted 目录
|
||
extracted_dir = Path("tlusty/extracted")
|
||
|
||
if not extracted_dir.exists():
|
||
print(f"错误: 目录不存在: {extracted_dir}")
|
||
return
|
||
|
||
# 首先解析 include 文件中的全局参数
|
||
global_params = parse_include_files(extracted_dir)
|
||
print(f"从 .FOR include 文件中提取了 {len(global_params)} 个全局参数")
|
||
|
||
all_arrays = {}
|
||
|
||
# 扫描所有 .f 文件
|
||
for fortran_file in sorted(extracted_dir.glob("*.f")):
|
||
arrays = parse_fortran_file(fortran_file, global_params)
|
||
if arrays:
|
||
all_arrays[fortran_file.name] = arrays
|
||
print(f"解析: {fortran_file.name} -> {len(arrays)} 个数组")
|
||
|
||
# 统计
|
||
total_arrays = sum(len(arrs) for arrs in all_arrays.values())
|
||
arrays_with_data = sum(
|
||
1 for arrs in all_arrays.values()
|
||
for a in arrs if a.get("data") and len(a["data"]) > 0
|
||
)
|
||
arrays_2d = sum(
|
||
1 for arrs in all_arrays.values()
|
||
for a in arrs if len(a.get("dims", [])) == 2 and a.get("data")
|
||
)
|
||
|
||
print()
|
||
print("=" * 60)
|
||
print(f"总计: {total_arrays} 个数组, {arrays_with_data} 个有数据, {arrays_2d} 个 2D 数组")
|
||
print("=" * 60)
|
||
|
||
# 生成 data.rs
|
||
output_path = Path("src/data.rs")
|
||
rust_code = generate_data_rs(all_arrays)
|
||
|
||
with open(output_path, 'w') as f:
|
||
f.write(rust_code)
|
||
|
||
print(f"已生成: {output_path}")
|
||
print()
|
||
print("在 lib.rs 或 main.rs 中添加:")
|
||
print(" pub mod data;")
|
||
print()
|
||
print("使用方法:")
|
||
print(" use crate::data::{TT, PN, get_p4b};")
|
||
print(" let p4b = get_p4b(); // 自动初始化并返回 2D 数组")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|