import os
from math import ceil
import argparse
import sys
import time

import numpy as np
from matplotlib.pyplot import close
import pandas as pd
import soundfile as sf
import scipy.signal as sg
from model_both import Detector, TransformerModel, Both
import torchelie as tch
import torchelie.callbacks.callbacks as tcb
from torchelie.loss.bitempered import tempered_softmax
import torch
from torch.utils.data import DataLoader
from torchelie.recipes import TrainAndTest

_FS = 256_000

def norm(x, eps=1e-10, axis=None):
    return (x - x.mean(axis, keepdims=True)) / (x.std(axis, keepdims=True) + eps)


class Click:
    def __init__(self, train, rank, df, data_path, hp=False):
        self.train = train
        self.rank = rank
        self.df = df
        self.data_path = data_path
        self.label = np.array(df.positif_negatif != 'n').astype(int)
        self.hp = hp
        self.hp_sos = sg.butter(3, 2 * 1500 / _FS, 'hp', output='sos')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        row = self.df.iloc[item]
        s, e = row[['pos_start', 'pos_end']]
        p = s if s == e else np.random.uniform(s, e)
        f = row['File ']
        f = f if f.startswith('/nfs') else os.path.join(self.data_path, f)
        p = int(sf.info(f).samplerate * p)
        sample, sr = sf.read(f, start=p-512, stop=p+512)
        label = self.label[item]
        if sample.ndim > 1:
            sample = sample[:, 0]
        if sr != _FS:
            #sample = sg.resample(sample, int(len(sample) * _FS / sr))
            song = sg.resample_poly(song, int(_FS/np.gcd(_FS, sr)), int(sr/ np.gcd(_FS, sr)))
            sr = _FS
        if self.hp:
            sample = sg.sosfiltfilt(self.hp_sos, sample)
        sample = norm(sample)
        if self.train:
            sample = sg.resample(sample, max(512, int(len(sample) * np.random.uniform(0.9, 1.1))))
            sample += np.random.normal(0, 1, len(sample)) * 10**np.random.uniform(-4, -0.5)
            sample += norm(np.random.normal(0, 1, len(sample)).cumsum()) * 10**np.random.uniform(-4, -0.5)
            if np.random.rand() < 0.5:
                f = np.random.uniform(1500, sr/2)
                if f*1.1 > sr/2:
                    sos = sg.butter(3, 2*f*0.9/sr, 'lowpass', output='sos')
                else:
                    sos = sg.butter(3, [2*f*0.9/sr, 2*f*1.1/sr], 'bandstop', output='sos')
                sample = sg.sosfiltfilt(sos, sample)
            if np.random.rand() < 0.2:
                sample += np.roll(sample, np.random.randint(256-64, 256+64))
            sample = norm(sample)
        if len(sample) > 512:
            p = np.random.randint(len(sample) - 512) if self.train else (len(sample) - 512)//2
            sample = sample[p:p+512]
        return sample.astype(np.float32)[None], label


class Sequence:
    def __init__(self, train, rank, df, data_path, hp=False, balance=False):
        self.train = train
        self.rank = rank
        self.df = df
        self.grp = tuple(df.groupby('File '))
        self.label = [int((grp.positif_negatif != 'n').any()) for _, grp in self.grp]
        if balance:
            f, t = np.unique(self.label, return_counts=True)[1]
            if f == t:
                pass
            else:
                self.grp = self.grp + (ceil(abs(t-f)/min(t, f)) *
                                       tuple([grp for grp, lab in zip(self.grp, self.label)
                                              if lab == int(t < f)]))[:abs(t-f)]
                self.label = self.label + (ceil(abs(t-f)/min(t, f)) * [lab for grp, lab in zip(self.grp, self.label)
                                                                       if lab == int(t < f)])[:abs(t-f)]
        self.data_path = data_path
        self.hp = hp
        self.hp_sos = sg.butter(3, 2 * 1500 / _FS, 'hp', output='sos')

    def __len__(self):
        return len(self.grp)

    def __getitem__(self, item):
        f, grp = self.grp[item]
        f = f if f.startswith('/nfs') else os.path.join(self.data_path, f)
        song, sr = sf.read(f)
        label = self.label[item]
        if song.ndim > 1:
            song = song[:, 0]
        if sr != _FS:
            #song = sg.resample(song, int(len(song) * _FS / sr))
            song = sg.resample_poly(song, int(_FS/np.gcd(_FS, sr)), int(sr/ np.gcd(_FS, sr)))
            sr = _FS
        if self.hp:
            song = sg.sosfiltfilt(self.hp_sos, song)
        song = norm(song)
        sr_new = sr
        if self.train:
            fact = np.random.uniform(0.9, 1.1)
            sr_new = fact * sr
            song = sg.resample(song,  int(len(song) * fact))
            song += np.random.normal(0, 1, len(song)) * 10**np.random.uniform(-4, -0.5)
            song += norm(np.random.normal(0, 1, len(song)).cumsum()) * 10**np.random.uniform(-4, -0.5)
            if np.random.rand() < 0.5:
                f = np.random.uniform(1500, sr/2)
                if f*1.1 > sr/2:
                    sos = sg.butter(3, 2*f*0.9/sr, 'lowpass', output='sos')
                else:
                    sos = sg.butter(3, [2*f*0.9/sr, 2*f*1.1/sr], 'bandstop', output='sos')
                song = sg.sosfiltfilt(sos, song)
            # if np.random.rand() < 0.05:
            #     song += np.roll(song, np.random.randint(256-64, 256+64))
            song = norm(song)
        
        pos = np.linspace(0, len(song) - 512, 57344).astype(int)  # constant size
        seq = norm(song[pos[:, None] + np.arange(512)], axis=-1)[:, None].astype(np.float32)
        seq_lab = np.full_like(pos, -1)
        starts = np.searchsorted(pos, np.array(grp.pos_start*sr_new)) - 1
        ends = np.searchsorted(pos, np.array(grp.pos_end*sr_new))
        for s, e, v in zip(starts, ends, grp.positif_negatif):
            seq_lab[s:e] = v != 'n'
        pos = (pos/sr_new).astype(np.float32) * 10
        if self.train:
            pos += np.random.uniform(-50, 50)
        return seq, label, seq_lab, pos


@torch.no_grad()
def precision_recall_accuracy(preds, labels):
    id_preds = torch.max(preds, 1)[1]
    tp = ((id_preds + labels) == 2).sum().item()
    return 100 * tp / (id_preds.sum().item() + 1e-8),\
           100 * tp / (labels.sum().item() + 1e-8),\
           100 * (id_preds == labels).sum().item() / len(labels)


def main(args):
    hp, lr, ne, wd, bs = args.hp, args.lr, args.ne, args.wd, args.bs
    sess_name = time.strftime('%x_%X').replace('/', '-') + f':{hp=}:{lr=}:{ne=}:{wd=}:{bs=}'

    df = pd.read_excel(args.df_path, 'annot_click', usecols='A:E').dropna()
    df = df[~df['File '].str.startswith('/nfs/NAS4/')]
    mask = df['File '].str.startswith('LOT2/BON')

    detector = Detector()

    state_dict = torch.load(args.weight)['model']

    try:
        detector.load_state_dict(state_dict)
    except RuntimeError:
        print("Error Load detector weights")
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        # load params
        detector.load_state_dict(new_state_dict)

    transformer = TransformerModel()
    both = Both(detector, transformer)

    train_dst = Sequence(True, 0, df[~mask], args.data_path, hp=hp, balance=True)
    test_dst = tch.datasets.CachedDataset(Sequence(False, 0, df[mask], args.data_path, hp=hp))
    train_dl = DataLoader(train_dst, batch_size=bs, shuffle=True, drop_last=True, num_workers=8, prefetch_factor=4, pin_memory=True)
    test_dl = DataLoader(test_dst, batch_size=bs, num_workers=8, prefetch_factor=4, shuffle=True, drop_last=True)

    loss_dfn = tch.loss.TemperedCrossEntropyLoss(0.6, 2.)
    loss_tfn = torch.nn.CrossEntropyLoss()

    @torch.no_grad()
    def scores(seq):
        scr = np.zeros(seq.shape[:2])
        bs = len(seq)
        nb_trames = seq.shape[1]
        for i in range(ceil(nb_trames/4096)):
            sli = slice(i*4096, (i+1)*4096)
            detec = tempered_softmax(detector(seq[:, sli].reshape(-1, 1, 512))[0].reshape(bs, -1, 2), 2.)[..., 1]
            mask = (detec.T >= torch.nn.MaxPool1d(3, stride=1, padding=1)(detec.T)).T
            detec = mask * detec
            scr[:, sli] = detec.cpu().numpy()
        return scr

    def train(batch):
        seq, lab, seq_lab, pos = batch
        detector.eval()
        scr = scores(seq)
        detector.train()
        top500 = np.argsort(scr)[:, -500:]
        #rand500 = np.random.randint(0, scr.shape[-1], len(scr)*500).reshape([len(scr), 500])
        #top500 = np.concatenate([top500, rand500], axis=1)
        token = detector(seq[np.arange(len(seq))[:, None], top500].reshape(-1, 1, 512))[1].reshape(len(seq), -1, 32)
        pos = pos[np.arange(len(seq))[:, None], top500]
        srt = np.argsort(pos.cpu().numpy())
        token = token[np.arange(len(pos))[:, None], srt]
        pos = pos[np.arange(len(pos))[:, None], srt]
        tpred = transformer(token, pos)
        loss_t = loss_tfn(tpred, lab)
        arg_f = np.where(seq_lab.cpu().numpy() == 0)
        arg_t = np.where(seq_lab.cpu().numpy() == 1)
        if len(arg_f[0]) + len(arg_t[0]) > 32:
            idx_t = np.random.choice(len(arg_t[0]), min(len(arg_t[0]), 32), False)
            idx_f = np.random.choice(len(arg_f[0]), min(len(arg_f[0]), 32), False)
            arg_t = (arg_t[0][idx_t], arg_t[1][idx_t])
            arg_f = (arg_f[0][idx_f], arg_f[1][idx_f])
            arg = (np.r_[arg_f[0], arg_t[0]], np.r_[arg_f[1], arg_t[1]])
            dpred = detector(seq[arg])[0]
            loss_d = loss_dfn(dpred, seq_lab[arg])
            loss = loss_t + loss_d
        elif len(arg_f[0]) or len(arg_t[0]):
            arg = (np.r_[arg_f[0], arg_t[0]], np.r_[arg_f[1], arg_t[1]])
            dpred = detector(seq[arg])[0]
            loss_d = loss_dfn(dpred, seq_lab[arg])
            loss = loss_t + loss_d
        else:
            loss_d = None
            loss = loss_t
        loss.backward()

        precision, recall, accuracy = precision_recall_accuracy(tpred, lab)
        return {'loss': loss, 'loss_t': loss_t, 'loss_d': loss_d,
                'accuracy': accuracy, 'precision': precision, 'recall': recall}

    def test(batch):
        seq, lab, seq_lab, pos = batch
        scr = scores(seq)
        top500 = np.argsort(scr)[:, -500:]
        #rand500 = np.random.randint(0, scr.shape[-1], len(scr)*500).reshape([len(scr), 500])
        #top500 = np.concatenate([top500, rand500], axis=1)
        token = detector(seq[np.arange(len(seq))[:, None], top500].reshape(-1, 1, 512))[1].reshape(len(seq), -1, 32)
        pos = pos[np.arange(len(seq))[:, None], top500]
        srt = np.argsort(pos.cpu().numpy())
        token = token[np.arange(len(pos))[:, None], srt]
        pos = pos[np.arange(len(pos))[:, None], srt]
        tpred = transformer(token, pos)
        loss = loss_tfn(tpred, lab)

        precision, recall, accuracy = precision_recall_accuracy(tpred, lab)
        return {'loss_t': loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall}

    recipe = TrainAndTest(both, train, test, train_dl, test_dl,
                          test_every=64, log_every=16, checkpoint='transformer_carimam/' + sess_name, visdom_env=None,
                          key_best=(lambda x: x['test_loop']['callbacks']['state']['metrics']['accuracy']))
    opt = tch.optim.RAdamW(both.parameters(), lr, weight_decay=wd)
    recipe.callbacks.add_callbacks([
                tcb.Optimizer(opt, log_lr=True, clip_grad_norm=5),
                tcb.LRSched(torch.optim.lr_scheduler.LinearLR(opt, start_factor=1, end_factor=0, total_iters=ne + 1),
                            metric=None),
                tcb.WindowedMetricAvg('loss'),
                tcb.WindowedMetricAvg('loss_t'),
                tcb.WindowedMetricAvg('loss_d'),
                tcb.WindowedMetricAvg('accuracy'),
                tcb.WindowedMetricAvg('precision'),
                tcb.WindowedMetricAvg('recall')])
    recipe.callbacks.add_epilogues([tcb.TensorboardLogger(log_dir='transformer_carimam/tfb/'
                                                                  + sess_name + '/train', log_every=64)])
    recipe.test_loop.callbacks.add_callbacks([tcb.EpochMetricAvg('loss_t', False),
                                              tcb.EpochMetricAvg('accuracy', False),
                                              tcb.WindowedMetricAvg('precision', False),
                                              tcb.WindowedMetricAvg('recall', False),
                                              ])
    recipe.test_loop.callbacks.add_epilogues([tcb.TensorboardLogger(log_dir='transformer_carimam/tfb/'
                                                                            + sess_name + '/test', log_every=-1)])
    recipe.to('cuda:' + args.device)
    recipe.run(ne)
    return 0


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--data_path", type=str, default='DATA/', help="Path to the wav main folder")
    parser.add_argument("--weight", type=str, default='detector_carimam/'
                                                      '04-06-22_11:48:04:hp=False:lr=0.001:ne=80:wd=0.05:bs=32/'
                                                      'ckpt_3968.pth', help="Path to the weight for the detector")
    parser.add_argument("--df_path", type=str, default='Annotation CARIMAM.xlsx', help="Path to the wav main folder")
    parser.add_argument("--hp", action='store_true', help="Highpass filter frequency")
    parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
    parser.add_argument("--ne", type=int, default=150, help="Number of epoch")
    parser.add_argument("--wd", type=float, default=5e-2, help="Weight decay")
    parser.add_argument("--bs", type=int, default=4, help="Batch size")
    parser.add_argument("--device", type=str, default='0', help='Gpu device')
    sys.exit(main(parser.parse_args()))
