import os
import numpy as np
import soundfile as sf
import scipy.signal as sg
import librosa as lr
import glob

SR = 96_000 #256_000

def norm(x, eps=1e-10, axis=None):
    return (x - x.mean(axis, keepdims=True)) / (x.std(axis, keepdims=True) + eps)

class AudioSignal:
    def __init__(self, train, df, data_path, hp=False, sig_dur=5., noise_data_aug=False):
        super(AudioSignal, self)
        #import ipdb;ipdb.set_trace()

        self.train = train
        self.df = df

        self.data_path = data_path
        self.label = np.array(self.df.positif_negatif != 'n').astype(int)
        self.hp = hp
        self.hp_sos = sg.butter(6, 2500, 'hp', output='sos', fs=SR)
        self.sig_dur = sig_dur

        self.list_noise_files = None
        if noise_data_aug is True : #ajouter du signal d'autre session comme data augmentation
            self.list_noise_files = glob.glob("/short/CARIMAM/DATA/DATA_CLEAN/*/*.WAV")
            #self.list_noise_files = glob.glob("/nfs/NAS7/SABIOD/SITE/SUIVI_FORT_DE_FRANCE/suivi_baie_fdf/hydro_bleu/session2_20220418_20220501_piles_est/wav/*.wav")
            #self.list_noise_files = glob.glob("/nfs/NAS3/SABIOD/SITE/PORTNOUVELLE/20220623_20220812/*.wav")
            self.nb_noise_files = len(self.list_noise_files)

        self.window = sg.windows.hann(4096)
        self.nfft = 4096

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        row = self.df.iloc[item]
        click_start, click_end = row[['pos_start', 'pos_end']]
        pos_click = (float(click_start) + float(click_end)) / 2

        file = row['File']
        file = file if file.startswith('/nfs') else os.path.join(self.data_path, file)
        file = file.replace(" ", "")

        try:
            pos_click = int(pos_click * sf.info(file).samplerate)
            if self.train:
                margin_dur = int(self.sig_dur * sf.info(file).samplerate)
            else:
                margin_dur = int(self.sig_dur/2 * sf.info(file).samplerate)
            margin_dur = int(self.sig_dur/2 * sf.info(file).samplerate)

            sig, sr = sf.read(file, start=max(0, pos_click-margin_dur), stop=pos_click+margin_dur)

            if (pos_click-margin_dur) < 0:
                sig = np.pad(sig, (np.abs(pos_click-margin_dur), 0))

            if len(sig) < 2 * margin_dur:
                sig = np.pad(sig, (0, (2*margin_dur)-len(sig)))

            if len(sig) <= 0 :
                print(file)
                print(sig.shape)
                import ipdb; ipdb.set_trace()

        except RuntimeError as e:
            print(file)
            print(item)

        label = self.label[item]

        if sig.ndim > 1:
            sig = sig[:, 0]
        if sr != SR:
            sig = sg.resample(sig, int(len(sig) * SR / sr))
            sr = SR
        if self.hp:
            sig = sg.sosfiltfilt(self.hp_sos, sig)

        sig = norm(sig)

        if self.train:
            r_pos = np.random.randint(len(sig)//2-1)
            sig = sig[r_pos : r_pos+int(self.sig_dur*sr)]
            sig = norm(sig)

            #import ipdb;ipdb.set_trace()

            if self.list_noise_files is None:
                #sig = sg.resample(sig, max(512, int(len(sample) * np.random.uniform(0.8, 1.2))))      # Data augmentation, accelere/ralenti signal
                sig += np.random.normal(0, 1, len(sig)) * 10 **np.random.uniform(-2, -0.5)             # Ajoute du bruit
                sig += norm(np.random.normal(0, 1, len(sig)).cumsum()) * 10**np.random.uniform(-2, -0.5) # Ajoute du bruit

            else :
                noise_file = self.list_noise_files[np.random.randint(self.nb_noise_files)]

                while sf.info(noise_file).duration < self.sig_dur:
                    noise_file = self.list_noise_files[np.random.randint(self.nb_noise_files)]
                    
                # get length file
                len_noise_file = int(sf.info(noise_file).samplerate * sf.info(noise_file).duration)
                margin_dur = int(self.sig_dur / 2 * sf.info(noise_file).samplerate)
                
                rand_pos = np.random.randint(margin_dur, len_noise_file - margin_dur)


                # load noise file form a random sample
                noise, sr = sf.read(noise_file, start=rand_pos-margin_dur, stop=rand_pos+margin_dur)

                if noise.ndim > 1:
                    noise = noise[:, 0]
                if sr != SR:
                    noise = sg.resample(noise, int(len(noise) * SR / sr))
                    sr = SR

                if self.hp:
                    noise = sg.sosfiltfilt(self.hp_sos, noise)

                noise = norm(noise)

                # add to signal avec un gain random
                noise_gain = np.random.uniform(low=2., high=5.)       # TODO CHANGER POUR UNE ECHELLE dB
                sig = sig + (noise * noise_gain)


            if np.random.rand() < 0.5:                           # Ajout d'un filtre lowpass ou band pass avec une frequence aleatoire
                f = np.random.uniform(1500, sr/2)
                if f*1.1 > sr/2:
                    sos = sg.butter(3, 2*f*0.9/sr, 'lowpass', output='sos')
                else:
                    sos = sg.butter(3, [2*f*0.9/sr, 2*f*1.1/sr], 'bandstop', output='sos')
                sig = sg.sosfiltfilt(sos, sig)
            if np.random.rand() < 0.05:                                                             
                sig += np.roll(sig, np.random.randint(256-64, 256+64))   # Add echo of click, simulate surface echo  

        sig = norm(sig)

        tf_sig = np.abs(lr.stft(sig, n_fft=self.nfft, hop_length=938, window=self.window, center=True))   # self.nfft//4 ~= 938    

        return tf_sig.astype(np.float32)[None], label


def main():
    import matplotlib.pyplot as plt
    import pandas as pd
    from tqdm import tqdm

    df = pd.read_excel(
        "Annotation_CARIMAM_apo_23_05_11.xlsx", 'annot_click_apo', usecols='A:E'
        ).dropna()
    df = df[~df['File'].str.contains('/nfs/NAS4/')]
    df = df[~df['File'].str.contains('/nfs/NAS3/')]
    data_path = "/nfs/NAS6/SABIOD/SITE/CARIMAM/DATA/"

    #import ipdb; ipdb.set_trace()
    dataset = AudioSignal(True, df, data_path, hp=True, sig_dur=5., noise_data_aug=True)
    print("Size of dataset : %d"%(len(dataset)))
    sig, label = dataset[20]

    plt.figure(figsize=(16,9))
    plt.imshow(sig[0], aspect='auto', interpolation=None, origin='lower')
    plt.title(f"Sig label : {label}")
    plt.show()

    for idx in tqdm(range(0, len(dataset), 10)):
        sig, label = dataset[idx]
        print(sig.shape)
    print("Test is succesfull")

if __name__ == "__main__":
    main()