# NF3与N2的亨利系数和吸附热模拟，使用 RASPA2 的 Widom 插入方法

import math
import re
import shutil
import subprocess
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import pandas as pd
from pymatgen.core import Structure


base_dir = Path(__file__).resolve().parent

FORCE_FIELD_MIXING_RULES = base_dir / "force_field_mixing_rules.def"
PSEUDO_ATOMS = base_dir / "pseudo_atoms.def"
N2_DEF = base_dir / "N2.def"
NF3_DEF = base_dir / "NF3.def"

CONDA_EXE = shutil.which("conda") or str(base_dir.parent / "miniconda3" / "bin" / "conda")

N2_HENRY_COL = "N2_Henry_coefficient [mol/kg/Pa]"
NF3_HENRY_COL = "NF3_Henry_coefficient [mol/kg/Pa]"
N2_QST_COL = "N2_<U_gh>_1-<U_h>_0 [kJ/mol]"
NF3_QST_COL = "NF3_<U_gh>_1-<U_h>_0 [kJ/mol]"

N2_HENRY_PATTERN = re.compile(r"\[N2\]\s+Average Henry coefficient:\s*([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?)")
NF3_HENRY_PATTERN = re.compile(r"\[NF3\]\s+Average Henry coefficient:\s*([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?)")
N2_QST_PATTERN = re.compile(r"\[N2\]\s+Average\s+<U_gh>_1-<U_h>_0:.*\(\s*([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?)\s*\+/-")
NF3_QST_PATTERN = re.compile(r"\[NF3\]\s+Average\s+<U_gh>_1-<U_h>_0:.*\(\s*([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?)\s*\+/-")


def normalize_material_name(material):
	if material is None or pd.isna(material):
		return None
	name = str(material).strip()
	if not name:
		return None
	if name.lower().endswith(".cif"):
		name = name[:-4]
	return name


def calculate_expansion_multipliers(cif_path, cutoff=12.0):
	"""确保超胞在三个方向都满足长度 > 2 * cutoff。"""
	try:
		struct = Structure.from_file(cif_path)
		lattice = struct.lattice
		v = lattice.volume
		a, b, c = lattice.abc
		alpha, beta, gamma = map(math.radians, lattice.angles)

		a_perp = v / (b * c * math.sin(alpha))
		b_perp = v / (a * c * math.sin(beta))
		c_perp = v / (a * b * math.sin(gamma))

		min_box_length = 2.0 * cutoff
		na = max(1, math.ceil(min_box_length / a_perp))
		nb = max(1, math.ceil(min_box_length / b_perp))
		nc = max(1, math.ceil(min_box_length / c_perp))
		return na, nb, nc
	except Exception:
		return 2, 2, 2


def create_simulation_input(cif_name, write_dir, na, nb, nc, hevf):
	input_text = f"""
SimulationType                MonteCarlo
NumberOfCycles                1000
NumberOfInitializationCycles  1000
PrintEvery                    0
RestartFile                   no

Forcefield                    GenericMOFs
ChargeMethod                  Ewald
CutOff                        12.0

Framework                     0
FrameworkName                 {cif_name}
UnitCells                     {na} {nb} {nc}
ExternalTemperature           298.0
HeliumVoidFraction            {hevf:.6f}
UseChargesFromCIFFile         yes

Component 0 MoleculeName              N2
			MoleculeDefinition        TraPPE
			WidomProbability          1.0
			CreateNumberOfMolecules   0

Component 1 MoleculeName              NF3
			MoleculeDefinition        local
			WidomProbability          1.0
			CreateNumberOfMolecules   0
"""
	(write_dir / "simulation.input").write_text(input_text, encoding="utf-8")


def prepare_simulation_workspace(cif_name, cif_source, sim_root):
	"""将单次模拟需要的文件放入一个独立目录。"""
	sim_dir = sim_root / cif_name / "raspa" / "simulation_NF3-N2_KH-Qst"
	if sim_dir.exists():
		shutil.rmtree(sim_dir)
	sim_dir.mkdir(parents=True, exist_ok=True)

	cif_dest = sim_dir / f"{cif_name}.cif"
	shutil.copy(cif_source, cif_dest)

	for f in [FORCE_FIELD_MIXING_RULES, PSEUDO_ATOMS, N2_DEF, NF3_DEF]:
		if not f.exists():
			raise FileNotFoundError(f"缺少必需文件: {f}")
		shutil.copy(f, sim_dir)

	return sim_dir, cif_dest


def process_single_material(cif_name, hevf, cifs_root_dir, output_root_dir):
	cif_source = cifs_root_dir / f"{cif_name}.cif"
	if not cif_source.exists():
		return {
			"Material": cif_name,
			"HeVF": hevf,
			"Status": "missing_cif",
			"Message": f"未找到 CIF 文件: {cif_source}",
		}

	if pd.isna(hevf) or float(hevf) <= 0:
		return {
			"Material": cif_name,
			"HeVF": hevf,
			"Status": "invalid_hevf",
			"Message": "HeVF 缺失或 <= 0，跳过模拟",
		}

	if not Path(CONDA_EXE).exists():
		return {
			"Material": cif_name,
			"HeVF": hevf,
			"Status": "missing_conda",
			"Message": f"未找到 conda 可执行文件: {CONDA_EXE}",
		}

	try:
		sim_dir, cif_dest = prepare_simulation_workspace(cif_name, cif_source, output_root_dir)
		na, nb, nc = calculate_expansion_multipliers(cif_dest, cutoff=12.0)
		create_simulation_input(cif_name, sim_dir, na, nb, nc, float(hevf))

		subprocess.run(
			[CONDA_EXE, "run", "-n", "raspa2_env", "simulate"],
			cwd=sim_dir,
			check=True,
			capture_output=True,
			text=True,
		)
		return {
			"Material": cif_name,
			"HeVF": hevf,
			"Status": "ok",
			"Message": "成功",
		}
	except subprocess.CalledProcessError as e:
		return {
			"Material": cif_name,
			"HeVF": hevf,
			"Status": "simulate_failed",
			"Message": (e.stderr or e.stdout or str(e)).strip(),
		}
	except Exception as e:
		return {
			"Material": cif_name,
			"HeVF": hevf,
			"Status": "system_error",
			"Message": str(e),
		}


def extract_nf3_n2_metrics_from_output(output_data_file):
	"""提取 N2/NF3 的 Henry 系数与吸附热(kJ/mol)。"""
	text = output_data_file.read_text(encoding="utf-8", errors="ignore")

	n2_henry = N2_HENRY_PATTERN.search(text)
	nf3_henry = NF3_HENRY_PATTERN.search(text)
	n2_qst = N2_QST_PATTERN.search(text)
	nf3_qst = NF3_QST_PATTERN.search(text)

	return {
		N2_HENRY_COL: float(n2_henry.group(1)) if n2_henry else None,
		NF3_HENRY_COL: float(nf3_henry.group(1)) if nf3_henry else None,
		N2_QST_COL: float(n2_qst.group(1)) if n2_qst else None,
		NF3_QST_COL: float(nf3_qst.group(1)) if nf3_qst else None,
	}


def collect_nf3_n2_metrics(materials, output_root_dir):
	rows = []
	for cif_name in materials:
		system0_dir = output_root_dir / cif_name / "raspa" / "simulation_NF3-N2_KH-Qst" / "Output" / "System_0"
		output_files = sorted(system0_dir.glob("output_*.data")) if system0_dir.exists() else []

		row = {"Material": cif_name, N2_HENRY_COL: None, NF3_HENRY_COL: None, N2_QST_COL: None, NF3_QST_COL: None}
		if output_files:
			row.update(extract_nf3_n2_metrics_from_output(output_files[-1]))
		rows.append(row)

	return pd.DataFrame(rows)


def merge_metrics_into_new_csv(base_df, metrics_df, output_csv):
	for col in [N2_HENRY_COL, NF3_HENRY_COL, N2_QST_COL, NF3_QST_COL]:
		if col in base_df.columns:
			base_df = base_df.drop(columns=[col])

	merged_df = base_df.merge(metrics_df, on="Material", how="left")
	merged_df.to_csv(output_csv, index=False)

	found_n2 = int(merged_df[N2_HENRY_COL].notna().sum()) if N2_HENRY_COL in merged_df.columns else 0
	found_nf3 = int(merged_df[NF3_HENRY_COL].notna().sum()) if NF3_HENRY_COL in merged_df.columns else 0
	print(f"已导出新 CSV: {output_csv}")
	print(f"N2 Henry 提取成功: {found_n2}/{len(merged_df)}")
	print(f"NF3 Henry 提取成功: {found_nf3}/{len(merged_df)}")


def main():
	csv_dir = base_dir / "csvs"
	cifs_dir = base_dir / "cifs"
	output_dir = base_dir / "high_through_raspa_simulate_output"
	input_csv = csv_dir / "rasap_simulate_HeVF_2nd.csv"
	merged_output_csv = csv_dir / "rasap_simulate_KH_Qst_3rd.csv"

	if not input_csv.exists():
		print(f"错误：未找到输入文件 {input_csv}")
		return

	df = pd.read_csv(input_csv)
	if "Material" not in df.columns or "HeVF" not in df.columns:
		print("错误：输入 CSV 必须包含 Material 和 HeVF 列")
		return

	df["Material"] = df["Material"].apply(normalize_material_name)
	df = df[df["Material"].notna()].copy()

	tasks = list(df[["Material", "HeVF"]].itertuples(index=False, name=None))
	total = len(tasks)
	if total == 0:
		print("没有可运行的材料记录。")
		return

	print(f"共 {total} 个材料，开始 N2+NF3 Widom 模拟...")

	result_rows = []
	with ProcessPoolExecutor(max_workers=16) as executor:
		futures = [
			executor.submit(process_single_material, m, h, cifs_dir, output_dir)
			for m, h in tasks
		]

		completed = 0
		for future in as_completed(futures):
			completed += 1
			row = future.result()
			result_rows.append(row)
			if row["Status"] == "ok":
				print(f"[{completed}/{total}] 成功: {row['Material']}")
			else:
				print(f"[{completed}/{total}] 失败: {row['Material']} | {row['Status']}")

	result_df = pd.DataFrame(result_rows)
	run_log_csv = csv_dir / "raspa_N2_NF3_run_log.csv"
	result_df.to_csv(run_log_csv, index=False)

	metrics_df = collect_nf3_n2_metrics([m for m, _ in tasks], output_dir)
	merge_metrics_into_new_csv(df, metrics_df, merged_output_csv)

	ok_count = int((result_df["Status"] == "ok").sum()) if not result_df.empty else 0
	print(f"\n任务完成：{ok_count}/{total} 成功")
	print(f"运行日志已保存到: {run_log_csv}")


if __name__ == "__main__":
	main()
