import argparse
import sys
import time
import os
import glob
import numpy as np
import pandas as pd
from tqdm import tqdm

import git
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from sklearn.metrics import roc_auc_score

from models import DetectorErbs
from audio_dataset import AudioSignal

import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('agg')

@torch.no_grad()
def precision_recall_accuracy(preds, labels):
    id_preds = torch.max(preds, 1)[1]
    tp = ((id_preds + labels) == 2).sum().item()
    fp = ((id_preds).sum() - tp).item()
    return 100 * tp / (id_preds.sum().item() + 1e-8),\
           100 * tp / (labels.sum().item() + 1e-8),\
           100 * (id_preds == labels).sum().item() / len(labels), tp, fp

@torch.no_grad()
def auc(preds, labels):
    prob = preds[:, 1]
    if sum(labels) == len(labels) or sum(labels) == 0:
        return float('NaN')
    return roc_auc_score(labels.cpu().numpy(), prob.cpu().numpy())

def main(args):    
    hp, lr, n_epoch, wd, bs, aug, cstm_name = args.hp, args.lr, args.ne, args.wd, args.bs, args.aug, args.custom_name
    #sess_name = time.strftime('%x_%X').replace('/', '-') + ":cstm_name=%s:hp=%r:lr=%f:ne=%d:wd=%f:bs=%d:aug=%r"%(cstm_name, hp, lr, ne, wd, bs, aug)
    sess_name = "cstm_name=%s:hp=%r:lr=%f:ne=%d:wd=%f:bs=%d:aug=%r"%(cstm_name, hp, lr, n_epoch, wd, bs, aug)
    sig_dur = 5.
    os.makedirs('runs/' + sess_name + '/', exist_ok=True)

    repo = git.Repo(
        os.path.dirname(os.path.abspath(__file__)),
        search_parent_directories=True
        )

    git_branch = repo.head.ref.name
    git_sha = repo.head.object.hexsha         

    df = pd.read_excel(args.df_path, 'annot_click_apo', usecols='A:E').dropna()
    df = df[~df['File'].str.contains('/nfs/NAS4/')]
    df = df[~df['File'].str.contains('/nfs/NAS3/')]
    
    mask = np.random.rand(len(df)) < 0.1
    #mask = ((df['File'].str.startswith('LOT2/BERMUDE'))|(df.File.str.startswith('LOT2/GUYANNE')) | (df.File.str.startswith('LOT2/ANG')))

    #mask_test = np.logical_or(df['File'].str.startswith('LOT2/GUYA'), df['File'].str.startswith('LOT2/StMARTIN'))
    #mask_train = np.logical_and( ~df['File'].str.startswith('LOT2/GUYA'), ~df['File'].str.startswith('LOT2/StMARTIN'))
    #mask_test = df['File'].str.startswith('LOT2/ANG')
    #df_test = df[mask_test]
    #df_train = df[~mask_test]

    print("Train_set pos:%d, neg:%d"%((df[~mask].positif_negatif != 'n').sum(), (df[~mask].positif_negatif == 'n').sum()))
    print("Test_set pos:%d, neg:%d"%((df[mask].positif_negatif != 'n').sum(), (df[mask].positif_negatif == 'n').sum()))

    detector = DetectorErbs()
    device = ('cuda:' + args.device)
    detector.to(device)

    writer_train = SummaryWriter('runs/tfb/'+sess_name+'_train')
    writer_test = SummaryWriter('runs/tfb/'+sess_name+'_test')

    train_dst = AudioSignal(True, df[~mask], args.data_path, hp=hp, noise_data_aug=aug, sig_dur=sig_dur)
    test_dst = AudioSignal(False, df[mask], args.data_path, hp=hp, sig_dur=sig_dur)
    train_dl = DataLoader(train_dst, batch_size=bs, shuffle=True, drop_last=True, num_workers=args.nb_workers,  prefetch_factor=4, pin_memory=True)
    test_dl = DataLoader(test_dst, batch_size=bs, shuffle=True, drop_last=True, num_workers=args.nb_workers, prefetch_factor=4)

    opt = torch.optim.AdamW(detector.parameters(), lr, weight_decay=wd)
    sch = torch.optim.lr_scheduler.LambdaLR(opt, lambda epoch : .999**epoch)        #TODO ajouter slow start
    loss_fn = nn.CrossEntropyLoss()

    def train(batch):
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        preds = detector(x)[0]
        loss = loss_fn(preds, y)
        loss.backward()
        preds = nn.functional.softmax(preds, dim=1)         

        precision, recall, accuracy, tp, fp = precision_recall_accuracy(preds, y)
        tpr =  tp / (y==1).sum()
        fpr = fp / (y==0).sum()
        return {'loss': loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 
            'auc': auc(preds, y), 'preds': preds, 'tpr': tpr , 'fpr': fpr}

    def test(batch):
        x, y = batch
        x = x.to(device)
        y = y.to(device)
        preds = detector(x)[0]
        loss = loss_fn(preds, y)
        preds = nn.functional.softmax(preds, dim=1)

        return {'loss': loss, 'labels': y, 'preds': preds}
        #precision, recall, accuracy, tp, fp = precision_recall_accuracy(preds, y)
        #return {
        # 'loss': loss, 'accuracy': accuracy, 'auc': auc(pred, y),
        # 'precision': precision, 'recall': recall, 
        # 'label': y, 'pred': pred, 'tp': tp, 'fp': fp
        # }

    start_epoch = 0

    if os.path.exists('runs/' + sess_name) is True and args.reload_weigth is True:
        files = glob.glob('runs/' + sess_name + "/*.pth")
        latest_file = max(files, key=os.path.getctime)
        state_dict = torch.load(latest_file)["model"]
        start_epoch = torch.load(latest_file)["epoch"]
        try:
            detector.load_state_dict(state_dict)
            print("Reload model weight succesfully")
        except RuntimeError:
            print("Error Load detector weights")
            exit()

    save_every_nbatch = 32
    for idx_epoch in range(start_epoch, n_epoch):
        detector.train()
        opt.step()
        sch.step()
        count = 0

        for batch in train_dl:
            sch.step()
            opt.zero_grad()

            res_train = train(batch)
            opt.step()

            if count % save_every_nbatch == 0:
                idx_elem =  len(train_dl) * (idx_epoch + 1) + count

                torch.save({
                    "model": detector.state_dict(),
                    "epoch": idx_epoch,
                    "params": {
                        "hp": hp,
                        "noise_aug" : aug,
                        "custom_name" : cstm_name,
                        "n_epoch" : idx_epoch,
                        "sig_dur": sig_dur,
                        },
                    "git_info": git_branch + '_' + git_sha
                    }, "runs/%s/ckpt_%06d.pth" %(sess_name, idx_elem))

                lr_0 = torch.Tensor(sch.get_last_lr())
                writer_train.add_scalar('lr', lr_0, idx_elem)
                writer_train.add_scalar('loss', res_train['loss'], idx_elem)
                writer_train.add_scalar('acc', res_train['accuracy'], idx_elem)
                writer_train.add_scalar('prec', res_train['precision'], idx_elem)
                writer_train.add_scalar('rec', res_train['recall'], idx_elem)
                writer_train.add_scalar('auc', res_train['auc'], idx_elem)
                writer_train.add_scalar('tpr', 100*res_train['tpr'], idx_elem)
                writer_train.add_scalar('fpr', 100*res_train['fpr'], idx_elem)

                print("epoch: %3d/%3d, iter: %6d/%6d, lr %f, loss: %f, acc: %f, prec: %f, rec: %f, auc: %f, tpr: %f, fpr: %f"%(
                    idx_epoch, n_epoch, count+1, len(train_dl)-1, lr_0, res_train['loss'], res_train['accuracy'], res_train['precision'], res_train['recall'], res_train['auc'], 100*res_train['tpr'], 100*res_train['fpr']))

            count += 1

        with torch.no_grad():
            detector.eval()
            labels, id_preds, losses, full_preds = [], [], [], []

            for test_batch in test_dl:
                # Sauvegarder le resultat des test_batch
                res_test = test(test_batch)

                labels.extend(res_test['labels'].cpu().detach().view(-1))
                id_preds.extend(torch.max(res_test['preds'], 1)[1].cpu().detach().view(-1))
                full_preds.extend(res_test['preds'][:, 1].cpu().detach())
                losses.append(res_test['loss'].cpu().detach())

            labels = np.array(labels)
            id_preds = np.array(id_preds)
            full_preds = np.array(full_preds)
            losses = np.array(losses)

            # compute metrics
            tp = ((id_preds + labels) == 2).sum()
            fp = ((id_preds==1).sum() - tp)
            precision = 100 * tp / (id_preds.sum() + 1e-8)
            recall = 100 * tp / (labels.sum() + 1e-8)
            accuracy = 100 * (id_preds == labels).sum() / len(labels)
            perc_tp = 100 * tp / (labels==1).sum()
            perc_fp = 100 * fp / (labels==0).sum()

            roc_auc = roc_auc_score(labels, full_preds)

            writer_test.add_scalar('loss', losses.mean(), idx_elem)
            writer_test.add_scalar('acc', accuracy, idx_elem)
            writer_test.add_scalar('prec', precision, idx_elem)
            writer_test.add_scalar('rec', recall, idx_elem)
            writer_test.add_scalar('tpr', perc_tp, idx_elem)
            writer_test.add_scalar('fpr', perc_fp, idx_elem)
            writer_test.add_scalar('auc', roc_auc, idx_elem)

            print("Test epoch: %3d/%3d, loss: %f, acc: %f, prec: %f, rec: %f, tpr: %f, fpr: %f, auc: %f"%(
                idx_epoch, n_epoch, losses.mean(), accuracy, precision, recall, perc_tp, perc_fp, roc_auc))

        # write test info
        torch.save({
            "model": detector.state_dict(),
            "epoch": idx_epoch,
            "params": {
                "hp": hp,
                "noise_aug": aug,
                "custom_name": cstm_name,
                "n_epoch": idx_epoch,
                "sig_dur": sig_dur,
                },
            "git_info": git_branch + '_' + git_sha
            }, "runs/%s/ckpt_%06d.pth" %(sess_name, idx_elem))

    return 0

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--data_path", type=str, default='DATA/', help="Path to the wav main folder")
    parser.add_argument("--df_path", type=str, default='Annotation CARIMAM.xlsx', help="Path to the wav main folder")
    parser.add_argument("--hp", action='store_true', help="Highpass filter frequency")
    parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
    parser.add_argument("--ne", type=int, default=30, help="Number of epoch")
    parser.add_argument("--wd", type=float, default=5e-2, help="Weight decay")
    parser.add_argument("--bs", type=int, default=32, help="Batch size")
    parser.add_argument("--nb_workers", type=int, default=12, help="Number of workers")
    parser.add_argument("--device", type=str, default='0', help='Gpu device')
    parser.add_argument("--aug", action='store_true', help="Use noise file as data augment")
    parser.add_argument("--reload_weigth", action='store_true', help="reload weigth from previous model")
    parser.add_argument("--custom_name", type=str, default='session_name', help="prefix name for weight file")

    sys.exit(main(parser.parse_args()))
