import os

import numpy as np
import soundfile as sf
import scipy.signal as sg
import scipy as scp
import scipy.fft
import glob

_FS = 256_000

def norm(x, eps=1e-10, axis=None):
    return (x - x.mean(axis, keepdims=True)) / (x.std(axis, keepdims=True) + eps)

class Click:
    def __init__(self, train, rank, df, data_path, hp=False, balanced=False, sess_aug=False):
        self.train = train
        self.rank = rank
        self.df = df.copy()
        self.df['label'] = (self.df.positif_negatif != 'n') * 1
        self.gb = self.df.groupby('File')

        if balanced:
            nb_items = min((df.positif_negatif != 'n').sum(), (df.positif_negatif == 'n').sum())
            df_pos = df[df.positif_negatif != 'n']
            df_neg = df[df.positif_negatif == 'n']

            diff_pos_neg = len(df_pos) - len(df_neg)
            if diff_pos_neg > 0:
                df_del = df_pos.sample(n=np.abs(diff_pos_neg))
                df = df.drop(df_del.index)
            if diff_pos_neg < 0:
                df_del = df_neg.sample(n=np.abs(diff_pos_neg))
                df = df.drop(df_del.index)
        #self.df['label'] = self.df.positif_negatif != 'n'        
        self.data_path = data_path
        self.hp = hp
        self.hp_sos = sg.butter(3, 2500, 'hp', output='sos', fs=_FS)

        self.list_noise_files = None
        if sess_aug == True : #ajouter du signal d'autre session comme data augmentation
            self.list_noise_files = glob.glob("/short/CARIMAM/DATA/DATA_CLEAN/*/*.WAV")
            #self.list_noise_files = glob.glob("/nfs/NAS7/SABIOD/SITE/SUIVI_FORT_DE_FRANCE/suivi_baie_fdf/hydro_bleu/session2_20220418_20220501_piles_est/wav/*.wav")
            #self.list_noise_files = glob.glob("/nfs/NAS3/SABIOD/SITE/PORTNOUVELLE/20220623_20220812/*.wav")
            self.nb_noise_files = len(self.list_noise_files)

        if False : # Experiemental
            self.window = sg.windows.hann(512)
            self.nfft = 1024


    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx_sig):
        elem_anc = self.df.iloc[idx_sig]
        #import ipdb; ipdb.set_trace()
        idx_noise = -1
        if self.list_noise_files != None:
        	idx_noise = np.random.randint(self.nb_noise_files)

        #import ipdb; ipdb.set_trace()
        candidate = self.gb.get_group(elem_anc['File'])
        elem_pos = candidate[candidate.label==elem_anc.label]
        elem_neg = candidate[candidate.label!=elem_anc.label]
        if len(elem_pos) > 0:
            elem_pos = elem_pos.sample(1)
        else:
            elem_pos = self.df[self.df.label==elem_anc.label].sample(1)
        if len(elem_neg) > 0:
            elem_neg = elem_neg.sample(1)
        else:
            elem_neg = self.df[self.df.label!=elem_anc.label].sample(1)

        sig_anc, lab_anc = self.get_one_item(elem_anc.squeeze(), idx_noise)
        #sig_pos, lab_pos = self.get_one_item(elem_pos.squeeze(), idx_noise)
        sig_pos, lab_pos = self.get_one_item(elem_anc.squeeze(), idx_noise)
        sig_neg, lab_neg = self.get_one_item(elem_neg.squeeze(), idx_noise)

        label = [lab_anc, lab_pos, lab_neg]
        return sig_anc, sig_pos, sig_neg, label

    def get_one_item(self, df_elem_sig, idx_noise):
        row = df_elem_sig
        click_start, click_end = row[['pos_start', 'pos_end']]
        click_start = float(click_start)
        click_end = float(click_end)
        pos_click = click_start if click_start == click_end else np.random.uniform(click_start, click_end)

        file = row['File']
        file = file if file.startswith('/nfs') else os.path.join(self.data_path, file)
        file = file.replace(" ", "")
        try:
            pos_click = int(sf.info(file).samplerate * pos_click)
            nb_samples = int((512/_FS) *  sf.info(file).samplerate)

            sample, sr = sf.read(file, start=pos_click-nb_samples, stop=pos_click+nb_samples)
        except RuntimeError as e:
            print(f)
            print(item)

        label = row.label
        if sample.ndim > 1:
            sample = sample[:, 0]
        if sr != _FS:
            sample = sg.resample(sample, int(len(sample) * _FS / sr))
            sr = _FS
        if self.hp:
            sample = sg.sosfiltfilt(self.hp_sos, sample)
        
        sample = norm(sample)

        #print(file)
        #print([click_start, click_end])

        if self.train:

            if self.list_noise_files == None:
                sample = sg.resample(sample, max(512, int(len(sample) * np.random.uniform(0.8, 1.2))))      # Data augmentation, accelere/ralenti signal
                sample += np.random.normal(0, 1, len(sample)) * 10 **np.random.uniform(-2, -0.5)             # Ajoute du bruit
                sample += norm(np.random.normal(0, 1, len(sample)).cumsum()) * 10**np.random.uniform(-2, -0.5) # Ajoute du bruit

            else :
                noise_file = self.list_noise_files[idx_noise]

                # get length file
                len_noise_file = int(sf.info(noise_file).samplerate * sf.info(noise_file).duration)
                nb_samples = int((1024/_FS) *  sf.info(noise_file).samplerate)
                rand_pos = np.random.randint(len_noise_file - nb_samples)

                # load file random sample
                noise, sr = sf.read(noise_file, start=rand_pos, stop=rand_pos+min(nb_samples, len(sample)))

                if noise.ndim > 1:
                    noise = noise[:, 0]
                if sr != _FS:
                    noise = sg.resample(noise, int(len(noise) * _FS / sr))
                    sr = _FS

                if self.hp:
                    noise = sg.sosfiltfilt(self.hp_sos, noise)

                noise = norm(noise)

                # add to signal avec un gain random
                noise_gain = np.random.uniform(low=0., high=1.)       # TODO CHANGER POUR UNE ECHELLE dB
                sample += noise * noise_gain      # les valeurs mins/max sont a fixer

            if np.random.rand() < 0.5:                                                                  # Ajout d'un filtre lowpass ou band pass avec une frequence aleatoire
                f = np.random.uniform(1500, sr/2)
                if f*1.1 > sr/2:
                    sos = sg.butter(3, 2*f*0.9/sr, 'lowpass', output='sos')
                else:
                    sos = sg.butter(3, [2*f*0.9/sr, 2*f*1.1/sr], 'bandstop', output='sos')
                sample = sg.sosfiltfilt(sos, sample)
            if np.random.rand() < 0.05:                                                             
                sample += np.roll(sample, np.random.randint(256-64, 256+64))           # ?
        
            if len(sample) > 512:
                p = np.random.randint(len(sample) - 512) # Il faut un decalage pour le train et test -> dans les enregistrements reels les clics ne sont pas centre
                sample = sample[p:p+512]
        else:
            if len(sample) > 512:
                sample = sample[len(sample)//2-256:len(sample)//2+256]

        sample = norm(sample)
        
        if False:    #Experimental
           sample = np.angle(scp.fft.fft(sample * self.window, n=self.nfft))[:self.nfft//2] 
        return sample.astype(np.float32)[None], label
