from math import ceil
import os
import argparse
import numpy as np
import scipy.signal as sg
import matplotlib.pyplot as plt
import soundfile as sf
from tqdm import tqdm
from model_both import Detector
from torch.multiprocessing import set_start_method
from torchelie.loss.bitempered import tempered_softmax
import torch
from umap import UMAP
import glob


_FS = 256_000


def norm(x, eps=1e-10, axis=None):
    return (x - x.mean(axis, keepdims=True)) / (x.std(axis, keepdims=True) + eps)


class Click:
    def __init__(self, file, hp=False):
        self.file = file
        self.hp = hp
        self.hp_sos = sg.butter(3, 2500, 'hp', output='sos', fs=_FS)

        sample, sr = sf.read(file)

        if sample.ndim > 1:
            sample = sample[:, 0]
        if len(sample) == 0:
            raise AssertionError("Error 0 sample audio file")
        if sr != _FS:
            #sample = sg.resample(sample, int(len(sample) * _FS / sr))
            sample = sg.resample_poly(sample, int(_FS/np.gcd(_FS, sr)), int(sr/ np.gcd(_FS, sr)))
            sr = _FS
        if self.hp:
            sample = sg.sosfiltfilt(self.hp_sos, sample)
        self.sample = sample
        self.pos = np.linspace(0, len(sample) - 512, ceil(len(sample)/256)).astype(int)

    def __len__(self):
        return len(self.pos)

    def __getitem__(self, item):
        p = self.pos[item]
        return norm(self.sample[p:p+512]).astype(np.float32)[None]


def one_file(args):
    clicks = Click(args.path, hp=True)
    if args.trans:
        detector = Detector()
        from model_both import TransformerModel, Both
        transformer = TransformerModel()
        both = Both(detector, transformer)
        state_dict = torch.load(args.weight)['model']
        plt.close('all')
        try:
            both.load_state_dict(state_dict)
        except RuntimeError:
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v
            # load params
            both.load_state_dict(new_state_dict)
        both.to('cuda:' + args.device)
        both.eval()
    else:
        detector = Detector()
        state_dict = torch.load(args.weight)['model']
        plt.close('all')
        try:
            detector.load_state_dict(state_dict)           
        except RuntimeError:                                
            from collections import OrderedDict
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            # load params
            detector.load_state_dict(new_state_dict)
        detector.to('cuda:' + args.device)
        detector.eval()
    out = np.empty((len(clicks), 2))
    emb = np.empty((len(clicks), 32))
    bs = 4096
    with torch.no_grad():
        for i, batch in enumerate(tqdm(torch.utils.data.DataLoader(clicks, batch_size=bs, num_workers=8,
                                                                   prefetch_factor=32,
                                                                   shuffle=False, drop_last=False))):
            p, e = detector(batch.to('cuda:' + args.device))
            out[i*bs:(i+1)*bs] = tempered_softmax(p, 2.).cpu().numpy()
            detec = tempered_softmax(p, 2.)
            mask = (detec.T >= torch.nn.MaxPool1d(3, stride=1, padding=1)(detec.T)).T
            detec = mask * detec
            emb[i*bs:(i+1)*bs] = e.cpu().numpy()
    return out, emb, clicks.pos

def affichage(args, out, emb,  pos_frame):
    plt.plot(pos_frame / _FS, out[:, 1])
    plt.figure()
    plt.hist(out[:, 1], 256)
    plt.yscale('log')
    fig, ax = plt.subplots(1, 2)
    nrm = np.linalg.norm(emb, axis=-1)
    ax[0].hist(nrm, np.geomspace(0.01, nrm.max(), 257))
    ax[0].set_xscale('log')
    ax[1].hist(nrm, np.linspace(0, np.percentile(nrm, 95), 257))
    plt.figure()
    bool_click = out[:, 1] > 0.98
    plt.plot(np.convolve(bool_click, np.ones(int(_FS * 0.1))))

    if args.IIId:
        fig = plt.figure()
        ax = fig.add_subplot(projection='3d')
        ax.scatter(*UMAP(n_components=3, verbose=10).fit_transform(emb).T, s=3, c=out[:,1], cmap='jet')
    else:
        plt.figure()
        reducer = UMAP(verbose=10)
        map = reducer.fit_transform(emb[::10])
        plt.scatter(map[:, 0], map[:, 1], 3, c=out[::10, 1], cmap='jet')

        # Comparaison resultats vs. annotations
        # import pandas as pd
        # clicks = Click(args.path)
        # df = pd.read_excel('Annotation_CARIMAM_apo_22_07_11.xlsx', 'annot_click_apo', usecols='A:E').dropna()
        # df = df[df['File'].str.startswith('LOT2/BON_20210425_20210521/20210515_075153')]
        # df_positif = df[df['positif_negatif'] != 'n']
        # v_pos = df_positif.pos_start
        # v_pos = ((v_pos*_FS) / 256).astype(int)
        # emb_pos = emb[v_pos]
        
        # #GET Position, conver to index
        # #regarder les points sur l'espace
        # map_pos = reducer.transform(emb_pos)
        # plt.scatter(map_pos[:,0], map_pos[:,1], 3, c='green')

    plt.show()



class Pred:
    def __init__(self, args, folder, weight, device, hp=False):
        self.hp = hp
        self.folder = folder
        self.wav = sorted([f for f in glob.glob(folder+'/*') if f.lower().endswith('.wav')])
        self.detector = Detector()
        self.device = device
        plt.close('all')

        if args.trans:
            from model_both import TransformerModel, Both
            transformer = TransformerModel()
            both = Both(self.detector, transformer)
            state_dict = torch.load(args.weight)['model']
            try:
                both.load_state_dict(state_dict)
            except RuntimeError:
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove `module.`
                    new_state_dict[name] = v
                # load params
                both.load_state_dict(new_state_dict)
            both.to('cuda:' + args.device)
            both.eval()
        else:
            state_dict = torch.load(args.weight)['model']
            try:
                self.detector.load_state_dict(state_dict)
            except RuntimeError:
                from collections import OrderedDict
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:] # remove `module.`
                    new_state_dict[name] = v
                # load params
                self.detector.load_state_dict(new_state_dict)
            self.detector.to('cuda:' + args.device)
            self.detector.eval()

    def __len__(self):
        return len(self.wav)

    def __getitem__(self, item):
        try:
            clicks = Click(self.wav[item])
            out = np.empty((len(clicks), 2), dtype=np.float32)
            emb = np.empty((len(clicks), 32), dtype=np.float32)
            bs = 4096
            with torch.no_grad():
                for i, batch in enumerate((torch.stack([torch.from_numpy(clicks[k]) for k in range(j*bs, (j+1)*bs) if k < len(clicks)], 0)
                                           for j in range(ceil(len(clicks)/bs)))):
                    p, e = self.detector(batch.to('cuda:' + self.device))
                    p = tempered_softmax(p, 2.)
                    mask = (p.T >= torch.nn.MaxPool1d(3, stride=1, padding=1)(p.T)).T
                    p = mask * p
                    out[i * bs:(i + 1) * bs] = p.cpu().numpy()
                    emb[i * bs:(i + 1) * bs] = e.cpu().numpy()

            #return out.mean(axis=0), np.histogram(out[:, 1], np.linspace(0, 1, 257), density=True)[0].cumsum()
            return out, emb, clicks.pos, self.wav[item]
        except Exception as e:
            print(e)
            print(self.wav[item])
            #return np.zeros(2), np.zeros(256)
            return np.empty((0,2)), np.empty((0, 32), np.float32), np.empty((0)), self.wav[item]


def one_folder(args):
    if (not args.erase) and (os.path.isfile(args.out + '_pred.npy') or os.path.isfile(args.out + '_hist.npy')):
        print("Output already exist, skipping")
        return
    # try:
    #     one_file = glob.glob(args.path+'/*.WAV')[10]
    #     if sf.info(one_file).samplerate != _FS:
    #         print("Session is not at 256 kHz, skipping")
    #         return
    # except:
    #     return
    path = args.path.replace("/nfs/NAS6/SABIOD/SITE/CARIMAM/DATA/", args.out) 
    path = path.replace("/short/CARIMAM/DATA/", args.out) 

    os.makedirs(path, exist_ok=True)

    pred_ds = Pred(args, args.path, args.weight, args.device, hp=True)
    #out = np.empty((len(pred_ds), 2))
    #his = np.empty((len(pred_ds), 256), np.float32)
    bs = 1
    with torch.no_grad():
        for idx_b, batch in enumerate(tqdm(torch.utils.data.DataLoader(pred_ds, batch_size=bs, num_workers=args.nb_workers,
                                                                   prefetch_factor=args.nb_workers*4,
                                                                   shuffle=False, drop_last=False))):
            batch_pred, batch_emb, batch_pos, batch_filenames = batch

            batch_pred = batch_pred.numpy().astype(np.float32)
            batch_emb = batch_emb.numpy().astype(np.float32)
            batch_pos = batch_pos.numpy().astype(np.int32)

            #out[i*bs:(i+1)*bs] = p.cpu().numpy()
            #his[i*bs:(i+1)*bs] = h.cpu().numpy()

            for idx_elem in np.arange(batch_pred.shape[0]): 
                filename = os.path.basename(batch_filenames[idx_elem])
                pred = batch_pred[idx_elem]
                emb = batch_emb[idx_elem]
                pos = batch_pos[idx_elem]
                np.savez_compressed('%s/%s_pred.npz'%(path, filename), pred[:,1])
                np.savez_compressed('%s/%s_emb.npz'%(path, filename), emb[pred[:,1]>=0.5])
                np.savez_compressed('%s/%s_pos_frame.npz'%(path, filename), pos)                
    #np.save(args.out + '_pred', out)
    #np.save(args.out + '_hist', his)

def main(args):
    set_start_method('spawn')
    
    list_files = glob.glob(args.path)
    if len(list_files) == 1 and os.path.isfile(list_files[0]):
        out, emb, pos_frame = one_file(args)
        if args.disp == True:
            affichage(args, out, emb, pos_frame)
        np.save("%s/out_%s_pred.npy"%(args.out, os.path.basename(args.path)), out)
        np.save("%s/out_%s_emb.npy"%(args.out, os.path.basename(args.path)), emb)
        np.save("%s/out_%s_pos_frame.npy"%(args.out, os.path.basename(args.path)), pos_frame)

    elif len(list_files) >= 1 and os.path.isdir(list_files[0]):
        one_folder(args)
    else:
        raise FileNotFoundError()


if __name__ == '__main__':  
      
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     description="""TODO""")
    parser.add_argument("path", type=str, help="Paths to file")
    parser.add_argument("--out", default='results', type=str, help="Paths to output")
    parser.add_argument("--weight", type=str, default='best_acc_updimv2_3dlong.pth', help="Model weight")
    parser.add_argument("--trans", action='store_true', help='Load weight from transformer')
    parser.add_argument("--device", type=str, default='0', help='Gpu device')
    parser.add_argument("--disp", action='store_true', help='Enable / disable plot')
    parser.add_argument("--IIId", action='store_true', help='Plot embs in 3D')
    parser.add_argument("--erase", action='store_true', help="If out file exist and this option is not given,"
                                                             " the computation will be halted")
    parser.add_argument("--nb_workers", type=int, default=8, help="Number of workers during the forward")

    main(parser.parse_args())
