import subprocess
import multiprocessing
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import csv
import shutil
import tempfile
import pandas as pd



def parse_zeo_outputs(cif_name, zeo_dir):
    data = {
        'Material': cif_name,
        'LCD': None,
        'PLD': None,
        'ASA_m2_g': None,
        'Density': None,
        'AV_cm3_g': None,
        'Error': None
    }

    res_file = zeo_dir / f"{cif_name}.res"
    sa_file = zeo_dir / f"{cif_name}.sa"
    vol_file = zeo_dir / f"{cif_name}.vol"

    if res_file.exists():
        parts = res_file.read_text().strip().split()
        if len(parts) >= 4:
            data['LCD'] = parts[3]
            data['PLD'] = parts[2]

    if sa_file.exists():
        content = sa_file.read_text()
        if "ASA_m^2/g:" in content:
            data['ASA_m2_g'] = content.split("ASA_m^2/g:")[1].split()[0]

    if vol_file.exists():
        content = vol_file.read_text()
        if "Density:" in content:
            data['Density'] = content.split("Density:")[1].split()[0]
        if "AV_cm^3/g:" in content:
            data['AV_cm3_g'] = content.split("AV_cm^3/g:")[1].split()[0]
    
    return data
    


def process_single_cif(cif_file, zeo_temp_root, network_exe):
    cif_name = cif_file.stem
    cif_zeo_output_dir = Path(zeo_temp_root) / cif_name
    cif_zeo_output_dir.mkdir(parents=True, exist_ok=True)

    res_file = cif_zeo_output_dir / f"{cif_name}.res"
    sa_file = cif_zeo_output_dir / f"{cif_name}.sa"
    vol_file = cif_zeo_output_dir / f"{cif_name}.vol"

    cmd = [
        network_exe,
        "-ha", "-res", str(res_file),
        "-sa", "1.66", "1.66", "10000", str(sa_file),
        "-vol", "1.66", "1.66", "10000", str(vol_file),
        str(cif_file.resolve())
    ]

    try:
        result = subprocess.run(
            cmd,
            cwd=cif_zeo_output_dir,
            check=True, 
            capture_output=True, 
            text=True
        )
        parsed_data = parse_zeo_outputs(cif_name, cif_zeo_output_dir)
        return cif_name, True, "成功", parsed_data
    except subprocess.CalledProcessError as e:
        return cif_name, False, f"错误: {e.stderr.strip()}", None
    except FileNotFoundError as e:
        return cif_name, False, f"错误: {e}", None




def main():

    base_dir = Path(__file__).resolve().parent
    csv_dir = base_dir / "csvs"
    cif_dir = base_dir / "cifs"
    script_name = Path(__file__).stem
    default_network = base_dir.parent / "zeo++-0.3" / "network"
    network_exe = shutil.which("network") or str(default_network)

    if not Path(network_exe).exists():
        raise FileNotFoundError(f"未找到 Zeo++ network 可执行文件: {network_exe}")

    csv_dir.mkdir(exist_ok=True)
    cif_files = list(cif_dir.glob("*.cif"))
    print(f"共发现 {len(cif_files)} 个CIF 文件")

    results_list = []

    with tempfile.TemporaryDirectory(prefix=f"{script_name}_zeopp_", dir=base_dir) as zeo_temp_root:
        print(f"Zeo++ 临时输出目录: {zeo_temp_root}")

        with ProcessPoolExecutor(max_workers=16) as executor:
            futures = [
                executor.submit(process_single_cif, cif_file, zeo_temp_root, network_exe)
                for cif_file in cif_files
            ]

            completed_count = 0
            for future in as_completed(futures):
                completed_count += 1
                cif_name, success, message, parsed_data = future.result()

                if success:
                    print(f"[{completed_count}/{len(cif_files)}] 完成：{cif_name}: {message}")
                    if parsed_data:
                        results_list.append(parsed_data)
                else:
                    print(f"[{completed_count}/{len(cif_files)}] 失败：{cif_name}: {message}")


    if results_list:
        csv_file_path1 = csv_dir / "zeo++_results_all.csv"
        fieldnames = set()
        for d in results_list:
            fieldnames.update(d.keys())
        fieldnames = ["Material"] + sorted([f for f in fieldnames if f != "Material"])

        with open(csv_file_path1, "w", newline="", encoding="utf-8") as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writeheader()
            writer.writerows(results_list)

        print(f"\n全部计算完成！结果已汇总至: {csv_file_path1}")
    else:
        print("\n全部计算完成，但没有提取到任何有效数据。")
        return


    r_NF3 = 4.5 # Å

    print(f"NF3 的分子直径设定为: {r_NF3} Å,筛选条件为 PLD >= {r_NF3} Å的材料。")
    try:
        df = pd.read_csv(csv_file_path1)
        numeric_cols = ["LCD", "PLD", "ASA_m2_g", "Density", "AV_cm3_g"]
        for col in numeric_cols:
            df[col] = pd.to_numeric(df[col], errors='coerce')

        # 过滤条件：LCD 达标，且关键数值列都不能为 0（并且不能为缺失值）
        non_zero_mask = (df[numeric_cols] > 0).all(axis=1)
        df_filtered = df[(df["LCD"] >= r_NF3) & non_zero_mask].copy()

        csv_file_path2 = csv_dir / "zeo++_results_filtered_1st.csv"
        df_filtered.to_csv(csv_file_path2, index=False)
        print(f"筛选完成！初始总量: {len(df)} 个，满足条件的材料数量: {len(df_filtered)} 个。")
        print(f"满足条件的材料已保存至: {csv_file_path2}")
    except Exception as e:
        print(f"筛选过程中发生错误: {e}")

    print("Zeo++ 临时输出文件已自动删除。")


    
if __name__ == "__main__":
    main()








