#!/usr/bin/env python3 """ 从 Fortran 源文件提取数组数据,生成 Rust data.rs 用法: python3 scripts/extract_fortran_data.py 输出: src/data.rs """ import re from pathlib import Path def parse_fortran_file(filepath: Path, global_params: dict = None) -> list[dict]: """解析单个 Fortran 文件中的数组 Args: filepath: Fortran 文件路径 global_params: 从 include 文件中提取的全局参数表 """ if global_params is None: global_params = {} with open(filepath, 'r') as f: content = f.read() arrays = {} # 0. 预处理 # 先清理每行的 Fortran 注释 (! 后面的内容) # 同时移除 Fortran 77 风格的注释行 (以 C 或 * 开头) lines = content.split('\n') cleaned_lines = [] for line in lines: # Fortran 77 注释行 (第1列是 C, c, *, 或完全空行) if len(line) > 0 and line[0] in 'Cc*': continue # 跳过整行注释 # Fortran 90 行内注释 if '!' in line: line = line.split('!')[0] cleaned_lines.append(line) # 再合并 Fortran 续行 (第6列是 *, +, &, 数字, 或字母) # Fortran 允许使用字母作为续行标记 (A, B, C, ... 用于超过9个续行) merged_lines = [] for line in cleaned_lines: # 检查是否是续行 (第6列是 *, +, &, 数字, 或非空格字符) if len(line) >= 6 and line[5] != ' ' and line[5] not in '\n\r\t': # 续行: 追加到上一行 (去掉前6列) if merged_lines: merged_lines[-1] += ' ' + line[6:].strip() else: merged_lines.append(line) content = '\n'.join(merged_lines) # 1. 首先解析所有 parameter 语句,建立局部常量表 # 合并全局参数和局部参数 param_table = dict(global_params) # 复制全局参数 param_pattern = r'parameter\s*\(([^)]+)\)' for match in re.finditer(param_pattern, content, re.IGNORECASE): params_str = match.group(1) for param in params_str.split(','): param = param.strip() if '=' in param: name, val = param.split('=', 1) name = name.strip().lower() val = val.strip().lower() # 尝试解析为整数或浮点数 try: # 先尝试整数 if '.' not in val and 'e' not in val and 'd' not in val: param_table[name] = int(val) else: # 浮点数,转换为整数(用于数组维度) val = val.replace('d', 'e') param_table[name] = int(float(val)) except ValueError: pass # 2. 解析 dimension 语句 # dimension p4a(22), p4b(10,28), adi(nni), ... # 注意: 使用 [ \t] 代替 \s 避免跨行匹配 dim_pattern = r'dimension[ \t]+([a-z0-9_,() \t]+)' for match in re.finditer(dim_pattern, content, re.IGNORECASE): dim_str = match.group(1) # 清理 dim_str - 移除可能包含的下一个关键字 for keyword in ['\nreal', '\ninteger', '\ncomplex', '\nlogical', '\ncharacter', '\ndimension', '\ndata', '\nparameter', '\nequivalence']: if keyword in dim_str.lower(): dim_str = dim_str[:dim_str.lower().find(keyword)] break arr_pattern = r'(\w+)\s*\(([^)]+)\)' for arr_match in re.finditer(arr_pattern, dim_str): name = arr_match.group(1).lower() dims_str = arr_match.group(2) # 解析维度,支持常量和 parameter 变量 dims = [] valid = True for d in dims_str.split(','): d = d.strip().lower() if d in param_table: dims.append(param_table[d]) else: try: dims.append(int(d)) except ValueError: valid = False break if valid and dims: arrays[name] = {"name": name, "dims": dims, "data": None, "source": filepath.name} # 2.5 解析类型声明中的数组 # REAL frac(MR), INTEGER arr(10), REAL*4 arr(10), CHARACTER*10 str(5), etc. # 注意: 使用 [ \t] 代替 \s 避免跨行匹配 type_decl_pattern = r'(real(?:\*[\d]+)?|integer(?:\*[\d]+)?|complex(?:\*[\d]+)?|logical(?:\*[\d]+)?|character(?:\*[\d]+)?)[ \t]+([a-z0-9_,() \t]+)' for match in re.finditer(type_decl_pattern, content, re.IGNORECASE): decl_str = match.group(2) # 清理 decl_str - 移除可能包含的下一个类型声明 for keyword in ['\nreal', '\ninteger', '\ncomplex', '\nlogical', '\ncharacter', '\ndimension', '\ndata', '\nparameter', '\nequivalence']: if keyword in decl_str.lower(): decl_str = decl_str[:decl_str.lower().find(keyword)] break # 匹配变量名(维度) arr_pattern = r'(\w+)\s*\(([^)]+)\)' for arr_match in re.finditer(arr_pattern, decl_str): name = arr_match.group(1).lower() if name in arrays: continue # 已有定义 dims_str = arr_match.group(2) # 解析维度 dims = [] valid = True for d in dims_str.split(','): d = d.strip().lower() if d in param_table: dims.append(param_table[d]) else: try: dims.append(int(d)) except ValueError: valid = False break if valid and dims: arrays[name] = {"name": name, "dims": dims, "data": None, "source": filepath.name} # 3. 解析 data 语句 (支持多行) # data name / val1, val2, ... / data_pattern = r'data\s+(\w+)\s*/\s*([^/]+)\s*/' for match in re.finditer(data_pattern, content, re.IGNORECASE | re.DOTALL): name = match.group(1).lower() data_str = match.group(2) if name not in arrays: arrays[name] = {"name": name, "dims": [], "data": None, "source": filepath.name} values = parse_data_values(data_str) arrays[name]["data"] = values # 4. 处理 parameter 语句中的标量常量(用于导出) for match in re.finditer(param_pattern, content, re.IGNORECASE): params_str = match.group(1) for param in params_str.split(','): param = param.strip() if '=' in param: name, val = param.split('=', 1) name = name.strip().lower() val = val.strip().lower() if name not in arrays: try: val = val.replace('d', 'e') arrays[name] = {"name": name, "dims": [], "data": [float(val)], "source": filepath.name, "is_param": True} except ValueError: pass return list(arrays.values()) def parse_data_values(data_str: str) -> list[float]: """解析 DATA 语句中的数值,处理重复语法如 7*1.387""" values = [] # 清理 lines = data_str.split('\n') cleaned_lines = [] for line in lines: line = line.strip() if line.startswith('*'): line = line[1:].strip() cleaned_lines.append(line) data_str = ' '.join(cleaned_lines) for part in data_str.split(','): part = part.strip() if not part: continue # 移除末尾的 / (DATA 语句结束符) part = part.rstrip('/') # 处理重复语法: "7*1.387" if '*' in part and not part.startswith('-'): match = re.match(r'(\d+)\s*\*\s*(-?[\d.]+)', part) if match: count = int(match.group(1)) val = float(match.group(2)) values.extend([val] * count) continue try: # 处理 Fortran 科学计数法 # 处理 "- 14.2" 这种中间有空格的负数 val = part.replace('d', 'e').replace('D', 'e') # 移除负号和数字之间的空格 val = re.sub(r'-\s+(\d)', r'-\1', val) # 移除科学计数法中的多余空格 (如 "1.48 e-2" -> "1.48e-2") val = re.sub(r'(\d)\s+([eEdD])', r'\1\2', val) values.append(float(val)) except ValueError: pass return values def generate_data_rs(all_arrays: dict[str, list[dict]]) -> str: """生成 src/data.rs 内容""" lines = [] lines.append("//! Fortran 数据数组自动导出") lines.append("//!") lines.append("//! 由 extract_fortran_data.py 自动生成,请勿手动修改") lines.append("") # 收集已使用的名称,避免重复 used_names = set() # 按源文件分组 for source, arrays in sorted(all_arrays.items()): # 过滤有数据的数组 valid_arrays = [a for a in arrays if a.get("data") and len(a["data"]) > 0] if not valid_arrays: continue lines.append(f"// ========== {source} ==========") lines.append("") for arr in valid_arrays: base_name = arr["name"].upper() # 清理名称中的特殊字符 base_name = re.sub(r'[^A-Z0-9_]', '', base_name) # 所有变量都添加文件名前缀,避免命名冲突 prefix = Path(source).stem.upper()[:8] # 取文件名前8个字符 name = f"{prefix}_{base_name}" # 如果加上前缀后仍有冲突,添加序号 if name in used_names: counter = 1 while f"{name}_{counter}" in used_names: counter += 1 name = f"{name}_{counter}" used_names.add(name) dims = arr["dims"] data = arr["data"] if not data: continue total_size = len(data) # 跳过单值 parameter if arr.get("is_param") and total_size == 1: lines.append(f"/// {arr['name']} (from {source})") lines.append(f"pub const {name}: f64 = {data[0]};") lines.append("") continue if len(dims) == 0: # 未知维度,用 Vec 格式输出以便检查 lines.append(f"/// {arr['name']} (from {source}, 未知维度,共 {len(data)} 个值)") lines.append(f"pub const {name}: [f64; {len(data)}] = [") for i, val in enumerate(data): if i % 10 == 0: lines.append(" ") lines[-1] += f"{val}," lines.append("];") lines.append("") elif len(dims) == 1: # 1D 数组 - 检查数据量是否匹配 expected_size = dims[0] if len(data) != expected_size: print(f"警告: {name} 期望 {expected_size} 个值,实际 {len(data)} 个,跳过") continue lines.append(f"/// {arr['name']}({dims[0]}) from {source}") lines.append(f"pub const {name}: [f64; {dims[0]}] = [") for i, val in enumerate(data): if i % 10 == 0: lines.append(" ") lines[-1] += f"{val}," lines.append("];") lines.append("") elif len(dims) == 2: # 2D 数组 - 直接转换为 Rust 行优先格式 nj, ni = dims[0], dims[1] expected_size = nj * ni if len(data) < expected_size: print(f"警告: {name} 期望 {expected_size} 个值,实际 {len(data)} 个,跳过") continue lines.append(f"/// {arr['name']}({nj}, {ni}) from {source}") lines.append(f"/// 已转换为 Rust 行优先格式") lines.append(f"pub const {name}: [[f64; {ni}]; {nj}] = [") # 列优先 → 行优先 转换 for j in range(nj): row = [] for i in range(ni): idx = j + i * nj # Fortran 列优先索引 row.append(str(data[idx])) lines.append(f" [{','.join(row)}],") lines.append("];") lines.append("") # 不再需要转换函数和 getter,2D 数组直接生成为 const return '\n'.join(lines) def parse_include_files(extracted_dir: Path) -> dict: """解析 .FOR include 文件中的全局参数""" global_params = {} # 扫描 .FOR 文件 for for_file in extracted_dir.glob("*.FOR"): try: content = for_file.read_text() except: # 也尝试 tlusty/ 根目录 for_file = Path("tlusty") / for_file.name if for_file.exists(): content = for_file.read_text() else: continue # 清理注释 lines = [] for line in content.split('\n'): if '!' in line: line = line.split('!')[0] lines.append(line) content = '\n'.join(lines) # 解析 parameter 语句 param_pattern = r'parameter\s*\(([^)]+)\)' for match in re.finditer(param_pattern, content, re.IGNORECASE): params_str = match.group(1) for param in params_str.split(','): param = param.strip() if '=' in param: name, val = param.split('=', 1) name = name.strip().lower() val = val.strip().lower() try: if '.' not in val and 'e' not in val and 'd' not in val: global_params[name] = int(val) else: val = val.replace('d', 'e') global_params[name] = int(float(val)) except ValueError: pass return global_params def main(): # 扫描 tlusty/extracted 目录 extracted_dir = Path("tlusty/extracted") if not extracted_dir.exists(): print(f"错误: 目录不存在: {extracted_dir}") return # 首先解析 include 文件中的全局参数 global_params = parse_include_files(extracted_dir) print(f"从 .FOR include 文件中提取了 {len(global_params)} 个全局参数") all_arrays = {} # 扫描所有 .f 文件 for fortran_file in sorted(extracted_dir.glob("*.f")): arrays = parse_fortran_file(fortran_file, global_params) if arrays: all_arrays[fortran_file.name] = arrays print(f"解析: {fortran_file.name} -> {len(arrays)} 个数组") # 统计 total_arrays = sum(len(arrs) for arrs in all_arrays.values()) arrays_with_data = sum( 1 for arrs in all_arrays.values() for a in arrs if a.get("data") and len(a["data"]) > 0 ) arrays_2d = sum( 1 for arrs in all_arrays.values() for a in arrs if len(a.get("dims", [])) == 2 and a.get("data") ) print() print("=" * 60) print(f"总计: {total_arrays} 个数组, {arrays_with_data} 个有数据, {arrays_2d} 个 2D 数组") print("=" * 60) # 生成 data.rs output_path = Path("src/data.rs") rust_code = generate_data_rs(all_arrays) with open(output_path, 'w') as f: f.write(rust_code) print(f"已生成: {output_path}") print() print("在 lib.rs 或 main.rs 中添加:") print(" pub mod data;") print() print("使用方法:") print(" use crate::data::{TT, PN, get_p4b};") print(" let p4b = get_p4b(); // 自动初始化并返回 2D 数组") if __name__ == "__main__": main()