import torch
from torch import nn, tensor
#import torchaudio
from torch.utils import data
from scipy import signal
import numpy as np
import wave
import socket
import sys
import soundfile as sf
import pandas as pd
import glob

EPS = torch.finfo(torch.float32).eps
folder = "/nfs/NAS6/SABIOD/SITE/CARIMAM/DATA/"  #'' # REF '/nfs/NAS5/SABIOD/SITE/BOMBYX/'
folder_noise = "/nfs/NAS6/SABIOD/SITE/CARIMAM/DATA/DATA_CLEAN/*/" #'/nfs/NAS6/SABIOD/SITE/CARIMAM/DATA_CLEAN/ceta-cnns/Train/wav-mono/' #REF '/nfs/NAS3/SABIOD/SITE/BOMBYX_MONACO_2022-07/wav/'

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self, x):
        return x.view(x.shape[0], -1)

def norm(arr):
    return (arr - np.mean(arr) ) / (np.std(arr)+EPS)

class Dataset(data.Dataset):
    def __init__(self, df, test_bool=False, filename=False, white_noise=False, pink_noise=False, rnd_shift=False, mirror=False, brown_noise=False, reef_noise=False, fe=50000, norm=True, int16=False, sampleDur=2.5):
        super(Dataset, self)
        self.df = df
        self.filename = filename # true if you want the filename in getitem
        self.rng = np.random.RandomState(42)
        self.pink_noise = pink_noise
        self.white_noise = white_noise
        self.brown_noise = brown_noise
        self.rnd_shift = rnd_shift
        self.mirror = mirror
        self.fe = fe
        self.sampleDur = sampleDur
        self.norm = norm
        self.int16 = int16

        self.reef_noise = reef_noise
        if self.reef_noise == True :
            test_noise_files = np.array(glob.glob(folder_noise+'*.wav'))
            if test_bool == False:
                self.list_noise_files = test_noise_files[np.arange(0, len(test_noise_files), 2)]
            else :
                self.list_noise_files = test_noise_files[np.arange(1, len(test_noise_files), 2)]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        fs = sf.info(folder+row.wavpath).samplerate
        clickpad = int(self.sampleDur * fs)

        sig, fs = sf.read(folder + row.wavpath, start=max(0, int(fs*row.time)-clickpad//2), stop=int(fs*row.time)+clickpad//2, dtype= 'int16' if self.int16 else 'int32')

        sig = sig / (2**15 if self.int16 else 2**31)

        if sig.ndim > 1:
            sig = sig[:,0]
            #sig = sig[:,np.random.randint(sig.shape[1])]

        if self.norm:
            sig = norm(sig)
        
        if len(sig) < clickpad:
            sig = np.append(sig, [np.zeros(clickpad-len(sig))])
        if fs != self.fe:
            sig = signal.resample(sig, int(self.sampleDur*self.fe))
        
        if self.reef_noise:
            noise_file = self.list_noise_files[np.random.randint(len(self.list_noise_files))]

            info_noise = sf.info(noise_file)
            fs_noise = info_noise.samplerate
            dur_noise = info_noise.duration
            
            len_noise = int(dur_noise * fs_noise)
            rand_pos = np.random.randint(len_noise - int(self.sampleDur*fs_noise))
            
            noise, fs_noise = sf.read(noise_file, start=rand_pos, stop=rand_pos+int(self.sampleDur*fs_noise), dtype= 'int16' if self.int16 else 'int32')
            noise = noise / (2**15 if self.int16 else 2**31)

            if noise.ndim > 1:
                noise = noise[:, 0]
            if fs_noise != self.fe:
                noise = signal.resample(noise, int(self.sampleDur*self.fe))
                #resample_poly is faster but does not work with integer signal
                #noise = signal.resample_poly(noise, int(self.fe/np.gcd(self.fe, fs_noise)), int(fs_noise/ np.gcd(self.fe, fs_noise)))

            
            noise = norm(noise)
            if self.norm == False:
                noise = (noise * np.std(sig)) + np.mean(sig)
            
            gain = 10**(np.random.uniform(low=-40., high=5.)/10)
            sig += gain * noise


        if self.mirror :
            sig = np.flip(sig) if np.random.random() > .5 else sig
        if self.rnd_shift :
            shift = int(np.random.random()*5*fs)
            sig = np.concatenate([sig[shift:], sig[:shift]])
        
        if self.pink_noise :
            noise = pink_noise(len(sig), self.rng)
        if self.brown_noise :
            noise = np.cumsum(np.random.normal(0, 1, len(sig)))  #* 10**(-SNR/20) # brown noise
        if self.white_noise :
            noise = np.random.normal(0, 1.41, len(sig)) # 3dB white noise
        if self.brown_noise or self.white_noise or self.pink_noise:
            noise = noise / (2**15 if self.int16 else 2**31)
            noise = norm(noise)            
            if self.norm == False:
                noise = (noise * np.std(sig)) + np.mean(sig)
            
            gain = 10**(np.random.uniform(low=-40., high=5.)/10)
            sig += gain * noise
        
        if self.norm :
            sig = norm(sig)
        else:
            pass
            #sig = norm(sig)
            #sig = (sig * G_STD) + G_MEAN
#        spec = fftweight.fft_gtgram(sig, fs, 512/fs, 128/fs, 64, 2000)
        
        if self.filename:
            return tensor(sig).float(), int(row.annot=='cachcach'), row.wavpath
        else:
            return tensor(sig).float(), int(row.annot=='cachcach')


class Updim(nn.Module):
    def __init__(self, *outshape):
        super(Updim, self).__init__()
        self.outshape = outshape
    def forward(self, x):
        return x.view(x.shape[0], *self.outshape)

def lossfun(pred, target):
    return torch.mean(torch.sum(- target * nn.functional.log_softmax(pred, 1), 1))

def wlossfun(pred, target, cuda0):
    wpos = target.argmax(1).sum().float() / len(target)
    weights = torch.ones(len(target)) * wpos
    weights[~(target.argmax(1)==1)] = 1-wpos
    preloss = torch.sum(-target * nn.functional.log_softmax(pred, 1), 1)
    return (preloss / weights.cuda(cuda0)).sum() / (2 * len(target))

class PrintShape(nn.Module):
    def __init__(self, msg):
        super(PrintShape, self).__init__()
        self.msg = msg
    def forward(self, x):
        print(self.msg+' shape : ', x.shape)
        return x


def PrintModel(model, inlength=22050, indata=None):
    x = tensor(np.arange(inlength)).view(1, 1, -1).float() if indata is None else indata
    print('in shape : ',x.shape, '\n')
    prevshape = x.shape
    for layer in model:
        print(layer)
        x = layer(x)
        if x.shape != prevshape:
            print('Outputs : ',x.shape)
            prevshape = x.shape
        print()


class fftweightsModule(nn.Module):
    def __init__(self, weights):
        super(fftweightsModule, self).__init__()
        self.weights = nn.Parameter(tensor(weights).float(), requires_grad=False)
    def forward(self, x):
        return torch.matmul(self.weights, x)

class GammaSpec(nn.Module):
    def __init__(self, fs, winsize, hopsize, nfilts, fmin):
        super(GammaSpec, self).__init__()
        self.nfft = int(2 ** (np.ceil(np.log2(2 * winsize))))
#        self.add_module('weights', fftweightsModule(fftweight.fft_weights(self.nfft, fs, nfilts, 1, fmin, fs/2, self.nfft/2 +1)[0]))
 #       nwin, nhop, _ = gtgram.gtgram_strides(fs, winsize/fs, hopsize/fs, 0)
        #acthalflen = int(np.floor(min(self.nfft//2, self.nwin//2)))
        #halfwin = 0.5 * ( 1 + np.cos(np.pi * np.arange(0, self.nwin//2)/self.nwin//2))
        #self.win = torch.Tensor(np.append(np.flip(halfwin), halfwin))
#        self.stft = STFT(n_fft=self.nfft, hop_length=hopsize, sr=fs)

    def forward(self, x):
#        torchaudio : spec = spectrogram(x, 0, self.win, self.nfft, self.nhop, self.nwin, None, False)
        spec = self.stft(x)
#        spec = (spec**2).sum(dim=-1)**.5 # absolute value of complex
        ret = self.weights(spec) / self.nfft
        return ret

def pink_noise(size, rng, ncols=16, axis=-1):
    """Generates pink noise using the Voss-McCartney algorithm.

    size: either a tuple of int or an int. If an int : number of sample to generate. If a tuple: shape of the return array.
    ncols: number of random sources to add. Should be high enough so that num_samples*0.5**(ncols-2) is near zero.
    axis: axis which contains the sound samples. Generate white noise otherwise.

    returns: NumPy array of shape size
    """
    if type(size) is not tuple:
        size = (size,)
    array = rng.rand(*size)
    assert -len(size) <= axis < len(size)
    axis  %= len(size)
    axis +=1
    # the total number of changes is nrows
    cols = rng.geometric(0.5, size)
    cols[cols >= ncols] = 0
    cols = (1.*(np.arange(1,ncols).reshape((-1,) + len(size)*(1,)) == cols)).swapaxes(axis,-1)
    cols[...,0] = 1.
    cols = np.cumsum(cols).reshape(cols.shape).astype(int).swapaxes(axis,-1)
    array = np.concatenate([array[np.newaxis],rng.rand(cols.max()+1)[cols]],axis=0).sum(0)
    return array


class depthwise_separable_conv2d(nn.Module):
    def __init__(self, nin, nout, kernel, padding=0, stride=1):
        super(depthwise_separable_conv2d, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=kernel, padding=padding, stride=stride, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)
    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

class Quat_depthwise_separable_conv1d(nn.Module):
    def __init__(self, nin, nout, kernel, padding=0, stride=1, quat=False, BN=True):
        super(Quat_depthwise_separable_conv1d, self).__init__()
        self.quat = quat
        convtype = nn.Conv1d if not self.quat else torch.nn.quantized.modules.conv.Conv1d
        self.depthwise = convtype(nin, nin, kernel_size=kernel, padding=padding, stride=stride, groups=nin)
        convtype = nn.Conv1d if not self.quat else nn.intrinsic.quantized.modules.conv_relu.ConvReLU1d
        self.pointwise = convtype(nin, nout, kernel_size=1)
        self.bn = nn.BatchNorm1d(nout) if not quat and BN else nn.Identity()
        self.relu = nn.ReLU() if not quat else nn.Identity()
        self.depthwise_quantize, self.pointwise_quantize, self.DeQuantize = [nn.Identity()] * 3

    def addquantize(self):
        self.depthwise_quantize = nn.quantized.Quantize(self.depthwise.scale, self.depthwise.zero_point, dtype=torch.quint8)
        self.pointwise_quantize = nn.quantized.Quantize(self.pointwise.scale, self.pointwise.zero_point, dtype=torch.quint8)
        self.DeQuantize = nn.quantized.DeQuantize()

    def fuse_module(self):
        torch.quantization.fuse_modules(self, [['pointwise', 'bn', 'relu']], inplace=True)

    def forward(self, x):
        out = self.depthwise_quantize(x)
        out = self.depthwise(out)
        out = self.DeQuantize(out)
        out = self.pointwise_quantize(out)
        out = self.pointwise(out)
        out = self.DeQuantize(out)
        out = self.bn(out)
        out = self.relu(out)
        return out

class depthwise_separable_conv1d(nn.Module):
    def __init__(self, nin, nout, kernel, padding=0, stride=1):
        super(depthwise_separable_conv1d, self).__init__()
        self.depthwise = nn.Conv1d(nin, nin, kernel_size=kernel, padding=padding, stride=stride, groups=nin)
        self.pointwise = nn.Conv1d(nin, nout, kernel_size=1)
    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


def addquantize(model):
    new = []
    for l in model:
        if type(l) == depthwise_separable_conv1d:
            l.addquantize()
            new.append(l)
        elif type(l) in [torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU1d, torch.nn.quantized.modules.conv.Conv1d] :
            new.extend([nn.quantized.Quantize(l.scale, l.zero_point, torch.quint8), l, nn.quantized.DeQuantize()])
        else:
            new.append(l)
    return nn.Sequential(*new)


class Dropout1d(nn.Module):
    def __init__(self, pdropout):
        super(Dropout1d, self).__init__()
        self.dropout = nn.Dropout2d(pdropout)
    def forward(self, x):
        x = x.unsqueeze(-1)
        x = self.dropout(x)
        return x.squeeze(-1)


class Reshape(nn.Module):
    def __init__(self, *shape):
        super(Reshape, self).__init__()
        self.shape = shape
    def forward(self, x):
        return x.view(x.shape[0], *self.shape)

class SpecNorm(nn.Module):
    def __init__(self):
        super(SpecNorm, self).__init__()
    def forward(self, x): # standardize by mean and std over each frequency bin of each sample (mean and std over time dimension)
        return (x-x.mean(axis=2, keepdim=True))/x.std(axis=2, keepdim=True)
