from sklearn.model_selection import KFold
import torch
import numpy as np
from torch import nn, tensor, utils, device, cuda, optim, long, save
from torch.utils import data
import pandas as pd
import utils as u
import time
from scipy.io import wavfile
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from visdom import Visdom
from models import get
df = pd.read_pickle('./annot_all.pkl')

# (BATCH, CHANNEL, Y, X)
# input size before the linear layers, independant of model input thks to adaptative pool

nepoch = 200
batch_size = 32
modelname = 'quattrained_stft.mdl'
print('Go for model '+modelname)
lr = 0.0005
wdL2 = 0.002
traindf = df[~df.wavpath.str.startswith('/BOMBYX2017')]
testdf = df[df.wavpath.str.startswith('/BOMBYX2017')]
#drop some negatives so there is as much as positives
#traindf.drop(traindf[traindf.annot!='cachcach'].sample(len(traindf[traindf.annot!='cachcach']) - len(traindf[traindf.annot=='cachcach'])).index, inplace=True)
print('train size is '+str(len(traindf))+' with '+str((traindf.annot=='cachcach').sum())+' positives')

model = get['stft']
model.load_state_dict(torch.load('stft.stdc'))
model.eval()
torch.quantization.fuse_modules(model, [['3', '4', '5'], ['7', '8', '9']], inplace=True)
#for l in model:
#    if type(l)==u.depthwise_separable_conv1d:
#        l.fuse_module()

model.qconfig = torch.quantization.default_qat_qconfig
torch.quantization.prepare_qat(model, inplace=True)

#print('nb param', sum(m.numel() for m in model.parameters() if m.requires_grad))
bad_batches = {}
model = model.to(device('cuda:0'))

optimizer = optim.Adam(model.parameters(), weight_decay=wdL2, lr=lr, betas=(0.8, 0.999))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : .98**epoch)
loader = utils.data.DataLoader(u.Dataset(traindf), batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
loss_fun = nn.BCEWithLogitsLoss(pos_weight=tensor(2).cuda())

loss, tprs, tnrs, accs = [], [], [], []
print('Started at ',time.ctime(time.time()))
vis = Visdom(port=8097, server='http://10.2.92.202', env=modelname[:-4])
# TRAINING LOOP
for epoch in range(nepoch):
    model.to('cuda')
    model.train()
    optimizer.step()
    scheduler.step()
    for batch in tqdm(loader):
        x, label = batch
        optimizer.zero_grad()
        pred = model(x.cuda()).squeeze()
        label = label.cuda().float()

        score = loss_fun(pred, label)

        label = np.array(label.cpu().detach()).astype(bool)
        pred = np.array(pred.cpu().detach()>0).astype(bool)
        accs.append((label==pred).sum()/len(label))
        tprs.append((label&pred).sum()/label.sum())
        tnrs.append((~label&~pred).sum()/(~label).sum())

        score.backward()
        optimizer.step()
        loss.append(score.item())
        vis.line(X=[len(loss)-1], Y=[loss[-1]], win='loss', update='append', opts={'title':'train loss'})
        vis.line(X=[len(loss)-1], Y=[np.mean(loss[-20:])], win='avgloss', update='append', opts={'title':'train avgloss'})
        vis.line(X=[len(loss)-1], Y=[np.mean(accs[-20:])], win='train_acc', update='append', opts={'title':'train acc'})
        vis.line(X=[len(loss)-1], Y=[np.mean(tnrs[-20:])], win='train_tnr', update='append', opts={'title':'train TNR'})
        vis.line(X=[len(loss)-1], Y=[np.mean(tprs[-20:])], win='train_tpr', update='append', opts={'title':'train TPR'})
    save(model.state_dict(), modelname[:-3]+'stdc')
    vis.line(X=[epoch], Y=[scheduler.get_lr()], win='lr', update='append', opts={'title':'learning rate'})
    # get loss on valid
    with torch.no_grad():
        qmodel = torch.quantization.convert(model.cpu().eval(), inplace=False)
        qmodel = u.addquantize(qmodel)
        qmodel.eval()
        labels, preds, losses, fullpreds = [], [], [], []
        for batch in utils.data.DataLoader(u.Dataset(testdf), batch_size=32, shuffle=True, num_workers=2, pin_memory=True):
            x, label = batch
            pred = qmodel(x).detach().float()
            pred = pred.squeeze() if pred.ndimension() > 1 else pred

            loss_fun = torch.nn.BCEWithLogitsLoss()
            preds.extend(pred.view(-1))
            labels.extend(label.view(-1))
            losses.append(loss_fun(pred, label.float()))
    validauc = roc_auc_score(labels, preds)
    # validauc = roc_auc_score(torch.cat((torch.stack(labels).view(-1,1)==0, torch.stack(labels).view(-1,1)==1),1), torch.stack(fullpreds))
    labels = np.array(labels).astype(bool)
    preds = (np.array(preds)>0).astype(bool)
    validacc = (labels==preds).sum() / len(labels)
    validtpr = (labels&preds).sum() / labels.sum()
    validtnr = (~labels&~preds).sum() / (~labels).sum()
    vis.line(X=[epoch], Y=[np.mean(losses)], win='valid_loss', update='append', opts={'title':'valid loss'})
    vis.line(X=[epoch], Y=[validacc], win='valid_acc', update='append', opts={'title':'valid acc'})
    vis.line(X=[epoch], Y=[validtpr], win='valid_tpr', update='append', opts={'title':'valid TPR'})
    vis.line(X=[epoch], Y=[validtnr], win='valid_tnr', update='append', opts={'title':'valid TNR'})
    vis.line(X=[epoch], Y=[validauc], win='valid_auc', update='append', opts={'title':'valid AUC'})
    save(qmodel.state_dict(), 'quat_'+modelname[:-3]+'stdc')
