from sklearn.metrics import roc_auc_score
from torch.utils.tensorboard import SummaryWriter
import torch
import numpy as np
from torch import nn, tensor, utils, device, cuda, optim, long, save
from torch.utils import data
import pandas as pd
import utils as u
import time
from tqdm import tqdm
#from sklearn.model_selection import train_test_split
from models import get
import argparse
df = pd.read_pickle('./annot_all.pkl') #REF ('./annot_all.pkl')
pDropout=.25
norm = True #REF True
int16 = False #True #REF False
# (BATCH, CHANNEL, Y, X)
# input size before the linear layers, independant of model input thks to adaptative pool

parser = argparse.ArgumentParser(description='')
parser.add_argument('-nfeat',  type=int, default=128)
parser.add_argument('-repeat', type=str, default='')
parser.add_argument('-kernel', type=int, default=7)
args = parser.parse_args()

nepoch = 30
batch_size = 8
modelname = 'stft_depthwise_ovs_'+str(args.nfeat)+'_k'+str(args.kernel)+'_r'+str(args.repeat)+'.stdc'
print('Go for model '+modelname)
lr = 0.0005
wdL2 = 0.002
fe=50000

writer = SummaryWriter('runs/'+modelname)

traindf = df[~df.wavpath.str.startswith('/nfs/NAS5/SABIOD/SITE/BOMBYX/BOMBYX2017')] #changer startwith par contient
testdf = df[df.wavpath.str.startswith('/nfs/NAS5/SABIOD/SITE/BOMBYX/BOMBYX2017')] #changer startwith par contient
print('train size is '+str(len(traindf))+' with '+str((traindf.annot=='cachcach').sum())+' positives')


model = get['stft_depthwise_ksize'](args.nfeat, args.kernel)

print('nb param', sum(m.numel() for m in model.parameters() if m.requires_grad))
bad_batches = {}
#u.PrintModel(model, indata=torch.ones(8,1,125000))
#model = nn.DataParallel(model)
gpu = device('cuda:1')
model = model.to(gpu)
#print('nb param', sum(m.numel() for m in model.module[:].parameters()))

optimizer = optim.Adam(model.parameters(), weight_decay=wdL2, lr=lr, betas=(0.8, 0.999))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : .98**epoch)

samplingweights = torch.ones(len(traindf))
samplingweights[(traindf.annot=='cachcach').to_numpy()] = 3
samplingweights[(traindf.annot=='globi').to_numpy()] = 10
samplingweights /= len(traindf) + (traindf.annot=='cachcach').sum()*2 + (traindf.annot=='globi').sum()*9
sampler = utils.data.WeightedRandomSampler(weights = samplingweights, num_samples=len(traindf), replacement=True)

train_ds = u.Dataset(traindf, fe=fe, norm=norm, int16=int16, reef_noise=True, sampleDur=5)
loader = utils.data.DataLoader(train_ds, batch_size=batch_size, num_workers=8, prefetch_factor=4, pin_memory=True, sampler=sampler)

print('Started at ',time.ctime(time.time()))
loss_fun = torch.nn.BCEWithLogitsLoss()

# TRAINING LOOP
for epoch in range(nepoch):

    model.train()
    optimizer.step()
    scheduler.step()
    count = 0
    for batch in tqdm(loader, desc=str(epoch), leave=False):
        x, label = batch
        optimizer.zero_grad()
        pred = model(x.to(gpu)).squeeze()
        label = label.to(gpu).float()
        score = loss_fun(pred, label)
        score.backward()
        optimizer.step()

        writer.add_scalar('train loss', score.item(), epoch * (len(traindf)//batch_size + 1) + count)
        writer.add_scalar('train acc', ((pred>0)==label.bool()).sum()/batch_size, epoch * (len(traindf)//batch_size + 1) + count)
        count += 1
    save(model.state_dict(), modelname)

    # get loss on valid
    with torch.no_grad():
        model.eval()
        labels, preds, losses, fullpreds = [], [], [], []
        for batch in utils.data.DataLoader(u.Dataset(testdf, fe=fe, norm=norm, int16=int16, sampleDur=5), batch_size=32, shuffle=True, num_workers=2, pin_memory=True):
            x, label = batch
            pred = model(x.to(gpu)).squeeze().cpu().detach()
            preds.extend(pred.view(-1))
            labels.extend(label.view(-1))
            losses.append(loss_fun(pred, label.float()))

    validauc = roc_auc_score(labels, preds)
    # validauc = roc_auc_score(torch.cat((torch.stack(labels).view(-1,1)==0, torch.stack(labels).view(-1,1)==1),1), torch.stack(fullpreds))
    labels = np.array(labels).astype(bool)
    preds = (np.array(preds)>0).astype(bool)
    validacc = (labels==preds).sum() / len(labels)
    validtpr = (labels&preds).sum() / labels.sum()
    validtnr = (~labels&~preds).sum() / (~labels).sum()
    writer.add_scalar('valid loss', np.mean(losses), epoch)
    writer.add_scalar('valid acc', validacc, epoch)
    writer.add_scalar('valid tpr', validtpr, epoch)
    writer.add_scalar('valid tnr', validtnr, epoch)
    writer.add_scalar('valid auc', validauc, epoch)

