from sklearn.model_selection import KFold
import torch
import numpy as np
from torch import nn, tensor, utils, device, cuda, optim, long, save
from torch.utils import data
import pandas as pd
import utils as u
import time
from scipy.io import wavfile
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from visdom import Visdom
from models import get
df = pd.read_pickle('./annot_all.pkl')

# (BATCH, CHANNEL, Y, X)
# input size before the linear layers, independant of model input thks to adaptative pool

nepoch = 80
batch_size = 16
modelname = 'kfold.mdl'
print('Go for model '+modelname)
lr = 0.0005
wdL2 = 0.002

cuda0 = device('cuda:0')

kf = KFold(n_splits=10, shuffle=True)

print('Started at ',time.ctime(time.time()))

kfold_val_acc, kfold_val_auc, kfold_val_tpr, kfold_val_tnr = [], [], [], []

for idtrain, idtest in kf.split(df):
    traindf = df.iloc[idtrain]
    testdf = df.iloc[idtest]

    model = get['0.3']
    model = nn.DataParallel(model)
    model.to(cuda0)

    optimizer = optim.Adam(model.parameters(), weight_decay=wdL2, lr=lr, betas=(0.8, 0.999))
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : .98**epoch)

    loss, tprs, tnrs, accs = [], [], [], []

    loader = utils.data.DataLoader(u.Dataset(traindf),
                                    batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

    # TRAINING LOOP
    for epoch in tqdm(range(nepoch)):
        model.train()
        optimizer.step()
        scheduler.step()

        for batch in loader:
            x, label = batch
            x = x.to('cuda', non_blocking=True)

            optimizer.zero_grad()
            pred = model(x).squeeze()
            label = label.cuda(cuda0).float()

            weights = torch.ones(len(label)).to('cuda', non_blocking=True)
            weights[label==1] = 3
            loss_fun = torch.nn.BCEWithLogitsLoss(weight=weights)
            score = loss_fun(pred, label)

            label = np.array(label.cpu().detach()).astype(bool)
            pred = np.array(pred.cpu().detach()>0).astype(bool)
            accs.append((label==pred).sum()/len(label))
            tprs.append((label&pred).sum()/label.sum())
            tnrs.append((~label&~pred).sum()/(~label).sum())

            score.backward()
            optimizer.step()

            loss.append(score.item())
    
    print('Training ended :')
    print('acc : '+str(np.mean(accs[:-20])))
    print('loss : '+str(np.mean(loss[:-20])))
    print('TPR : '+str(np.mean(tprs[:-20])))
    print('TNR : '+str(np.mean(tnrs[:-20])))


    # get loss on valid
    with torch.no_grad():
        model.eval()
        labels, preds, losses, fullpreds = [], [], [], []
        for batch in utils.data.DataLoader(u.Dataset(testdf), batch_size=32, shuffle=True, num_workers=2, pin_memory=True):
            x, label = batch
            x.to('cuda', non_blocking=True)
            pred = model(x).cpu().detach().float()
            pred = pred.squeeze() if pred.ndimension() > 1 else pred

            loss_fun = torch.nn.BCEWithLogitsLoss()
            preds.extend(pred.view(-1))
            labels.extend(label.view(-1))
            losses.append(loss_fun(pred, label.float()))
    validauc = roc_auc_score(labels, preds)
    # validauc = roc_auc_score(torch.cat((torch.stack(labels).view(-1,1)==0, torch.stack(labels).view(-1,1)==1),1), torch.stack(fullpreds))
    labels = np.array(labels).astype(bool)
    preds = (np.array(preds)>0).astype(bool)
    validacc = (labels==preds).sum() / len(labels)
    validtpr = (labels&preds).sum() / labels.sum()
    validtnr = (~labels&~preds).sum() / (~labels).sum()

    print('Valid ended :')
    print('acc : '+str(validacc))
    print('auc : '+str(validauc))
    print('TPR : '+str(validtpr))
    print('TNR : '+str(validtnr))
    kfold_val_tnr.append(validtnr)
    kfold_val_tpr.append(validtpr)
    kfold_val_auc.append(validauc)
    kfold_val_acc.append(validacc)


print('global :')
print('acc :',np.mean(kfold_val_acc))
print('auc :',np.mean(kfold_val_auc))
print('tpr :',np.mean(kfold_val_tpr))
print('tnr :',np.mean(kfold_val_tnr))
