"""NF3/N2(0.1/0.9) 竞争吸附自动化脚本（RASPA2 GCMC）。"""

import math
import re
import shutil
import subprocess
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import pandas as pd
from pymatgen.core import Structure


base_dir = Path(__file__).resolve().parent

FORCE_FIELD_MIXING_RULES = base_dir / "force_field_mixing_rules.def"
PSEUDO_ATOMS = base_dir / "pseudo_atoms.def"
N2_DEF = base_dir / "N2.def"
NF3_DEF = base_dir / "NF3.def"

CONDA_EXE = shutil.which("conda") or str(base_dir.parent / "miniconda3" / "bin" / "conda")

TEMPERATURE_K = 298.0
PRESSURES_PA = [100000]
NF3_MOL_FRACTION = 0.10
N2_MOL_FRACTION = 0.90

N2_ABS_COL = "N2_loading_absolute [mol/kg]"
NF3_ABS_COL = "NF3_loading_absolute [mol/kg]"
N2_EXCESS_COL = "N2_loading_excess [mol/kg]"
NF3_EXCESS_COL = "NF3_loading_excess [mol/kg]"
SELECTIVITY_COL = "NF3/N2_selectivity"


def normalize_material_name(material):
	if material is None or pd.isna(material):
		return None
	name = str(material).strip()
	if not name:
		return None
	if name.lower().endswith(".cif"):
		name = name[:-4]
	return name


def calculate_expansion_multipliers(cif_path, cutoff=12.0):
	"""确保超胞在三个方向都满足长度 > 2 * cutoff。"""
	try:
		struct = Structure.from_file(cif_path)
		lattice = struct.lattice
		v = lattice.volume
		a, b, c = lattice.abc
		alpha, beta, gamma = map(math.radians, lattice.angles)

		a_perp = v / (b * c * math.sin(alpha))
		b_perp = v / (a * c * math.sin(beta))
		c_perp = v / (a * b * math.sin(gamma))

		min_box_length = 2.0 * cutoff
		na = max(1, math.ceil(min_box_length / a_perp))
		nb = max(1, math.ceil(min_box_length / b_perp))
		nc = max(1, math.ceil(min_box_length / c_perp))
		return na, nb, nc
	except Exception:
		return 2, 2, 2


def create_simulation_input(cif_name, write_dir, na, nb, nc, hevf, pressure_pa):
	input_text = f"""
SimulationType                MonteCarlo
NumberOfCycles                50000
NumberOfInitializationCycles  50000
PrintEvery                    0
RestartFile                   no

Forcefield                    GenericMOFs
ChargeMethod                  Ewald
EwaldPrecision                1e-6
CutOff                        12.0

Framework                     0
FrameworkName                 {cif_name}
UnitCells                     {na} {nb} {nc}
ExternalTemperature           {TEMPERATURE_K}
HeliumVoidFraction            {hevf:.6f}
ExternalPressure              {pressure_pa}
UseChargesFromCIFFile         yes

Component 0 MoleculeName              NF3
			MoleculeDefinition        local
			MolFraction               {NF3_MOL_FRACTION:.2f}
			TranslationProbability    1.0
			RotationProbability       1.0
			ReinsertionProbability    1.0
			SwapProbability           1.0
			IdentityChangeProbability 1.0
			  NumberOfIdentityChanges 2
			  IdentityChangesList     0 1
			CreateNumberOfMolecules   0

Component 1 MoleculeName              N2
			MoleculeDefinition        local
			MolFraction               {N2_MOL_FRACTION:.2f}
			TranslationProbability    1.0
			RotationProbability       1.0
			ReinsertionProbability    1.0
			SwapProbability           1.0
			IdentityChangeProbability 1.0
			  NumberOfIdentityChanges 2
			  IdentityChangesList     0 1
			CreateNumberOfMolecules   0
"""
	(write_dir / "simulation.input").write_text(input_text, encoding="utf-8")


def prepare_simulation_workspace(cif_name, cif_source, sim_root, pressure_pa):
	"""将单次模拟需要的文件放入一个独立目录。"""
	sim_dir = sim_root / cif_name / "raspa" / f"simulation_NF3-N2_competitive_{pressure_pa}Pa"
	if sim_dir.exists():
		shutil.rmtree(sim_dir)
	sim_dir.mkdir(parents=True, exist_ok=True)

	cif_dest = sim_dir / f"{cif_name}.cif"
	shutil.copy(cif_source, cif_dest)

	for f in [FORCE_FIELD_MIXING_RULES, PSEUDO_ATOMS, N2_DEF, NF3_DEF]:
		if not f.exists():
			raise FileNotFoundError(f"缺少必需文件: {f}")
		shutil.copy(f, sim_dir)

	return sim_dir, cif_dest


def process_single_task(cif_name, hevf, pressure_pa, cifs_root_dir, output_root_dir):
	cif_source = cifs_root_dir / f"{cif_name}.cif"
	if not cif_source.exists():
		return {
			"Material": cif_name,
			"HeVF": hevf,
			"Pressure_Pa": pressure_pa,
			"Status": "missing_cif",
			"Message": f"未找到 CIF 文件: {cif_source}",
		}

	if pd.isna(hevf) or float(hevf) <= 0:
		return {
			"Material": cif_name,
			"HeVF": hevf,
			"Pressure_Pa": pressure_pa,
			"Status": "invalid_hevf",
			"Message": "HeVF 缺失或 <= 0，跳过模拟",
		}

	if not Path(CONDA_EXE).exists():
		return {
			"Material": cif_name,
			"HeVF": hevf,
			"Pressure_Pa": pressure_pa,
			"Status": "missing_conda",
			"Message": f"未找到 conda 可执行文件: {CONDA_EXE}",
		}

	try:
		sim_dir, cif_dest = prepare_simulation_workspace(cif_name, cif_source, output_root_dir, pressure_pa)
		na, nb, nc = calculate_expansion_multipliers(cif_dest, cutoff=12.0)
		create_simulation_input(cif_name, sim_dir, na, nb, nc, float(hevf), pressure_pa)

		subprocess.run(
			[CONDA_EXE, "run", "-n", "raspa2_env", "simulate"],
			cwd=sim_dir,
			check=True,
			capture_output=True,
			text=True,
		)
		return {
			"Material": cif_name,
			"HeVF": hevf,
			"Pressure_Pa": pressure_pa,
			"Status": "ok",
			"Message": "成功",
		}
	except subprocess.CalledProcessError as e:
		return {
			"Material": cif_name,
			"HeVF": hevf,
			"Pressure_Pa": pressure_pa,
			"Status": "simulate_failed",
			"Message": (e.stderr or e.stdout or str(e)).strip(),
		}
	except Exception as e:
		return {
			"Material": cif_name,
			"HeVF": hevf,
			"Pressure_Pa": pressure_pa,
			"Status": "system_error",
			"Message": str(e),
		}


def extract_component_metric(text, component_name, metric_type):
	"""提取指定组分在 Number of molecules 区块的平均吸附量。"""
	pattern = re.compile(
		rf"Component\s+\d+\s+\[{re.escape(component_name)}\].*?Average loading {metric_type} \[mol/kg framework\]\s+([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?)",
		re.S,
	)
	match = pattern.search(text)
	return float(match.group(1)) if match else None


def extract_competitive_metrics_from_output(output_data_file):
	text = output_data_file.read_text(encoding="utf-8", errors="ignore")

	n2_abs = extract_component_metric(text, "N2", "absolute")
	nf3_abs = extract_component_metric(text, "NF3", "absolute")
	n2_excess = extract_component_metric(text, "N2", "excess")
	nf3_excess = extract_component_metric(text, "NF3", "excess")

	selectivity = None
	if n2_abs is not None and nf3_abs is not None and n2_abs > 0:
		selectivity = (nf3_abs / NF3_MOL_FRACTION) / (n2_abs / N2_MOL_FRACTION)

	return {
		N2_ABS_COL: n2_abs,
		NF3_ABS_COL: nf3_abs,
		N2_EXCESS_COL: n2_excess,
		NF3_EXCESS_COL: nf3_excess,
		SELECTIVITY_COL: selectivity,
	}


def collect_competitive_metrics(tasks, output_root_dir):
	rows = []
	for cif_name, _, pressure_pa in tasks:
		system0_dir = output_root_dir / cif_name / "raspa" / f"simulation_NF3-N2_competitive_{pressure_pa}Pa" / "Output" / "System_0"
		output_files = sorted(system0_dir.glob("output_*.data")) if system0_dir.exists() else []

		row = {
			"Material": cif_name,
			"Pressure_Pa": pressure_pa,
			N2_ABS_COL: None,
			NF3_ABS_COL: None,
			N2_EXCESS_COL: None,
			NF3_EXCESS_COL: None,
			SELECTIVITY_COL: None,
		}
		if output_files:
			row.update(extract_competitive_metrics_from_output(output_files[-1]))
		rows.append(row)

	return pd.DataFrame(rows)


def merge_metrics_into_new_csv(base_df, metrics_df, output_csv):
	if metrics_df.empty:
		base_df.to_csv(output_csv, index=False)
		print(f"未提取到竞争吸附结果，已导出原始表: {output_csv}")
		return

	if len(PRESSURES_PA) == 1:
		merge_df = metrics_df.drop(columns=["Pressure_Pa"])
		for col in [N2_ABS_COL, NF3_ABS_COL, N2_EXCESS_COL, NF3_EXCESS_COL, SELECTIVITY_COL]:
			if col in base_df.columns:
				base_df = base_df.drop(columns=[col])
		merged_df = base_df.merge(merge_df, on="Material", how="left")
	else:
		wide_df = metrics_df.copy()
		wide_df["Pressure_Pa"] = wide_df["Pressure_Pa"].astype(int).astype(str) + "Pa"
		wide_df = wide_df.set_index(["Material", "Pressure_Pa"]).unstack("Pressure_Pa")
		wide_df.columns = [f"{name}_{pressure}" for name, pressure in wide_df.columns]
		wide_df = wide_df.reset_index()
		merged_df = base_df.merge(wide_df, on="Material", how="left")

	merged_df.to_csv(output_csv, index=False)

	found_nf3 = int(merged_df[NF3_ABS_COL].notna().sum()) if NF3_ABS_COL in merged_df.columns else 0
	print(f"已导出新 CSV: {output_csv}")
	print(f"NF3 绝对吸附量提取成功: {found_nf3}/{len(merged_df)}")


def main():
	csv_dir = base_dir / "csvs"
	cifs_dir = base_dir / "cifs"
	output_dir = base_dir / "high_through_raspa_simulate_output"
	input_csv = csv_dir / "rasap_simulate_KH_Qst_3rd.csv"
	run_log_csv = csv_dir / "raspa_NF3_N2_competitive_run_log.csv"
	metrics_csv = csv_dir / "raspa_NF3_N2_competitive_metrics.csv"
	merged_output_csv = csv_dir / "rasap_simulate_competitive_4th.csv"

	if not input_csv.exists():
		print(f"错误：未找到输入文件 {input_csv}")
		return

	df = pd.read_csv(input_csv)
	if "Material" not in df.columns or "HeVF" not in df.columns:
		print("错误：输入 CSV 必须包含 Material 和 HeVF 列")
		return

	df["Material"] = df["Material"].apply(normalize_material_name)
	df = df[df["Material"].notna()].copy()

	tasks = []
	for material, hevf in df[["Material", "HeVF"]].itertuples(index=False, name=None):
		for pressure in PRESSURES_PA:
			tasks.append((material, hevf, int(pressure)))

	total = len(tasks)
	if total == 0:
		print("没有可运行的材料记录。")
		return

	print(f"共 {total} 个任务，开始 NF3/N2={NF3_MOL_FRACTION:.2f}/{N2_MOL_FRACTION:.2f} 竞争吸附模拟...")

	result_rows = []
	with ProcessPoolExecutor(max_workers=16) as executor:
		futures = [
			executor.submit(process_single_task, m, h, p, cifs_dir, output_dir)
			for m, h, p in tasks
		]

		completed = 0
		for future in as_completed(futures):
			completed += 1
			row = future.result()
			result_rows.append(row)
			if row["Status"] == "ok":
				print(f"[{completed}/{total}] 成功: {row['Material']} @ {row['Pressure_Pa']} Pa")
			else:
				print(f"[{completed}/{total}] 失败: {row['Material']} @ {row['Pressure_Pa']} Pa | {row['Status']}")

	result_df = pd.DataFrame(result_rows)
	result_df.to_csv(run_log_csv, index=False)

	metrics_df = collect_competitive_metrics(tasks, output_dir)
	metrics_df.to_csv(metrics_csv, index=False)
	merge_metrics_into_new_csv(df, metrics_df, merged_output_csv)

	ok_count = int((result_df["Status"] == "ok").sum()) if not result_df.empty else 0
	print(f"\n任务完成：{ok_count}/{total} 成功")
	print(f"运行日志已保存到: {run_log_csv}")
	print(f"指标明细已保存到: {metrics_csv}")


if __name__ == "__main__":
	main()
