import argparse
import sys
import time

import numpy as np
import pandas as pd
from model_both import Detector
import torchelie as tch
import torchelie.callbacks.callbacks as tcb
import torchelie.utils as tu
from torchelie.loss.bitempered import tempered_softmax
import torch
from torch.utils.data import DataLoader
from torchelie.recipes import TrainAndTest
from sklearn.metrics import roc_auc_score, RocCurveDisplay

from torch.utils.tensorboard import SummaryWriter
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt

from Click_dataset_dclde import *

@torch.no_grad()
def precision_recall_accuracy(preds, labels):
    id_preds = torch.max(preds, 1)[1]
    tp = ((id_preds + labels) == 2).sum().item()
    fp = ((id_preds).sum() - tp).item()
    return 100 * tp / (id_preds.sum().item() + 1e-8),\
           100 * tp / (labels.sum().item() + 1e-8),\
           100 * (id_preds == labels).sum().item() / len(labels), tp, fp

@torch.no_grad()
def auc(preds, labels):
    prob = preds[:, 1]
    if sum(labels) == len(labels) or sum(labels) == 0:
        return float('NaN')
    return roc_auc_score(labels.cpu().numpy(), prob.cpu().numpy())


class ROC_curve(tu.AutoStateDict):
    def __init__(self, name):
        super(ROC_curve, self).__init__()
        self.name = name
        self.preds = None
        self.labels = None

    def on_epoch_start(self, state):
        self.preds = list()
        self.labels = list()

        if self.name in state['metrics']:
            del state['metrics'][self.name]

    @torch.no_grad()
    def on_batch_end(self, state):
        self.preds.extend(state['pred'].detach().cpu()[:, 1])
        self.labels.extend(state['label'].cpu())

    def on_epoch_end(self, state):
        state['metrics'][self.name] = RocCurveDisplay.from_predictions(self.labels, self.preds).figure_              


def main(args):    
    hp, lr, ne, wd, bs, blcd, aug, cstm_name = args.hp, args.lr, args.ne, args.wd, args.bs, args.balanced, args.aug, args.custom_name
    sess_name = time.strftime('%x_%X').replace('/', '-') + f':{cstm_name=}:{hp=}:{lr=}:{ne=}:{wd=}:{bs=}:{blcd=}:{aug=}'

    df_train = pd.read_pickle('Dclde_click_extract_cured_pretty2.pkl')
    df_train['File'] = df_train['File']
    df_train['pos_start'] = df_train['Pos_in_file']
    df_train['pos_end'] = df_train['Pos_in_file']

    df_test = pd.read_excel(args.df_path, 'annot_click_apo', usecols='A:E').dropna()
    df_test = df_test[~df_test['File'].str.contains('/nfs/NAS4/')]
    df_test = df_test[~df_test['File'].str.contains('/nfs/NAS3/')]
    df_test['File'] = "/nfs/NAS6/SABIOD/SITE/CARIMAM/DATA/" + df_test['File']

    #mask = np.random.rand(len(df)) < 0.1
    #mask = ((df['File'].str.startswith('LOT2/BERMUDE'))|(df.File.str.startswith('LOT2/GUYANNE')) | (df.File.str.startswith('LOT2/ANG')))

    #mask_test = np.logical_or(df['File'].str.startswith('LOT2/GUYA'), df['File'].str.startswith('LOT2/StMARTIN'))
    #mask_train = np.logical_and( ~df['File'].str.startswith('LOT2/GUYA'), ~df['File'].str.startswith('LOT2/StMARTIN'))
    #mask_test = df['File'].str.startswith('LOT2/ANG')
    #df_test = df[mask_test]
    #df_train = df[~mask_test]
    
    print("Train_set pos:%d, neg:%d"%((df_train.positif_negatif != 'n').sum(), (df_train.positif_negatif == 'n').sum()))
    print("Test_set pos:%d, neg:%d"%((df_test.positif_negatif != 'n').sum(), (df_test.positif_negatif == 'n').sum()))

    detector = Detector()

    train_dst = Click(True, 0, df_train, args.data_path, hp=hp, balanced=blcd, sess_aug=aug)
    test_dst = Click(False, 0, df_test, args.data_path, hp=hp, balanced=blcd)
    train_dl = DataLoader(train_dst, batch_size=bs, shuffle=True, drop_last=True, num_workers=12,  prefetch_factor=4, pin_memory=True)
    test_dl = DataLoader(test_dst, batch_size=bs, num_workers=12, prefetch_factor=4, shuffle=True, drop_last=True)

    loss_fn = tch.loss.TemperedCrossEntropyLoss(0.6, 2.)

    #import ipdb; ipdb.set_trace()
    #tmp = test_dst[110]
    #import ipdb; ipdb.set_trace()
    #tmp = train_dst[4731688]
    # idx = 0
    # for elem, label in test_dst:
    #     print("%d/%d"%(idx, len(test_dst)))
    #     idx = idx + 1

    # idx = 0
    # for elem, label in train_dst:
    #     print("%d/%d"%(idx, len(train_dst)))
    #     idx = idx + 1

    def train(batch):
        x, y = batch
        pred = detector(x)[0]
        loss = loss_fn(pred, y)
        loss.backward()
        pred = tempered_softmax(pred, 2.)
        precision, recall, accuracy, _, _ = precision_recall_accuracy(pred, y)
        return {'loss': loss, 'accuracy': accuracy, 'auc': auc(pred, y), 
                'precision': precision, 'recall': recall, 'pred': pred}


    def test(batch):
        x, y = batch
        pred = detector(x)[0]
        loss = loss_fn(pred, y)
        pred = tempered_softmax(pred, 2.)
        precision, recall, accuracy, tp, fp = precision_recall_accuracy(pred, y)
        return {'loss': loss, 'accuracy': accuracy, 'auc': auc(pred, y), 
                'precision': precision, 'recall': recall, 'label': y, 'pred': pred, 'tp': tp, 'fp': fp}

    recipe = TrainAndTest(detector, train, test, train_dl, test_dl,
                          test_every=128, log_every=32, checkpoint='detector_carimam/' + sess_name, visdom_env=None)#,
                         # key_best=(lambda x: x['test_loop']['callbacks']['state']['metrics']['accuracy']))
    #opt = tch.optim.RAdamW(detector.parameters(), lr, weight_decay=wd)
    opt = torch.optim.AdamW(detector.parameters(), lr, weight_decay=wd)

    recipe.callbacks.add_callbacks([
                tcb.Optimizer(opt, log_lr=True, clip_grad_norm=5),
                tcb.LRSched(tch.lr_scheduler.LinearDecay(opt, total_iters=ne*len(train_dl)),
                             metric=None, step_each_batch=True),
                tcb.WindowedMetricAvg('loss'),
                tcb.WindowedMetricAvg('accuracy'),
                #tcb.AccAvg(),
                tcb.WindowedMetricAvg('precision'),
                tcb.WindowedMetricAvg('recall'),
                tcb.WindowedMetricAvg('auc')
                ])
    recipe.callbacks.add_epilogues([tcb.TensorboardLogger(log_dir='detector_carimam/tfb/'
                                                                  + sess_name + '/train', log_every=16)])
    recipe.test_loop.callbacks.add_callbacks([tcb.EpochMetricAvg('loss', False),
                                              tcb.EpochMetricAvg('accuracy', False),
                                              tcb.WindowedMetricAvg('precision', False),
                                              tcb.WindowedMetricAvg('recall', False),
                                              tcb.EpochMetricAvg('auc', False),
                                              ROC_curve('roc_curve'),
                                              #tcb.AccAvg(post_each_batch=False)#,
                                              #tcb.WindowedMetricAvg('tp'),
                                              #tcb.WindowedMetricAvg('fp')
                                              ])
    recipe.test_loop.callbacks.add_epilogues([tcb.TensorboardLogger(log_dir='detector_carimam/tfb/'
                                                                            + sess_name + '/test', log_every=16)])
    recipe.to('cuda:' + args.device)
    recipe.run(ne)

    tp_all = 0
    fp_all = 0
    for batch in test_dl:
        x, y = batch
        x = x.to('cuda:' + args.device)
        y = y.to('cuda:' + args.device)
        res = test((x, y))
        tp_all += res['tp']
        fp_all += res['fp']
    
    import ipdb; ipdb.set_trace()
    print(fp_all / (df[mask].positif_negatif == 'n').sum()) 
    print(tp_all / (df[mask].positif_negatif != 'n').sum())

    return 0

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--data_path", type=str, default='DATA/', help="Path to the wav main folder")
    parser.add_argument("--df_path", type=str, default='Annotation CARIMAM.xlsx', help="Path to the wav main folder")
    parser.add_argument("--hp", action='store_true', help="Highpass filter frequency")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--ne", type=int, default=40, help="Number of epoch")
    parser.add_argument("--wd", type=float, default=5e-2, help="Weight decay")
    parser.add_argument("--bs", type=int, default=32, help="Batch size")
    parser.add_argument("--device", type=str, default='0', help='Gpu device')
    parser.add_argument("--balanced", action='store_true', help="Use balanced train dataset")
    parser.add_argument("--aug", action='store_true', help="Use noise file as data augment")
    parser.add_argument("--custom_name", type=str, default='session_name', help="prefix name for weight file")

    sys.exit(main(parser.parse_args()))
