import subprocess
import shutil
import pandas as pd
import re
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import math
from pymatgen.core import Structure


base_dir = Path(__file__).resolve().parent

FORCE_FIELD_MIXING_RULES = base_dir / "force_field_mixing_rules.def"
PSEUDO_ATOMS = base_dir / "pseudo_atoms.def"
HELIUM_DEF = base_dir / "helium.def"
CONDA_EXE = shutil.which("conda") or str(base_dir.parent / "miniconda3" / "bin" / "conda")
WIDOM_PATTERN = re.compile(
    r"\[helium\]\s+Average Widom Rosenbluth-weight:\s*([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?)"
)

def calculate_expansion_multipliers(cif_path, cutoff=12.0):
    try:
        struct = Structure.from_file(cif_path)
        lattice = struct.lattice
        v = lattice.volume
        a, b, c = lattice.abc
        alpha, beta, gamma = map(math.radians, lattice.angles)
        
        a_perp = v / (b * c * math.sin(alpha))
        b_perp = v / (a * c * math.sin(beta))
        c_perp = v / (a * b * math.sin(gamma))
        
        min_box_length = 2.0 * cutoff

        na = math.ceil(min_box_length / a_perp)
        nb = math.ceil(min_box_length / b_perp)
        nc = math.ceil(min_box_length / c_perp)
        return na, nb, nc
    except Exception:
        return 2, 2, 2
    

def create_simulation_input(cif_name, write_dir, na, nb, nc):

    input_text = f"""
SimulationType                MonteCarlo
NumberOfCycles                10000
PrintEvery                    1000
PrintPropertiesEvery          1000

Forcefield                    GenericMOFs
ChargeMethod                  Ewald
EwaldPrecision                1e-6
CutOff                        12.0
Framework                     0
FrameworkName                 {cif_name}
UnitCells                     {na} {nb} {nc}  
ExternalTemperature           298.0

Component 0 MoleculeName              helium
            MoleculeDefinition        TraPPE
            WidomProbability          1.0
            CreateNumberOfMolecules   0
"""
    input_file = write_dir / "simulation.input"
    input_file.write_text(input_text, encoding="utf-8")


def normalize_material_name(material):
    if material is None or pd.isna(material):
        return None
    name = str(material).strip()
    if not name:
        return None
    if name.lower().endswith(".cif"):
        name = name[:-4]
    return name


def prepare_simulation_workspace(cif_name, cif_source, raspa_dir):
    """将一次模拟所需输入文件集中到单个工作目录。"""
    sim_dir = raspa_dir / "simulation_HeVF"
    if sim_dir.exists():
        shutil.rmtree(sim_dir)
    sim_dir.mkdir(parents=True, exist_ok=True)

    cif_dest = sim_dir / f"{cif_name}.cif"
    shutil.copy(cif_source, cif_dest)

    if FORCE_FIELD_MIXING_RULES.exists():
        shutil.copy(FORCE_FIELD_MIXING_RULES, sim_dir)
    if PSEUDO_ATOMS.exists():
        shutil.copy(PSEUDO_ATOMS, sim_dir)
    if HELIUM_DEF.exists():
        shutil.copy(HELIUM_DEF, sim_dir)

    return sim_dir, cif_dest

def process_raspa_he(cif_name, cifs_root_dir, output_root_dir):
    """单独跑一个材料的 RASPA Helium Void Fraction 计算"""

    cif_source = cifs_root_dir / f"{cif_name}.cif"
    if not cif_source.exists():
        return cif_name, False, f"未找到原 CIF 文件: {cif_source}"


    raspa_dir = output_root_dir / cif_name / "raspa"
    raspa_dir.mkdir(parents=True, exist_ok=True)
    
    try:
        sim_dir, cif_dest = prepare_simulation_workspace(cif_name, cif_source, raspa_dir)

        na, nb, nc = calculate_expansion_multipliers(cif_dest)
        create_simulation_input(cif_name, sim_dir, na, nb, nc)
        
        if not Path(CONDA_EXE).exists():
            return cif_name, False, f"未找到 conda 可执行文件: {CONDA_EXE}"

        subprocess.run(
            [CONDA_EXE, "run", "-n", "raspa2_env", "simulate"],
            cwd=sim_dir,
            check=True, 
            capture_output=True, 
            text=True
        )
        return cif_name, True, "成功"
    except subprocess.CalledProcessError as e:
        return cif_name, False, f"RASPA 运行失败: {e.stderr.strip()}"
    except FileNotFoundError:
        return cif_name, False, "未找到 conda/simulate 命令，请确认 raspa2_e 环境可用"
    except Exception as e:
        return cif_name, False, f"发生系统错误: {str(e)}"


def extract_widom_weight(output_data_file):
    """从 RASPA 输出文件提取 [helium] Average Widom Rosenbluth-weight 数值。"""
    try:
        for line in output_data_file.read_text(encoding="utf-8", errors="ignore").splitlines():
            match = WIDOM_PATTERN.search(line)
            if match:
                return float(match.group(1))
    except Exception:
        return None
    return None


def collect_hevf_results(candidate_cifs, output_dir):
    rows = []
    for cif_name in candidate_cifs:
        system0_dir = output_dir / cif_name / "raspa" / "simulation_HeVF" / "Output" / "System_0"
        output_files = sorted(system0_dir.glob("output_*.data")) if system0_dir.exists() else []

        if not output_files:
            rows.append(
                {
                    "Material": cif_name,
                    "HeVF": None,
                }
            )
            continue

        output_file = output_files[-1]
        value = extract_widom_weight(output_file)
        rows.append(
            {
                "Material": cif_name,
                "HeVF": value,
            }
        )

    return pd.DataFrame(rows)


def merge_hevf_into_zeo_features(filtered_csv_path, hevf_df, csv_dir):
    base_df = pd.read_csv(filtered_csv_path)

    merged_df = base_df.merge(hevf_df, on="Material", how="left")
    output_csv = csv_dir / "rasap_simulate_HeVF_2nd.csv"
    merged_df.to_csv(output_csv, index=False)

    hevf_count = int(merged_df["HeVF"].notna().sum()) if "HeVF" in merged_df.columns else 0
    print(f"HeVF 合并完成：{hevf_count}/{len(merged_df)} 条含 HeVF，结果文件: {output_csv}")

def main():
    base_dir = Path(__file__).resolve().parent
    csv_dir = base_dir / "csvs"
    cif_dir = base_dir / "cifs"

    output_dir = base_dir / "high_through_raspa_simulate_output" 
    

    filtered_csv_path = csv_dir / "zeo++_results_filtered_1st.csv"
    if not filtered_csv_path.exists():
        print(f"错误：未找到过滤阶段文件 {filtered_csv_path}")
        return

    df = pd.read_csv(filtered_csv_path)

    candidate_cifs = [normalize_material_name(item) for item in df["Material"].tolist()]
    candidate_cifs = [name for name in candidate_cifs if name]
    total_candidates = len(candidate_cifs)
    
    print(f"检测到 {total_candidates} 个候选材料，正在启动 RASPA 氦气孔隙率并行计算...")

    with ProcessPoolExecutor(max_workers=16) as executor:
        futures = [
            executor.submit(process_raspa_he, name, cif_dir, output_dir) 
            for name in candidate_cifs
        ]
        
        completed_count = 0
        for future in as_completed(futures):
            completed_count += 1
            cif_name, success, msg = future.result()
            if success:
                print(f"[{completed_count}/{total_candidates}] 🟢 [成功] 氦气模拟完成: {cif_name}")
            else:
                print(f"[{completed_count}/{total_candidates}] 🔴 [失败] {cif_name} : {msg}")
                
    print("\n✅ 所有 RASPA 氦气孔隙率计算任务分配完毕。")
    hevf_df = collect_hevf_results(candidate_cifs, output_dir)
    merge_hevf_into_zeo_features(filtered_csv_path, hevf_df, csv_dir)


if __name__ == "__main__":
    main()
