import os
from math import ceil
import argparse
import sys
import time
import librosa as lr
import matplotlib.pyplot as plt

import numpy as np
from matplotlib.pyplot import close
import pandas as pd
import soundfile as sf
from sklearn.metrics import roc_auc_score, RocCurveDisplay
import scipy.signal as sg
from model_both import Detector, TransformerModel, Both, Context, Context2, Context_dil,  Context_dil_2D, Context_rnn_1d, Context_ViT_1d
import torchelie as tch
import torchelie.callbacks.callbacks as tcb
import torchelie.utils as tu
from torchelie.loss.bitempered import tempered_softmax
import torch
from torch.utils.data import DataLoader
from torchelie.recipes import TrainAndTest
import glob
import tqdm

import scipy as scp
import scipy.fft

from joblib import Parallel, delayed
import multiprocessing


_SR = 256_000
EPS = 1e-16

def norm(x, eps=1e-10, axis=None):
    return (x - x.mean(axis, keepdims=True)) / (x.std(axis, keepdims=True) + eps)


class Click:
    #NOT USED
    def __init__(self, train, rank, df, data_path, hp=False):
        self.train = train
        self.rank = rank
        self.df = df
        self.data_path = data_path
        self.label = np.array(df.positif_negatif != 'n').astype(int)
        self.hp = hp
        self.hp_sos = sg.butter(6, 2500, 'hp', output='sos', fs=_SR)


    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        row = self.df.iloc[item]
        s, e = row[['pos_start', 'pos_end']]
        p = s if s == e else np.random.uniform(s, e)
        f = row['File ']
        f = f if f.startswith('/nfs') else os.path.join(self.data_path, f)
        p = int(sf.info(f).samplerate * p)
        sample, sr = sf.read(f, start=p-512, stop=p+512)
        label = self.label[item]
        if sample.ndim > 1:
            sample = sample[:, 0]
        if sr != _SR:
            #sample = sg.resample(sample, int(len(sample) * _SR / sr))
            song = sg.resample_poly(song, int(_SR/np.gcd(_SR, sr)), int(sr/ np.gcd(_SR, sr)))
            sr = _SR
        if self.hp:
            sample = sg.sosfiltfilt(self.hp_sos, sample)
        sample = norm(sample)
        if self.train:
            sample = sg.resample(sample, max(512, int(len(sample) * np.random.uniform(0.9, 1.1))))
            sample += np.random.normal(0, 1, len(sample)) * 10**np.random.uniform(-4, -0.5)
            sample += norm(np.random.normal(0, 1, len(sample)).cumsum()) * 10**np.random.uniform(-4, -0.5)
            if np.random.rand() < 0.5:
                f = np.random.uniform(1500, sr/2)
                if f*1.1 > sr/2:
                    sos = sg.butter(3, 2*f*0.9/sr, 'lowpass', output='sos')
                else:
                    sos = sg.butter(3, [2*f*0.9/sr, 2*f*1.1/sr], 'bandstop', output='sos')
                sample = sg.sosfiltfilt(sos, sample)
            if np.random.rand() < 0.2:
                sample += np.roll(sample, np.random.randint(256-64, 256+64))
            sample = norm(sample)
        if len(sample) > 512:
            p = np.random.randint(len(sample) - 512) if self.train else (len(sample) - 512)//2
            sample = sample[p:p+512]
        return sample.astype(np.float32)[None], label


class Sequence:
    def __init__(self, train, df, data_path, nb_ctxt_frame=512, hp=False, balanced=False, sess_aug=False):
        self.train = train
        self.df = df

        self.data_path = data_path
        self.hp = hp
        self.hp_sos = sg.butter(6, 2500, 'hp', output='sos', fs=_SR)
        self.nb_ctxt_frame = nb_ctxt_frame # Nomber of frame before and after anotate click
        self.frame_size = 2*int(_SR * 0.001) #frame duration in sample
        self.hop_size = self.frame_size//2

        # global idx 
        # idx = 0
        # def filter(grp):
        #     global idx
        #     print("%d/%d"%(idx, 894))
        #     idx = idx + 1 
        #     grp_pos = grp[grp.positif_negatif != 'n']
        #     grp_mid = (grp.pos_start + grp.pos_end).values[:, None]/2
        #     return grp[((abs(grp_mid - grp_pos.pos_end.values) > 0.001*2*nb_ctxt_frame).all(1) & (abs(grp_pos.pos_start.values - grp_mid) > 0.001*2*nb_ctxt_frame).all(1)) | (grp.positif_negatif != 'n')]
    
        # df.groupby(df.File).apply(filter)
        # df.to_pickle('Dclde_click_extract_cured_pretty2_filtered.pkl')


        # df_pos = df[df.positif_negatif != 'n']
        # df_save = df.copy()
        # import ipdb; ipdb.set_trace()

        # #import ipdb; ipdb.set_trace()
        # for idx, item in tqdm.tqdm(df.iterrows()):
        #     if item.positif_negatif == 'n':
        #         pos_mid = (item.pos_start + item.pos_end)/2
        #         pos_margin_left = pos_mid - (nb_ctxt_frame * 0.001)
        #         pos_margin_right = pos_mid + (nb_ctxt_frame * 0.001)

        #         df_pos_file = df_pos[ df_pos["File"]== item["File"] ]
                                
        #         if np.logical_and(df_pos_file.pos_start > pos_margin_left, df_pos_file.pos_end < pos_margin_right).sum() > 0 or \
        #         np.logical_and(df_pos_file.pos_end < pos_margin_left, df_pos_file.pos_end > pos_margin_right).sum() > 0 :
        #             df_save = df_save.drop(index=idx)
                    
        # df = df_save    
        #print("After analysis")  
        #print((df.positif_negatif == 'n').sum())
        #print((df.positif_negatif != 'n').sum())

        if balanced:
            nb_items = max((df.positif_negatif != 'n').sum(), (df.positif_negatif == 'n').sum())
            df_pos = df[df.positif_negatif != 'n']
            df_neg = df[df.positif_negatif == 'n']
            diff_pos_neg = len(df_pos) - len(df_neg)
            if diff_pos_neg > 0:
                self.df = pd.concat([self.df, df_neg.sample(diff_pos_neg, replace=True)])
            if diff_pos_neg < 0:
                self.df = pd.concat([self.df, df_pos.sample(-diff_pos_neg, replace=True)])

        self.list_noise_files = None
        if sess_aug == True : #ajouter du signal d'autre session comme data augmentation
            self.list_noise_files = glob.glob("/short/CARIMAM/DATA/DATA_CLEAN/*/*.WAV")
            #self.list_noise_files = glob.glob("/nfs/NAS3/SABIOD/SITE/ANTILLES_CCS_2021/Exp*/Wavs/*/*.wav")
            #self.list_noise_files = glob.glob("/nfs/NAS7/SABIOD/SITE/SUIVI_FORT_DE_FRANCE/suivi_baie_fdf/hydro_bleu/session2_20220418_20220501_piles_est/wav/*.wav")
            #self.list_noise_files = glob.glob("/nfs/NAS3/SABIOD/SITE/PORTNOUVELLE/20220623_20220812/*.wav")
            self.nb_noise_files = len(self.list_noise_files)

        self.window = sg.windows.hann(512)
        #self.nfft = 1024

    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
        row = self.df.iloc[item]
        file = row['File']
        label = int(row['positif_negatif'] != 'n')  # 0->negatif, 1->positif

        file = file if file.startswith('/nfs') else os.path.join(self.data_path, file)
        click_start, click_end = row[['pos_start', 'pos_end']]
        pos_click = click_start if click_start == click_end else np.random.uniform(click_start, click_end)

        pos_click = int(sf.info(file).samplerate * pos_click)

        #scale_fact = 1.1 # sample rate scale factor for data augmentation
        #click_margin = int(np.ceil(self.nb_ctxt_frame*scale_fact+1))*self.hop_size
        click_margin = self.nb_ctxt_frame * 2 * int(sf.info(file).samplerate * 0.001)
        sig, sr = sf.read(file, start=max(0, pos_click-click_margin), stop=pos_click+click_margin)
        if sig.ndim > 1:
            sig = sig[:, 0]

        if sr != _SR:
            sig = sg.resample(sig, int(len(sig) * _SR / sr))
            #sig = sg.resample_poly(sig, int(_SR/np.gcd(_SR, sr)), int(sr/ np.gcd(_SR, sr)))
            sr = _SR
            click_margin = self.nb_ctxt_frame * 2 * self.hop_size

        if len(sig) < 2*(click_margin):
            diff_len = (2*(click_margin)) - len(sig)
            sig = np.pad(sig, (diff_len//2, diff_len//2),'constant', constant_values=0)
        
        # if pos_click-click_margin < 0:
        #      sig = np.pad(sig, (click_margin-pos_click, 0),'constant', constant_values=0)

        # if len(sig) < 2*(click_margin):
        #      sig = np.pad(sig, (0, (2*(click_margin))-len(sig)),'constant', constant_values=0)

        #Transform function
        if self.hp:
            sig = sg.sosfiltfilt(self.hp_sos, sig)
        
        sig = norm(sig)
        
        if self.train:  
            if self.list_noise_files == None:
                #sig = sg.resample(sample, max(512, int(len(sig) * np.random.uniform(0.8, 1.2))))      # Data augmentation, accelere/ralenti signal
                sig += np.random.normal(0, 1, len(sig)) * 10 **np.random.uniform(-2, -0.5)             # Ajoute du bruit
                sig += norm(np.random.normal(0, 1, len(sig)).cumsum()) * 10**np.random.uniform(-2, -0.5) # Ajoute du bruit
            else :
                #import ipdb; ipdb.set_trace()
                ok_file = False
                while ok_file == False:     #If noise_file is corrupted
                    try:
                        noise_file = self.list_noise_files[np.random.randint(self.nb_noise_files)]

                        # get length file
                        len_noise_file = int(sf.info(noise_file).samplerate * sf.info(noise_file).duration)
                        rand_pos = np.random.randint(max(1, len_noise_file - len(sig)))

                        # load file random sample
                        noise, sr = sf.read(noise_file, start=rand_pos, stop=rand_pos+len(sig))

                        if noise.ndim > 1:
                            noise = noise[:, 0]
                            
                        if sr != _SR:
                            noise = sg.resample(noise, int(len(noise) * _SR / sr))
                            sr = _SR
                        
                        ok_file = True
                    except:
                        pass

                noise = noise[:len(sig)]
                if len(noise) < len(sig):
                    len_diff = np.abs(len(noise) - len(sig))
                    noise = np.pad(noise, (int(np.floor(len_diff/2)), int(np.ceil(len_diff/2))), 'wrap')

                if self.hp:
                    noise = sg.sosfiltfilt(self.hp_sos, noise)

                noise = norm(noise)

                # add to signal avec un gain random
                noise_gain = np.random.uniform(low=2., high=5.)       # TODO CHANGER POUR UNE ECHELLE dB
                sig = sig + (noise * noise_gain)      # les valeurs mins/max sont a fixer

            if np.random.rand() < 0.5:                                                                  # Ajout d'un filtre lowpass ou band pass avec une frequence aleatoire
                f = np.random.uniform(1500, sr/2)
                if f*1.1 > sr/2:
                    sos = sg.butter(3, 2*f*0.9/sr, 'lowpass', output='sos')
                else:
                    sos = sg.butter(3, [2*f*0.9/sr, 2*f*1.1/sr], 'bandstop', output='sos')
                sig = sg.sosfiltfilt(sos, sig)
            if np.random.rand() < 0.05:                                                             
                sig += 0.1*np.roll(sig, np.random.randint(256-64, 256+64))

            sig = norm(sig)

        sig = np.roll(sig, np.random.randint(self.hop_size) - (self.hop_size//2)) #//2 ou 0 ?
        seq = lr.util.frame(sig, frame_length=self.frame_size, hop_length=self.hop_size, axis=0, writeable=True)
        seq = norm(seq, axis=1)
        click_frame = seq.shape[0] // 2
        shift = np.random.randint(self.nb_ctxt_frame) - (self.nb_ctxt_frame//2)
        #shift = 0
        #import ipdb; ipdb.set_trace()
        seq = seq[ click_frame - self.nb_ctxt_frame + shift: click_frame + self.nb_ctxt_frame + shift] #click frame is at seq[self.nb_ctxt_frame] aka mid+1
        click_frame = self.nb_ctxt_frame - shift

        # if False:    #Experimental
        #    for idx in range(len(seq)):
        #     seq[idx] = np.angle(scp.fft.fft(seq[idx] * self.window, n=self.nfft))[:self.nfft//2] 

        return seq.astype(np.float32), label, click_frame


@torch.no_grad()
def precision_recall_accuracy(preds, labels):
    id_preds = torch.max(preds, 1)[1]
    tp = ((id_preds + labels) == 2).sum().item()
    fp = ((id_preds).sum() - tp).item()
    return 100 * tp / (id_preds.sum().item() + 1e-8),\
           100 * tp / (labels.sum().item() + 1e-8),\
           100 * (id_preds == labels).sum().item() / len(labels), tp, fp


class ROC_curve(tu.AutoStateDict):
    def __init__(self, name):
        super(ROC_curve, self).__init__()
        self.name = name
        self.preds = None
        self.labels = None

    def on_epoch_start(self, state):
        self.preds = list()
        self.labels = list()

        if self.name in state['metrics']:
            del state['metrics'][self.name]

    @torch.no_grad()
    def on_batch_end(self, state):
        self.preds.extend(state['pred'].detach().cpu()[:, 1])
        self.labels.extend(state['label'].cpu())

    def on_epoch_end(self, state):
        state['metrics'][self.name] = RocCurveDisplay.from_predictions(self.labels, self.preds).figure_              


class export_datasample(tu.AutoStateDict):
    def __init__(self, name, dest_path, frame_size):
        super(export_datasample, self).__init__()
        self.name = name
        self.hop_size = frame_size // 2 
        self.dest_path = dest_path +'/'
        os.makedirs(self.dest_path)
    
    @torch.no_grad()
    def on_epoch_start(self, state):
        if self.name in state['metrics']:
            del state['metrics'][self.name]

    @torch.no_grad()
    def on_batch_end(self, state):
        pass

    @torch.no_grad()
    def on_epoch_end(self, state):
        pred, label, seq = state['pred'], state['batch'][1], state['batch'][0]
        pred = pred[:, 1]
        seq = np.hstack(seq[:,:self.hop_size])
        seq = seq / np.abs(seq.max())
 
        dist_pred = np.sqrt(np.power(pred - label, 2))
        idx_best = np.argmin(dist_pred)
        idx_worst = np.argmax(dist_pred)
        sf.write(self.dest_path + "/best-" + str(label[idx_best]) + "-" + str(time.time()) + ".wav", seq[idx_best], _SR)
        sf.write(self.dest_path + "/worst-" + str(label[idx_best]) + "-" + str(time.time()) + ".wav", seq[idx_worst], _SR)


class ROC_AUC(tu.AutoStateDict):
    def __init__(self, name):
        super(ROC_AUC, self).__init__()
        self.name = name
        self.preds = None
        self.labels = None
    
    @torch.no_grad()
    def on_epoch_start(self, state):
        self.preds = list()
        self.labels = list()

        if self.name in state['metrics']:
            del state['metrics'][self.name]

    @torch.no_grad()
    def on_batch_end(self, state):
        self.preds.extend(state['pred'].detach().cpu()[:, 1])
        self.labels.extend(state['label'].cpu())

    @torch.no_grad()
    def on_epoch_end(self, state):
        self.preds = np.array(self.preds)
        self.labels = np.array(self.labels)
        state['metrics'][self.name] = self.auc(self.labels, self.preds)

    @torch.no_grad()
    def auc(self, labels, preds):
        if sum(labels) == len(labels) or sum(labels) == 0:
            return float('NaN')
        return roc_auc_score(labels, preds)


class FpRate(tu.AutoStateDict):
    def __init__(self, name):
        super(FpRate, self).__init__()
        self.name = name
        self.preds = None
        self.labels = None

    @torch.no_grad()
    def on_epoch_start(self, state):
        self.preds = list()
        self.labels = list()

        if self.name in state['metrics']:
            del state['metrics'][self.name]

    @torch.no_grad()
    def on_batch_end(self, state):
        self.preds.extend(state['pred'].detach().cpu()[:, 1])
        self.labels.extend(state['label'].cpu())

    @torch.no_grad()
    def on_epoch_end(self, state):
        self.preds = np.array(self.preds)
        self.labels = np.array(self.labels)
        pos = self.preds > 0.5 
        tp = ((pos + self.labels) == 2).sum()
        state['metrics'][self.name] = (pos.sum() - tp) / (pos.sum() + EPS)


def main(args):
    hp, lr, ne, wd, bs, lmbd, blcd, aug, frame, cstm_name = args.hp, args.lr, args.ne, args.wd, args.bs, args.lmbda, args.balanced, args.aug, args.nb_frames, args.custom_name
    sess_name = time.strftime('%x_%X').replace('/', '-') + f':{cstm_name=}:{hp=}:{lr=}:{ne=}:{wd=}:{bs=}:{lmbd=}:{blcd=}:{aug=}:{frame=}' 

    df_train = pd.read_pickle('Dclde_click_extract_cured_pretty2_filtered.pkl')
    df_train['File'] = df_train['File']
    df_train['pos_start'] = df_train['Pos_in_file']
    df_train['pos_end'] = df_train['Pos_in_file']

    df_test = pd.read_excel(args.df_path, 'annot_click_apo', usecols='A:E').dropna()
    df_test = df_test[~df_test['File'].str.contains('/nfs/NAS4/')]
    df_test = df_test[~df_test['File'].str.contains('/nfs/NAS3/')]
    df_test['File'] = "/nfs/NAS6/SABIOD/SITE/CARIMAM/DATA/" + df_test['File']

    # mask = np.random.rand(len(df)) < 0.1
    # mask = ((df['File'].str.startswith('LOT2/BERMUDE'))|(df.File.str.startswith('LOT2/GUYANNE')) | (df.File.str.startswith('LOT2/ANG')))

    #mask_test = np.logical_or(df['File'].str.startswith('LOT2/GUYA'), df['File'].str.startswith('LOT2/StMARTIN'))
    #mask_train = np.logical_and( ~df['File'].str.startswith('LOT2/GUYA'), ~df['File'].str.startswith('LOT2/StMARTIN'))
    #mask_test = df['File'].str.startswith('LOT2/ANG')

    print("Train_set pos:%d, neg:%d"%((df_train.positif_negatif != 'n').sum(), (df_train.positif_negatif == 'n').sum()))
    print("Test_set pos:%d, neg:%d"%((df_test.positif_negatif != 'n').sum(), (df_test.positif_negatif == 'n').sum()))
    

    margin_frames = args.nb_frames
    seq_dim = 2*margin_frames-1

    detector = Detector()
    if args.vit:
        context_cnn = Context_ViT_1d(input_len = margin_frames*2)
    else:
        context_cnn = Context_rnn_1d()
    #context_cnn = Context_dil(input_len = margin_frames*2)    
    #context_cnn = Context_dil_2D(input_len = margin_frames*2)
    both = Both(detector, context_cnn)

    if args.reload_weight == False:                     # init weight detector
        state_dict = torch.load(args.weight)['model']
        try:
            detector.load_state_dict(state_dict)
        except RuntimeError:
            print("Error Load detector weights")
            exit()
            # from collections import OrderedDict
            # new_state_dict = OrderedDict()
            # for k, v in state_dict.items():
            #     name = k[7:]  # remove `module.`
            #     new_state_dict[name] = v
            # # load params
            # detector.load_state_dict(new_state_dict)
    else:                                               # relaod weight detector+context 
        state_dict = torch.load(args.weight)['model']
        plt.close('all')
        try:
            both.load_state_dict(state_dict)
            both.to('cuda:' + args.device)
            both.eval()
        except RuntimeError:
            print("Error Load context weights")
            exit()


    if False:   #debug
        from torchsummary import summary
        import ipdb; ipdb.set_trace()
        #context_cnn = Context_ViT_1d()
        context_cnn = Context_rnn_1d()
        summary(context_cnn.cuda(), input_size=(1, margin_frames*2, 32))

    train_dst = Sequence(True, df_train, args.data_path, nb_ctxt_frame=margin_frames, hp=hp, balanced=blcd, sess_aug=aug)
    test_dst = Sequence(False, df_test, args.data_path, nb_ctxt_frame=margin_frames, hp=hp, balanced=blcd)

    train_dl = DataLoader(train_dst, batch_size=bs, shuffle=True, drop_last=True, num_workers=12, prefetch_factor=4, pin_memory=True)
    test_dl = DataLoader(test_dst, batch_size=bs, shuffle=True,  drop_last=True, num_workers=12, prefetch_factor=4, pin_memory=True)

    loss_dfn = torch.nn.CrossEntropyLoss() #tch.loss.TemperedCrossEntropyLoss(0.6, 2.)
    loss_tfn = torch.nn.CrossEntropyLoss()

    #import ipdb; ipdb.set_trace()
    train_dst[110]

    def train(batch):
        seq, lab, idx_frame = batch
        loss_detec = 0
        loss_cnn = 0

        bs_detec = seq.shape[0] * seq.shape[1]

        #import ipdb; ipdb.set_trace()

        #Check dimension of seq for detector
        if args.no_detec_finetune == True:
            with torch.no_grad():
                pred_detec, emb = detector(seq.reshape(bs_detec, 1, seq.shape[-1]))
                pred_detec = pred_detec.reshape(seq.shape[0], seq.shape[1], pred_detec.shape[-1])
                emb = emb.reshape(seq.shape[0], seq.shape[1], emb.shape[-1])
        else:
            pred_detec, emb = detector(seq.reshape(bs_detec, 1, seq.shape[-1]))
            pred_detec = pred_detec.reshape(seq.shape[0], seq.shape[1], pred_detec.shape[-1])
            emb = emb.reshape(seq.shape[0], seq.shape[1], emb.shape[-1])

        central_pred_detec = pred_detec[np.arange(emb.shape[0]), idx_frame] #seq.shape[1]//2]

        if args.no_loss_detec == True:
            loss_detec = 0
        else:
            loss_detec = loss_dfn(central_pred_detec, lab)

        pred_cnn, emb_cnn = context_cnn(emb) # with LSTM + Transformer
        #pred_cnn, emb_cnn = context_cnn(emb.permute(0,2,1))      # with CNN
        #pred_cnn, emb_cnn = context_cnn(emb.unsqueeze(1))
        loss_cnn = loss_tfn(pred_cnn, lab)

        #pred_cnn = central_pred_detec
        #loss_cnn = loss_dfn(pred_cnn, lab)

        loss = args.lmbda * loss_cnn + (1-args.lmbda)*loss_detec
        loss.backward()

        precision_cnn, recall_cnn, accuracy_cnn, _, _ = precision_recall_accuracy(pred_cnn, lab)
        precision_detec, recall_detec, accuracy_detec, _, _ = precision_recall_accuracy(central_pred_detec, lab)
        return {'loss': loss, 'loss_cnn': loss_cnn, 'loss_d': loss_detec, 'label': lab, 'pred': pred_cnn,
                'accuracy': accuracy_cnn, 'precision': precision_cnn, 'recall': recall_cnn}

    def test(batch):
        seq, lab, idx_frame = batch

        bs_detec = seq.shape[0] * seq.shape[1]
        pred_detec, emb = detector(seq.reshape(bs_detec, 1, seq.shape[-1]))
        
        pred_detec = pred_detec.reshape(seq.shape[0], seq.shape[1], pred_detec.shape[-1])
        emb = emb.reshape(seq.shape[0], seq.shape[1], emb.shape[-1])

        central_pred_detec = pred_detec[np.arange(emb.shape[0]), idx_frame] #seq.shape[1]//2]
        
        if args.no_loss_detec == True:
            loss_detec = 0
        else:
            loss_detec = loss_dfn(central_pred_detec, lab)

        pred_cnn, emb_cnn = context_cnn(emb)
        #pred_cnn, emb_cnn = context_cnn(emb.permute(0,2,1))
        #pred_cnn = central_pred_detec
        #pred_cnn, emb_cnn = context_cnn(emb.unsqueeze(1)) 
        loss_cnn = loss_tfn(pred_cnn, lab)

        loss = args.lmbda * loss_cnn + (1-args.lmbda)*loss_detec

        #pred_cnn = tempered_softmax(pred_cnn, 2.)
        pred_cnn = torch.nn.functional.softmax(pred_cnn, 1)

        precision_detec, recall_detec, accuracy_detec, tp, fp = precision_recall_accuracy(central_pred_detec, lab)
        precision_cnn, recall_cnn, accuracy_cnn, _, _ = precision_recall_accuracy(pred_cnn, lab)
        return {'loss': loss, 'accuracy': accuracy_cnn, 'label': lab, 'pred': pred_cnn, 'tp': tp, 'fp': fp,
            'precision': precision_cnn, 'recall': recall_cnn}


    recipe = TrainAndTest(both, train, test, train_dl, test_dl,
                          test_every=256, log_every=64, checkpoint='context_carimam/' + sess_name, visdom_env=None)#,
                          #key_best=(lambda x: x['test_loop']['callbacks']['state']['metrics']['accuracy']))
    #opt = tch.optim.RAdamW(detector.parameters(), lr, weight_decay=wd)
    opt = torch.optim.AdamW(detector.parameters(), lr, weight_decay=wd)

    recipe.callbacks.add_callbacks([
                tcb.Optimizer(opt, log_lr=True, clip_grad_norm=5),
                tcb.LRSched(tch.lr_scheduler.LinearDecay(opt, total_iters=ne*len(train_dl)),
                            metric=None, step_each_batch=True),
                tcb.WindowedMetricAvg('loss'),
                tcb.WindowedMetricAvg('loss_cnn'),
                tcb.WindowedMetricAvg('loss_d'),
                tcb.WindowedMetricAvg('accuracy'),
                tcb.WindowedMetricAvg('precision'),
                tcb.WindowedMetricAvg('recall'),
                #tcb.WindowedMetricAvg('auc'),
                ROC_AUC('roc_auc'),
                FpRate('fp'),
                ])
    recipe.callbacks.add_epilogues([tcb.TensorboardLogger(log_dir='context_carimam/tfb/'
                                                                  + sess_name + '/train', log_every=64)])
    recipe.test_loop.callbacks.add_callbacks([tcb.EpochMetricAvg('loss', False),
                                              tcb.EpochMetricAvg('accuracy', False),
                                              tcb.WindowedMetricAvg('precision', False),
                                              tcb.WindowedMetricAvg('recall', False),
                                              #tcb.EpochMetricAvg('auc', False),
                                              ROC_AUC('roc_auc'),
                                              export_datasample('export_data', 'context_carimam/' + sess_name + "/monitor/", frame_size = 2*int(_SR * 0.001)),
                                              #tcb.AccAvg(post_each_batch=False),
                                              FpRate('fp'),
                                              ])
    recipe.test_loop.callbacks.add_epilogues([tcb.TensorboardLogger(log_dir='context_carimam/tfb/'
                                                                            + sess_name + '/test', log_every=-1)])
    recipe.to('cuda:' + args.device)
    
    recipe.run(ne)

    tp_all = 0
    fp_all = 0
    for batch in test_dl:
        x, y, idx_pos = batch
        x = x.to('cuda:' + args.device)
        y = y.to('cuda:' + args.device)
        res = test((x, y, idx_pos))
        tp_all += res['tp']
        fp_all += res['fp']
    
    import ipdb; ipdb.set_trace()
    print(fp_all / (df[mask].positif_negatif == 'n').sum()) 
    print(tp_all / (df[mask].positif_negatif != 'n').sum())

    return 0


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--data_path", type=str, default='DATA/', help="Path to the wav main folder")
    parser.add_argument("--weight", type=str, default='detector_carimam/03-06-23_16:02:46:hp=True:lr=0.0005:ne=40:wd=0.05:bs=32:blcd=False:aug=True/ckpt_2624.pth', help="Path to the weight for the detector")
    parser.add_argument("--df_path", type=str, default='Annotation CARIMAM.xlsx', help="Path to the wav main folder")
    parser.add_argument("--hp", action='store_true', help="Highpass filter frequency")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--ne", type=int, default=50, help="Number of epoch")
    parser.add_argument("--wd", type=float, default=5e-2, help="Weight decay")
    parser.add_argument("--bs", type=int, default=16, help="Batch size")
    parser.add_argument("--device", type=str, default='0', help='Gpu device')
    parser.add_argument("--no_loss_detec", type=bool, help="Use detector loss for training")
    parser.add_argument("--lambda", dest="lmbda", type=float, default=0.8, help="Weighting between detecter loss and transformer loss")
    parser.add_argument("--nb_frames", type=int, default=256, help="number margin around anotated frame")
    parser.add_argument("--balanced", action='store_true', help="Use balanced train dataset")
    parser.add_argument("--no_detec_finetune", action='store_true', default=False, help="finetune detecter weigths")
    parser.add_argument("--aug", action='store_true', help="Use noise file as data augment")
    parser.add_argument("--custom_name", type=str, default='session_name', help="prefix name for weight file")
    parser.add_argument("--vit", action='store_true', help="Use ViT context detector")
    parser.add_argument("--reload_weight", action='store_true', help="Reload ViT weight")

    sys.exit(main(parser.parse_args()))
