import os
import shutil
import pandas as pd
from pathlib import Path

# ================== PARAMÈTRES À MODIFIER ==================

# Chemin du fichier CSV d'entrée
CSV_PATH = "/CIAN/PROCESS/SEGLVIK_EUROPAM/BROWN_UNIV/fin_whale_1/RESOURCES/RESOURCES_SEGLVIK_NEW_ANNOTS/full_manual/detections.csv"

# Nom de la colonne contenant les noms de fichiers
FILENAME_COL = "filename"

# Chemins des deux fichiers CSV de sortie
CSV_SERIE1_PATH = "/CIAN/PROCESS/SEGLVIK_EUROPAM/BROWN_UNIV/fin_whale_1/RESOURCES/RESOURCES_SEGLVIK_NEW_ANNOTS/validset/detections.csv"
CSV_SERIE2_PATH = "/CIAN/PROCESS/SEGLVIK_EUROPAM/BROWN_UNIV/fin_whale_1/RESOURCES/RESOURCES_SEGLVIK_NEW_ANNOTS/testset/detections.csv"

CSV_TRAINSET_PATH = "/CIAN/PROCESS/SEGLVIK_EUROPAM/BROWN_UNIV/fin_whale_1/RESOURCES/RESOURCES_SEGLVIK_NEW_ANNOTS/trainset/detections.csv"

# Dossier source contenant les fichiers à copier
SOURCE_DIR = "/CIAN/PROCESS/SEGLVIK_EUROPAM/BROWN_UNIV/fin_whale_1/RESOURCES/RESOURCES_SEGLVIK_MANUAL_NEW_ANNOTS/full/flacs"

# Dossiers de destination pour chaque série
DEST_DIR_SERIE1 = "/CIAN/PROCESS/SEGLVIK_EUROPAM/BROWN_UNIV/fin_whale_1/RESOURCES/RESOURCES_SEGLVIK_NEW_ANNOTS/validset/flacs"
DEST_DIR_SERIE2 = "/CIAN/PROCESS/SEGLVIK_EUROPAM/BROWN_UNIV/fin_whale_1/RESOURCES/RESOURCES_SEGLVIK_NEW_ANNOTS/testset/flacs"

# ===========================================================


def delete_from_trainset(to_remove):
    df = pd.read_csv(CSV_TRAINSET_PATH)
    print(f"Line amount before cleaning : {len(df)}")
    df_cleaned = df[~df[FILENAME_COL].isin(to_remove)]
    df_cleaned.to_csv(CSV_TRAINSET_PATH, index=False)
    print(f"Line amount after cleaning : {len(df_cleaned)}")

def main():
    # 1) Lecture du fichier CSV
    df = pd.read_csv(CSV_PATH)

    if FILENAME_COL not in df.columns:
        raise ValueError(f"La colonne '{FILENAME_COL}' n'existe pas dans le CSV.")

    # 2) Sélection de la colonne filename et .unique()
    filenames_unique = df[FILENAME_COL].dropna().unique()
    delete_from_trainset(filenames_unique)

    # 3) Division de cette série en 2 séries équitables
    n = len(filenames_unique)
    mid = n // 2  # si n est impair, la 2e série aura un élément de plus
    serie1 = filenames_unique[:mid]
    serie2 = filenames_unique[mid:]

    print(f"Nombre de filenames uniques : {n}")
    print(f"Série 1 : {len(serie1)} fichiers")
    print(f"Série 2 : {len(serie2)} fichiers")

    # 4) Création de deux nouveaux CSV en reprenant les infos initiales

    # Filtrer le dataframe original pour ne garder que les lignes dont le filename est dans chaque série
    df_serie1 = df[df[FILENAME_COL].isin(serie1)]
    df_serie2 = df[df[FILENAME_COL].isin(serie2)]

    # Sauvegarde des deux fichiers CSV
    df_serie1.to_csv(CSV_SERIE1_PATH, index=False)
    df_serie2.to_csv(CSV_SERIE2_PATH, index=False)

    print(f"CSV série 1 sauvegardé dans : {CSV_SERIE1_PATH}")
    print(f"CSV série 2 sauvegardé dans : {CSV_SERIE2_PATH}")

    # 5) Copie des fichiers de la série 1
    copy_files(serie1, SOURCE_DIR, DEST_DIR_SERIE1)

    # 6) Copie des fichiers de la série 2
    copy_files(serie2, SOURCE_DIR, DEST_DIR_SERIE2)


def copy_files(filenames, source_dir, dest_dir):
    """Copie tous les fichiers listés dans 'filenames'
    depuis source_dir vers dest_dir.
    """
    source_dir = Path(source_dir)
    dest_dir = Path(dest_dir)
    dest_dir.mkdir(parents=True, exist_ok=True)

    missing = []

    for fname in filenames:
        src_path = source_dir / str(fname)
        dst_path = dest_dir / str(fname)

        if src_path.is_file():
            # Crée l'arborescence si nécessaire (au cas où il y a des sous-dossiers)
            dst_path.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy2(src_path, dst_path)
        else:
            missing.append(str(src_path))

    print(f"\nCopie terminée vers : {dest_dir}")
    print(f"Fichiers copiés : {len(filenames) - len(missing)}")
    if missing:
        print("Fichiers manquants (non trouvés dans le dossier source) :")
        for m in missing:
            print(f" - {m}")


if __name__ == "__main__":
    main()
