import os
import pandas as pd
import librosa
import numpy as np
import matplotlib.pyplot as plt
import random
from datetime import date
import argparse
from p_tqdm import p_map
import cv2
from mycolorpy import colorlist as mcp

def arg_directory(path):
    if os.path.isdir(path):
        return path
    else:
        raise argparse.ArgumentTypeError(f'`{path}` is not a valid path')

def process_annotations(file_path, duration, sr, overlap, mode, unique, columns_name, export, window, hop, cpu):
    today = date.today()

    df = pd.read_csv(file_path, low_memory=False)
    df.rename(columns={'label': 'Code', 'annotation_initial_time': 'start', 'annotation_final_time': 'stop', 
    'duree': 'd_annot', 'min_frequency': 'min_freq', 'max_frequency': 'max_freq', 'avg_frequency': 'midl_y'}, inplace=True)

    data = df.groupby('Code').count()
    data = data.sort_values(data.columns[0], ascending=False)
    data.reset_index(inplace = True)
    list_espece = data.Code

    df = df[df.Code.isin(data.Code)]
    df['d_annot'] = df.stop - df.start
    df['midl'] = (df.stop + df.start) / 2
    df['Path'] = df[columns_name]

    df = df[df.d_annot < duration]
    df = df.reset_index()

    # Add your class names to the 'colors' list
    colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) for _ in range(30)]

    def process(x):
        count, (f, grp) = x
        filename = str(f)

        while len(grp) != 0:
            tab = grp[grp.midl <= grp.start.iloc[0] + (duration-overlap)]
            fin = pd.DataFrame(columns=['id', 'x', 'y', 'width', 'height'])
            duree = tab.d_annot

            if len(tab) == 0:
                tab = grp

            rd = round(random.uniform(-1.5, 1.5), 2)

            if unique == 'multiple':
                if tab.start.iloc[0] <= 3:
                    offset = 0
                elif tab.start.iloc[0] == 3:
                    offset = 1
                elif tab.start.iloc[0] >= 3:
                    offset = tab.midl.iloc[0] - 3.5 + rd
            elif unique == 'unique':
                offset = 0

            window_size = window
            window = np.hanning(window_size)
            y, sr = librosa.load(filename, offset=offset, duration=duration, sr=sr)
            stft = librosa.core.spectrum.stft(y, n_fft=window_size, hop_length=hop, window=window)
            spectrum, freq, time, im = plt.specgram(y, Fs=sr, NFFT=window_size, noverlap=hop, cmap='jet')

            vmin = np.flipud(np.log10(np.abs(stft))).mean()
            vmax = np.flipud(np.log10(np.abs(stft))).max()

            plt.close()
            plt.imshow(np.flipud(np.log10(np.abs(stft))), aspect='auto', interpolation=None, cmap='jet', vmin=vmin, vmax=vmax)
            plt.subplots_adjust(top=1, bottom=0, left=0, right=1)

            for idxs, row in tab.iterrows():
                species = row.Code
                x_pxl = (row.midl - offset) / duration
                width_pxl = (row.stop - row.start) / duration

                if mode == 'uniform':
                    height_pxl = 0.8
                    y_pxl = 0.5
                else:
                    y_pxl = 1 - (row.midl_y / (sr / 2))
                    height_pxl = (row.max_freq - row.min_freq) / (sr / 2)
                    if height_pxl > 1:
                        height_pxl = 1
                    elif height_pxl > y_pxl * 2:
                        y_pxl = y_pxl + 0.5 * (height_pxl - y_pxl * 2)

                annotation = pd.DataFrame([[str(data[data.Code == row.Code].index[0]), x_pxl, y_pxl, width_pxl, height_pxl]],
                                         columns=['id', 'x', 'y', 'width', 'height'])
                fin = pd.concat([fin, annotation])

            grp = grp.drop(tab.index)

            name = str(row.Path.replace('/', '_').replace('.', '_') + '_' + str(count))
            name_file = os.path.join(directory, str('labels_' + str(today.day) + '_' + str(today.month)), str(name + '.txt'))

            if os.path.exist(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)))) and 
            os.path.exist(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)), row.Code)):

                plt.savefig(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)), row.Code, str(name + '.jpg')))
                fin.to_csv(name_file, sep=' ', header=False, index=False)
                plt.savefig(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)), 'all',
                                    str(name + '.jpg')))
            elif os.path.exist(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)))):
                for especes in list_espece:
                    os.mkdir(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)), especes))

                fin.to_csv(name_file, sep=' ', header=False, index=False)
                plt.savefig(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)), row.Code,
                                str(name + '.jpg')))
                plt.savefig(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)), 'all',
                                str(name + '.jpg')))
            else:
                os.mkdir(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month))))
                for especes in list_espece:
                    os.mkdir(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)), especes))
                
                os.mkdir(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)), 'all'))
                os.mkdir(os.path.join(directory, str('labels_' + str(today.day) + '_' + str(today.month))))
                fin.to_csv(name_file, sep=' ', header=False, index=False)
                plt.savefig(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)), row.Code,
                                    str(name + '.jpg')))
                plt.savefig(os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)), 'all',
                                   str(name + '.jpg')))

            plt.close()

            if export:
                im = cv2.imread(
                    os.path.join(directory, str('images_' + str(today.day) + '_' + str(today.month)), 'all', str(name + '.jpg')))
                im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                H, W = im.shape[0], im.shape[1]
                for l in range(len(fin)):
                    x, y, w, h = fin.x.iloc[l] * W, fin.y.iloc[l] * H, fin.width.iloc[l] * W, fin.height.iloc[l] * H
                    shape1 = (int(x - 0.5 * w), int(y + 0.5 * h))
                    shape2 = (int(x + 0.5 * w), int(y + 0.5 * h))
                    shape3 = (int(x - 0.5 * w), int(y - 0.5 * h))
                    shape4 = (int(x + 0.5 * w), int(y - 0.5 * h)
                              )
                    shp1 = shape4[0] - 10, shape4[1] + 20
                    shp2 = shape4[0], shape4[1] + 20
                    shp3 = shape4[0] - 10, shape4[1]
                    shp4 = shape4[0], shape4[1]
                    text_shape = shp1[0], shp1[1] - 5
                    label = str(fin.id.iloc[l])
                    cv2.rectangle(im, pt1=shape1, pt2=shape4, color=colors[data[data.Code == row.Code].index[0]], thickness=1)
                    cv2.rectangle(im, pt1=shp1, pt2=shp4, color=colors[data[data.Code == row.Code].index[0]], thickness=-1)
                    cv2.putText(im, label, text_shape, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
                plt.imshow(im)
                plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
                try:
                    plt.savefig(
                        os.path.join(directory, str('images_annotes_' + str(today.day) + '_' + str(today.month), str(name + '.jpg'))))
                except Exception:
                    os.mkdir(os.path.join(directory, str('images_annotes_' + str(today.day) + '_' + str(today.month))))
                    plt.savefig(os.path.join(directory, str('images_annotes_' + str(today.day) + '_' + str(today.month), str(name + '.jpg'))))
                plt.close()

    p_map(process, enumerate(df.groupby('Path')), num_cpus=cpu, total=len(df.groupby('Path')))
    print('saved to', directory)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='Create .txt and .jpg to each annotation frome a csv')
    parser.add_argument('filename_path', type=str, help='Path and name of the file containing the annotations', required=True)
    parser.add_argument('-c', '--columns_name', type=str, help='Name of the column that contains the path', required=True)
    parser.add_argument('path_to_data', type=arg_directory, help='Path of the folder that contains the recordings', required=True)
    parser.add_argument('directory', type=arg_directory, help='Directory to which spectrograms and .txt files will be stored', required=True)
    parser.add_argument('-m', '--mode', type=str, choices=['uniform', 'personalized'], help='Choose the mode to calculate the y and height value', 
                                                                                        default='personalized')
    parser.add_argument('-u', '--unique', type=str, choices=['unique', 'multiple'], help='unique for only one spectrogram per file, multiple for multiple spectrograms', default=multiple)
    parser.add_argument('--export', type=str, default=None, help='To export the position of the bounding box on the spectrogram', required=False)
    parser.add_argument('--duration', type=int, help='Duration for each spectrogram', default=8)
    parser.add_argument('--overlap', type=int, help='Overlap between 2 spectrograms', default=2)
    parser.add_argument('--sr', type=int, help='Sampling rate for the spectrogram. If no argument, '
                                            'SR will be original SR of the recording', default=None)
    parser.add_argument('--window', type=int, help='Window size for the Fourier Transform', default=1024)
    parser.add_argument('--hop', type=int, help='Hop lenght for the Fourier Transform', default=512)
    parser.add_argument('--cpu', type=int, help='To speed up the process, write 2 or more', default=1)

    args = parser.parse_args()

    process_annotations(args.filename_path, duration=args.duration, sr=args.sr, overlape=args.overlap, mode=args.mode, unique=args.unique, 
    columns_name=args.columns_name, export=args.export, window=args.window, hop=args.hop, cpu=args.cpu)
