import yaml
import os
import pandas as pd
from tqdm import tqdm
import numpy as np
import shutil
import argparse

def arg_directory(path):
    if os.path.isdir(path):
        return path
    else:
        raise argparse.ArgumentTypeError(f'`{path}` is not a valid path')

def create_directory_if_not_exists(directory):
    if not os.path.exists(directory):
        os.mkdir(directory)

def copy_files_to_directory(file_list, source_dir, destination_dir, suffix):
    for file_name in file_list:
        source_path = os.path.join(source_dir, f'{file_name}.{suffix}')
        destination_path = os.path.join(destination_dir, f'{file_name}.{suffix}')
        shutil.copy2(source_path, destination_path)

def split(df, ratio):
    classes = df.espece.unique()
    n_class = classes.size
    train_count = pd.DataFrame(np.zeros((n_class, 1)), index=classes)
    test_count = train_count.copy()
    train_df = pd.DataFrame()
    test_df = pd.DataFrame()
    for i, c in enumerate(classes):
        try:
            sdf = df.groupby('espece').get_group(c)
        except Exception:
            continue 
        if train_count.loc[c].iloc[0] == 0:
            f = sdf.sample(1).file.iloc[0]
            mask = df.file == f
            train_count = train_count.add(df[mask].espece.value_counts(), axis = 0).fillna(0)
            train_df = pd.concat([train_df,df[mask]])
            df = df[~mask]
        if test_count.loc[c].iloc[0] == 0:
            f = sdf.sample(1).file.iloc[0]
            mask = df.file == f
            test_count = test_count.add(df[mask].espece.value_counts(), axis = 0).fillna(0)
            test_df = pd.concat([test_df, df[mask]])
            df = df[~mask]
    while len(df):
        min_esp = df.groupby('espece').count().file.idxmin()
        sdf = df.groupby('espece').get_group(min_esp)
        f = sdf.sample(1).file.iloc[0]
        if (train_count.loc[min_esp]/(test_count.loc[min_esp] + train_count.loc[min_esp]))[0] > ratio:
            test_count.loc[min_esp] += df[df.file == f].groupby('espece').count().iloc[0].file
            test_df = pd.concat([test_df,df[df.file == f]])
        else:
            train_count.loc[min_esp] += df[df.file == f].groupby('espece').count().iloc[0].file
            train_df = pd.concat([train_df, df[df.file == f]])
        df = df[df.file != f]
    print('\nratio', train_count/(test_count + train_count))
    return train_df, test_df

def process_data(args):
    path = args.path_to_data
    directory = args.directory

    df = pd.concat({f: pd.read_csv(os.path.join(path, f), sep=' ', names=['espece', 'x', 'y', 'w', 'h'])
                   for f in tqdm(os.listdir(path))}, names=['file'])

    df = df.reset_index(level=[0])
    df = df.reset_index()
    del df['index']

    df = df[df.espece != 'y']
    df.espece = df.espece.astype(float)
    tab = df.groupby('espece').count()
    tab = tab.sort_values(tab.columns[0], ascending=False)
    compte = pd.DataFrame(np.zeros((len(tab) + 1, 1)), columns=['nombre'])

    return df


def export_split(entry, path, directory):
    val = entry[0]
    train = entry[1]
    create_directory_if_not_exists(os.path.join(directory, 'images'))
    create_directory_if_not_exists(os.path.join(directory, 'labels'))

    if args.test == 1:
        test = entry[2]
        test.file = ['.'.join(x.split('.')[:-1]) for num, x in enumerate(test.file)]
        create_directory_if_not_exists(os.path.join(directory, 'images/test'))
        create_directory_if_not_exists(os.path.join(directory, 'labels/test'))
        copy_files_to_directory(test.file, path, os.path.join(directory, 'labels/test'), 'txt')
        copy_files_to_directory(test.file, os.path.join(path, '../images/all'), os.path.join(directory, 'images/test'), 'jpg') 


    val.file = ['.'.join(x.split('.')[:-1]) for num, x in enumerate(val.file)]
    train.file = ['.'.join(x.split('.')[:-1]) for num, x in enumerate(train.file)]
    create_directory_if_not_exists(os.path.join(directory, 'images/train'))
    create_directory_if_not_exists(os.path.join(directory, 'images/val'))
    create_directory_if_not_exists(os.path.join(directory, 'labels/train'))
    create_directory_if_not_exists(os.path.join(directory, 'labels/val'))

    copy_files_to_directory(val.file, path, os.path.join(directory, 'labels/val'), 'txt')
    copy_files_to_directory(val.file, os.path.join(path, '../images/all'), os.path.join(directory, 'images/val'), 'jpg')

    copy_files_to_directory(train.file, path, os.path.join(directory, 'labels/train'), 'txt')
    copy_files_to_directory(train.file, os.path.join(path, '../images/all'), os.path.join(directory, 'images/train'), 'jpg')

    try:
        liste_espece = pd.read_csv(os.path.join(path, '../liste_especes.csv'))
    except Exception:
        print('No species list detected, please add it to', os.path.join(directory, 'custom_data.yaml'))

    with open(os.path.join(directory, 'custom_data.yaml'), 'w') as f:
        if args.test == 1:
            f.write(f'test: {os.path.join(directory, "images/test")}\n')
        f.write(f'train: {os.path.join(directory, "images/train")}\n')
        f.write(f'val: {os.path.join(directory, "images/val")}\n')
        f.write(f'nc: {len(liste_espece)}\n')
        f.write(f'names: {liste_espece.espece.tolist()}')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='TODO')
    parser.add_argument('-r', '--ratio', type=float, default=0.7, help='Train Ratio (val = 1 - ratio)')
    parser.add_argument('-p', '--path_to_data', type=arg_directory, help='Path of the folder that contains the .txt (ending with labels/)', required=True)
    parser.add_argument('-d', '--directory', type=arg_directory, help='Directory to which spectrogram and .txt files will be stored (different from -p)', required=True)
    parser.add_argument('--test', type=int, help='1 if True. Split into train/test/val. 1 - Ratio / 2 for test and same for validation', default=0)
    args = parser.parse_args()

    df = process_data(args)
    train, val = split(df, args.ratio)
    if args.test == 1:
        val, test = split(val, 0.5)
        export_split([val, train, test], args.path_to_data, args.directory)
    else :
        export_split([val, train], args.path_to_data, args.directory)
