import torch
from torch.utils.data import Dataset
import scipy.signal as sg
import numpy as np
import soundfile as sf
import glob
import mdct
import os

class AudioDataset(Dataset):
    """Audio dataset."""

    def __init__(self, root_dir, transform=None):
        """
        Initializes the dataset
        Args:
            root_dir (string): Directory with all the audio.
            transform (callable, optional): Optional transform to be applied on a data.
        """
        self.root_dir = root_dir
        self.clean_audiofile = glob.glob(root_dir + "/clean/*.wav")
        self.noisy_audiofile = glob.glob(root_dir + "/noisy_25/*.wav")

        #Verification si les fichiers sont les fichiers audios forme bien des paires
        for filename in self.clean_audiofile:
            if os.path.basename(filename) != os.path.basename(filename.replace("clean", "noisy_25")):
                print("ERROR - error loading file: %s"%(filename))

        self.transform = transform

    def __len__(self):
        return len(self.clean_audiofile)

    def __getitem__(self, idx_f):
        """
        Args:
            idx_f (int): index of audio file in the dataset.
        Return:
            data (dict): dictionnery contening paire of signals (the clean and the noisy)
        """
        clean_sig, clean_sr = sf.read(self.clean_audiofile[idx_f])
        noisy_sig, noisy_sr = sf.read(self.noisy_audiofile[idx_f])

        if len(clean_sig) != len(noisy_sig):
            print("ERROR - file don't have the same length : %s %s"%(self.clean_audiofile[idx_f], self.noisy_audiofile[idx_f]))

        data = {'clean': clean_sig, 'noisy': noisy_sig}
        if self.transform:
            data = self.transform(data)

        return data


class Mdct(object):
    def __init__(self, n_mdct=2**8):
        """
        Args:
            n_mdct (int): number of frequency bin in the cosinus transform
        """
        self.n_mdct = n_mdct

    def __call__(self, data):
        """
        Compute the Modified Cosinus Transform of the two signals
        Args:
            data (dict): dictionnery contening paire of signals (the clean and the noisy)
        Return:
            data (dict): dictionnery contening paire of signals (the clean and the noisy)
        """
        clean_sig, noisy_sig = data['clean'], data['noisy']
        clean_ct = mdct.mdct(clean_sig, framelength = self.n_mdct)
        noisy_ct = mdct.mdct(noisy_sig, framelength = self.n_mdct)

        return {'clean': clean_ct, 'noisy': noisy_ct}

class Imdct(object):
    def __init__(self, n_mdct=2**8):
        """
        Args:
            n_mdct (int): number of frequency bin in the cosinus transform
        """
        self.n_mdct = n_mdct

    def __call__(self, data):
        """
        Compute the Inverse Modified Cosinus Transform of the two signals
        Args:
            data (dict): dictionnery contening paire of signals (the clean and the noisy)
        Return:
            data (dict): dictionnery contening paire of signals (the clean and the noisy)
        """
        clean_ct, noisy_ct = data['clean'], data['noisy']
        clean_sig = mdct.imdct(clean_ct, framelength = self.n_mdct)
        noisy_sig = mdct.imdct(noisy_ct, framelength = self.n_mdct)

        return {'clean': clean_sig, 'noisy': noisy_sig}


class RandomCrop(object):
    def __init__(self, output_duration, sr):
        self.crop_nb_sample = int(output_duration * sr)

    def __call__(self, data):
        """
        Crop randomly the signals. The same crop is applied to the two signals.
        Args:
            data (dict): dictionnery contening paire of signals (the clean and the noisy)
        Return:
            data (dict): dictionnery contening paire of cropped signals (the clean and the noisy)
        """
        clean_sig, noisy_sig = data['clean'], data['noisy']

        beg = np.random.randint(0, clean_sig.shape[-1] - self.crop_nb_sample)
        end = beg + self.crop_nb_sample

        clean_sig_crop = np.zeros([self.crop_nb_sample])
        clean_sig_crop = clean_sig[beg:end]

        noisy_sig_crop = np.zeros([self.crop_nb_sample])
        noisy_sig_crop = noisy_sig[beg:end]

        return {'clean': clean_sig_crop, 'noisy': noisy_sig_crop}

class FadeInOut(object):

    def __init__(self, fade_duration, sr):
        """
        Args:
            fade_duration (int): duration of the fades
            sr (int): sampling rate of the signals
        """
        self.fade_nb_sample = int(sr * fade_duration)
        self.fade_in = np.linspace(0, 1, self.fade_nb_sample)
        self.fade_out = np.linspace(1, 0, self.fade_nb_sample)

    def __call__(self, data):
        """
        Apply a fade in and a fade out to the signals. The same fades are applied to the two signals.
        Args:
            data (dict): dictionnery contening paire of signals (the clean and the noisy)
        Return:
            data (dict): dictionnery contening paire of cropped signals (the clean and the noisy)
        """
        clean_sig, noisy_sig = data['clean'], data['noisy']

        clean_sig[:self.fade_nb_sample] *= self.fade_in
        clean_sig[-self.fade_nb_sample:] *= self.fade_out

        noisy_sig[:self.fade_nb_sample] *= self.fade_in
        noisy_sig[-self.fade_nb_sample:] *= self.fade_out
        return {'clean': clean_sig, 'noisy': noisy_sig}

class ToTensor(object):
    def __call__(self, data):
        """
        Convert numpy array to torch.tensor
        Args:
            data (dict): dictionnery contening paire of numpy.array signal
        Return:
            data (dict): dictionnery contening paire of torch.tensor signal
        """
        clean_sig, noisy_sig = data['clean'], data['noisy']
        return {'clean': torch.from_numpy(clean_sig).float(), 'noisy': torch.from_numpy(noisy_sig).float()}
