from math import ceil
import os
import argparse
import numpy as np
import scipy.signal as sg
import matplotlib.pyplot as plt
import soundfile as sf
from model_both import Detector, TransformerModel, Both, Context_dil, Context_dil_2D, Context_rnn_1d, Context_ViT_1d
from torchelie.loss.bitempered import tempered_softmax
import torch
from torch.multiprocessing import set_start_method
from torch.utils.data import DataLoader
import librosa as lr
from umap import UMAP

import glob
import tqdm

_SR = 256_000

def norm(x, eps=1e-10, axis=None):
    return (x - x.mean(axis, keepdims=True)) / (x.std(axis, keepdims=True) + eps)


class FileSequence:
    def __init__(self, filename, nb_ctxt_frame=256, hp=False):
        self.filename = filename

        self.hp = hp
        self.hp_sos = sg.butter(3, 2 * 1500 / _SR, 'hp', output='sos')
        #self.hp_sos = sg.butter(3, 2500, 'hp', output='sos', fs=_SR)
        self.nb_ctxt_frame = nb_ctxt_frame # Nomber of frame before and after anotate click
        self.frame_size = 2*int(_SR * 0.001) #frame duration in sample
        self.hop_size = self.frame_size//2
        
        self.sequence_duration = (self.nb_ctxt_frame * 2 * self.hop_size) / _SR
        file = self.filename #if self.filename.startswith('/nfs') else os.path.join(self.data_path, self.filename)
        self.file_duration = sf.info(file).duration

        self.file_size = int(sf.info(file).samplerate * sf.info(file).duration)
        self.nb_sequences = int((self.file_duration - ((self.frame_size) / _SR) ) / (self.sequence_duration/2))
        self.pos_sequences = np.arange(1, self.nb_sequences-1) * self.sequence_duration/2
        self.nb_sequences -=  2 # exclude first and last frame

    def __len__(self):
        return self.nb_sequences
    
    def __getitem__(self, item):
        file = self.filename #if self.filename.startswith('/nfs') else os.path.join(self.data_path, self.filename)
        
        pos_seq = self.pos_sequences[item]
        pos_seq_start = int((pos_seq-(self.sequence_duration/2)) * sf.info(file).samplerate)
        pos_seq_end = int((pos_seq+(self.sequence_duration/2)) * sf.info(file).samplerate)

        sig, sr = sf.read(file, start=pos_seq_start, stop=(pos_seq_end + self.frame_size))
        if sig.ndim > 1:
            sig = sig[:, -1]

        if sr != _SR:
            #sig = sg.resample(sig, int(len(sig) * _SR / sr))
            sig = sg.resample_poly(sig, int(_SR/np.gcd(_SR, sr)), int(sr/ np.gcd(_SR, sr)))
            sr = _SR

        #Transform function
        if self.hp:
            sig = sg.sosfiltfilt(self.hp_sos, sig)
        
        sig = norm(sig)
    
        seq = lr.util.frame(sig, frame_length=self.frame_size, hop_length=self.hop_size, axis=0, writeable=True)
        seq = norm(seq, axis=1)

        return seq.astype(np.float32)

def one_file(args,  filename):
    #Load detecteur
    detector = Detector()
    if args.vit == True:
        context_cnn = Context_ViT_1d(input_len = args.nb_frames*2)
    else:
        context_cnn = Context_rnn_1d()
    
    #import ipdb; ipdb.set_trace()
    both = Both(detector, context_cnn)

    state_dict = torch.load(args.weight)['model']
    plt.close('all')
    try:
        both.load_state_dict(state_dict)
        both.to('cuda:' + args.device)
        both.eval()
    except RuntimeError:
        print("Error Load context weights")
        exit() 

    bs = args.bs
    pred_dst = FileSequence(filename, nb_ctxt_frame=args.nb_frames, hp=args.hp)
    pred_dst[10]
    pred_dl = DataLoader(pred_dst, batch_size=bs, shuffle=False, drop_last=False, num_workers=1)#args.nb_workers, prefetch_factor=args.nb_workers*2)

    pred = np.empty((len(pred_dst), 2))
    emb_fin = np.empty((len(pred_dst), 32))

    #emb_fin_det = np.empty((len(pred_dst), 2049, 32))

    with torch.no_grad():
        for idx, batch in enumerate(tqdm.tqdm(pred_dl)):
        #for idx, batch in enumerate(pred_dl):
            bs_detec = batch.shape[0] * batch.shape[1]
            batch = batch.to('cuda:' + args.device)

            pred_det, emb_det = detector(batch.reshape(bs_detec, 1, batch.shape[-1]))
            pred_det = pred_det.reshape(batch.shape[0], batch.shape[1], pred_det.shape[-1])
            emb_det = emb_det.reshape(batch.shape[0], batch.shape[1], emb_det.shape[-1])
            #pred_ctxt, emb_ctxt = context_cnn(emb_det.permute(0,2,1)) #if cnn
            pred_ctxt, emb_ctxt = context_cnn(emb_det[:, :(args.nb_frames*2)]) # if rnn :256?
            
            #import ipdb; ipdb.set_trace()

            #pred[idx*bs: (idx+1)*bs] = tempered_softmax(pred_ctxt, 2.).cpu().numpy()
            pred[idx*bs: (idx+1)*bs] = torch.nn.functional.softmax(pred_ctxt, 1).cpu().numpy()
            emb_fin[idx*bs: (idx+1)*bs] = emb_ctxt.cpu().numpy()

            #emb_fin_det[idx*bs: (idx+1)*bs] = emb_det.cpu().numpy()
            
    #emb_fin = emb_fin_det.reshape(54*2049, 32)[::100]
    #pred = pred.repeat(2049, axis=0)[::100]
    return pred, emb_fin, pred_dst.pos_sequences


def one_folder_naive(args):
    finale_path = args.path.replace("/nfs/NAS6/SABIOD/SITE/CARIMAM/DATA/", args.path_out)
    finale_path = finale_path.replace("/nfs/NAS3/SABIOD/SITE/PORTNOUVELLE/", args.path_out)
    finale_path = finale_path.replace("/nfs/NAS3/SABIOD/SITE/ANTILLES_CCS_2021/", args.path_out)
    finale_path = finale_path.replace("/nfs/NAS4/SABIOD/SITE/DCLDE/DCL_HF_Data_Dev/", args.path_out)
    finale_path = finale_path.replace("/nfs/NAS7/SABIOD/SITE/", args.path_out)
    finale_path = finale_path.replace("/short/CARIMAM/DATA/", args.path_out)
    finale_path = finale_path.replace("/nfs/NAS6/mahe/ceta-cnns/Forward/", args.path_out)

    os.makedirs(finale_path, exist_ok=True)

    all_filenames = glob.glob(args.path)
    for idx, filename in enumerate(tqdm.tqdm(all_filenames)):
        pred, emb, pos_seq = one_file(args, filename)
        
        filename = os.path.basename(filename)
        np.savez_compressed('%s/%s_pred.npz'%(finale_path, filename), pred[:,1])
        np.savez_compressed('%s/%s_emb.npz'%(finale_path, filename), emb)
        np.savez_compressed('%s/%s_pos_sequence.npz'%(finale_path, filename), pos_seq) 


class MultiSequences:
    def __init__(self, args):
        wav_upper = glob.glob(args.path+"/*.WAV")
        wav_lower = glob.glob(args.path+"/*.wav")
        self.wav = wav_upper + wav_lower
        #self.wav = glob.glob(args.path+"/*/*/*.wav")

        self.device = args.device

        self.hp = args.hp
        self.nb_frames = args.nb_frames
        self.device = args.device
        self.bs = args.bs
        self.nb_workers = args.nb_workers
        plt.close('all')

        #Load detecteur
        self.detector = Detector()
        if args.vit == True:
            self.context_cnn = Context_ViT_1d(input_len = args.nb_frames*2)
        else:
            self.context_cnn = Context_rnn_1d()
        self.both = Both(self.detector, self.context_cnn)

        state_dict = torch.load(args.weight)['model']
        plt.close('all')
        try:
            self.both.load_state_dict(state_dict)
        except RuntimeError:
            print("Error Load context weights")    
        self.both.to('cuda:' + args.device)
        self.both.eval()

    def __len__(self):
        return len(self.wav)
    
    def __getitem__(self, item):
        try :
            pred_dst = FileSequence(self.wav[item], nb_ctxt_frame=self.nb_frames, hp=self.hp)

            pred = np.empty((len(pred_dst), 2))
            emb = np.empty((len(pred_dst), 32))

            with torch.no_grad():
                for idx, batch in enumerate((torch.stack([torch.from_numpy(pred_dst[k]) for k in range(j*self.bs, (j+1)*self.bs) if k < len(pred_dst)], 0) for j in range(ceil(len(pred_dst)/self.bs)))):
                    bs_detec = batch.shape[0] * batch.shape[1]
                    batch = batch.to('cuda:' + self.device)
                    pred_det, emb_det = self.detector(batch.reshape(bs_detec, 1, batch.shape[-1]))
                    pred_det = pred_det.reshape(batch.shape[0], batch.shape[1], pred_det.shape[-1])
                    emb_det = emb_det.reshape(batch.shape[0], batch.shape[1], emb_det.shape[-1])
                    
                    #pred_ctxt, emb_ctxt = self.context_cnn(emb_det.permute(0,2,1)) #if cnn
                    pred_ctxt, emb_ctxt = self.context_cnn(emb_det[:, :(self.nb_frames*2)]) # if rnn 

                    #pred[idx*self.bs: (idx+1)*self.bs] = tempered_softmax(pred_ctxt, 2.).cpu().numpy()
                    pred[idx*self.bs: (idx+1)*self.bs] = torch.nn.functional.softmax(pred_ctxt, 1).cpu().numpy()
                    emb[idx*self.bs: (idx+1)*self.bs] = emb_ctxt.cpu().numpy()

            return pred, emb, pred_dst.pos_sequences, self.wav[item]
            
        except Exception as e:
            print(e)
            print(self.wav[item])
            return np.empty((0,2)), np.empty((0, 32), np.float32), np.empty((0)), self.wav[item]

def one_folder(args):
    seqs_ds = MultiSequences(args)
    seqs_ds[0]
    bs = 1
    seqs_dl = DataLoader(seqs_ds, batch_size=bs, shuffle=False, drop_last=False, num_workers=args.nb_workers, prefetch_factor=args.nb_workers*2)
    

    with torch.no_grad():
        for idx_b, batch in enumerate(tqdm.tqdm(seqs_dl)):
            
            batch_pred, batch_emb, batch_pos, batch_filenames = batch
            batch_pred = batch_pred.numpy().astype(np.float32)
            batch_emb = batch_emb.numpy().astype(np.float32)
            batch_pos = batch_pos.numpy().astype(np.int32)

            for idx_elem in np.arange(batch_pred.shape[0]):

                filename = batch_filenames[idx_elem]

                filename = filename.replace("/nfs/NAS3/SABIOD/SITE/PORTNOUVELLE/", args.path_out)
                filename = filename.replace("/nfs/NAS6/SABIOD/SITE/CARIMAM/DATA/", args.path_out)
                filename = filename.replace("/nfs/NAS3/SABIOD/SITE/ANTILLES_CCS_2021/", args.path_out)
                filename = filename.replace("/nfs/NAS4/SABIOD/SITE/DCLDE/DCL_HF_Data_Dev/", args.path_out)
                filename = filename.replace("/nfs/NAS7/SABIOD/SITE/", args.path_out)
                filename = filename.replace("/short/CARIMAM/DATA/", args.path_out)
                filename = filename.replace("/nfs/NAS6/mahe/ceta-cnns/Forward/", args.path_out)

                pred = batch_pred[idx_elem]
                emb = batch_emb[idx_elem]
                pos = batch_pos[idx_elem]

                os.makedirs(os.path.dirname(filename), exist_ok=True)
                np.savez_compressed('%s_pred.npz'%(filename), pred[:,1])
                np.savez_compressed('%s_emb.npz'%(filename), emb)
                np.savez_compressed('%s_pos_sequence.npz'%(filename), pos)


def affichage(args, pred, emb, pos_sequences):
    plt.figure()
    plt.plot(pos_sequences, pred[:, 1])

    plt.figure()
    plt.hist(pred[:, 1], 256)
    plt.yscale('log')

    plt.figure()
    reducer = UMAP(verbose=10)
    map = reducer.fit_transform(emb)
    plt.scatter(map[:, 0], map[:, 1], 3, c=pred[:, 1], cmap='jet')
    plt.show()


def main(args):
    #import ipdb; ipdb.set_trace()
    all_filenames = glob.glob(args.path)

    if len(all_filenames) == 1 and os.path.isfile(all_filenames[0]):
        pred, emb, pos_seq = one_file(args, args.path)
        if args.disp == True:
            affichage(args, pred, emb, pos_seq)
        np.save("%s/out_%s_pred.npy"%(os.path.dirname(args.path_out), os.path.basename(args.path)), pred)
        np.save("%s/out_%s_emb.npy"%(os.path.dirname(args.path_out), os.path.basename(args.path)), emb)
        np.save("%s/out_%s_pos_sequence.npy"%(os.path.dirname(args.path_out), os.path.basename(args.path)), pos_seq)
    elif len(all_filenames) > 1 or os.path.isdir(all_filenames[0]) :
        #one_folder_naive(args)
        one_folder(args)
    else:
        raise FileNotFoundError()


if __name__ == '__main__':
    set_start_method('spawn')
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     description="""TODO""")
    parser.add_argument("path", type=str, help="Paths to file")
    parser.add_argument("--path_out", default='results', type=str, help="Paths to output")
    parser.add_argument("--weight", type=str, default='PATH.pth', help="Model weight")
    parser.add_argument("--device", type=str, default='1', help='Gpu device')
    parser.add_argument("--nb_workers", type=int, default=8, help="Number of workers during the forward")
    parser.add_argument("--nb_frames", type=int, default=256, help="number margin around anotated frame")
    parser.add_argument("--bs", type=int, default=16, help="Batch size")
    parser.add_argument("--hp", action='store_true', help="Highpass filter frequency")
    parser.add_argument("--disp", action='store_true', default=False, help="Display information about prediction")
    parser.add_argument("--vit", action='store_true', help="Use ViT context detector")

    main(parser.parse_args())
