from sklearn.metrics import roc_auc_score
from torch.utils.tensorboard import SummaryWriter
import torch
import numpy as np
from torch import nn, tensor, utils, device, cuda, optim, long, save
from torch.utils import data
import pandas as pd
import utils as u
import time
from tqdm import tqdm
from sklearn.metrics import *
#from sklearn.model_selection import train_test_split
from models import get
import argparse
df = pd.read_pickle('./new_annot_all.pkl') #REF ('./annot_all.pkl')
pDropout=.25
norm = False #REF False
int16 = True #REF True
# (BATCH, CHANNEL, Y, X)
# input size before the linear layers, independant of model input thks to adaptative pool

parser = argparse.ArgumentParser(description='')
parser.add_argument('-nfeat',  type=int, default=128)
parser.add_argument('-repeat', type=str, default='')
parser.add_argument('-kernel', type=int, default=7)
args = parser.parse_args()

nepoch = 30
batch_size = 8
modelname = 'stft_depthwise_ovs_'+str(args.nfeat)+'_k'+str(args.kernel)+'_r'+str(args.repeat)+'.stdc'
print('Go for model '+modelname)
lr = 0.0005
wdL2 = 0.002
fe = 64000

writer = SummaryWriter('runs/'+modelname)

traindf = df
#traindf = df[np.logical_and(~df.wavpath.str.startswith('/nfs/NAS5/SABIOD/SITE/BOMBYX/BOMBYX2017'), ~df.wavpath.str.startswith('/nfs/NAS3/SABIOD/SITE/BOMBYX_MONACO_2022-07'))]
#traindf = df[~df.wavpath.str.startswith('/nfs/NAS5/SABIOD/SITE/BOMBYX/BOMBYX2017')]
testdf_bx1 = df[df.wavpath.str.startswith('/nfs/NAS5/SABIOD/SITE/BOMBYX/BOMBYX2017')]
testdf_bx2 = df[df.wavpath.str.startswith('/nfs/NAS3/SABIOD/SITE/BOMBYX_MONACO_2022-07')]
print('train size is '+str(len(traindf))+' with '+str((traindf.annot=='cachcach').sum())+' positives')
print('test bx1 size is '+str(len(testdf_bx1))+' with '+str((testdf_bx1.annot=='cachcach').sum())+' positives')
print('test bx2 size is '+str(len(testdf_bx2))+' with '+str((testdf_bx2.annot=='cachcach').sum())+' positives')

#model = get['stft_depthwise_ksize_specBN'](args.nfeat, args.kernel)
model = get['stft_depthwise_64kHz_ksize_specBN'](args.nfeat, args.kernel) #get['stft_depthwise_ksize'](args.nfeat, args.kernel)

print('nb param', sum(m.numel() for m in model.parameters() if m.requires_grad))
bad_batches = {}
#u.PrintModel(model, indata=torch.ones(8,1,125000))
#model = nn.DataParallel(model)
gpu = device('cuda:1')
model = model.to(gpu)
#print('nb param', sum(m.numel() for m in model.module[:].parameters()))

optimizer = optim.Adam(model.parameters(), weight_decay=wdL2, lr=lr, betas=(0.8, 0.999))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : .98**epoch)

samplingweights = torch.ones(len(traindf))
samplingweights[(traindf.annot=='cachcach').to_numpy()] = 3
samplingweights[(traindf.annot=='globi').to_numpy()] = 10
samplingweights[(df.wavpath.str.startswith('/nfs/NAS3/SABIOD/SITE/BOMBYX_MONACO_2022-07')).to_numpy()] = 8
samplingweights /= len(traindf) + (traindf.annot=='cachcach').sum()*2 + (traindf.annot=='globi').sum()*9
sampler = utils.data.WeightedRandomSampler(weights = samplingweights, num_samples=len(traindf), replacement=True)

train_ds = u.Dataset(traindf, test_bool=False, fe=fe, norm=norm, int16=int16, reef_noise=True, brown_noise=False, sampleDur=5)
loader = utils.data.DataLoader(train_ds, batch_size=batch_size, num_workers=8, prefetch_factor=4, pin_memory=True, sampler=sampler)

print('Started at ',time.ctime(time.time()))
loss_fun = torch.nn.BCEWithLogitsLoss()

# TRAINING LOOP
for epoch in range(nepoch):

    model.train()
    optimizer.step()
    scheduler.step()
    count = 0
    
    for batch in tqdm(loader, desc=str(epoch), leave=False):
        x, label = batch
        optimizer.zero_grad()
        pred = model(x.to(gpu)).view(-1)
        label = label.to(gpu).float()
        score = loss_fun(pred, label)
        score.backward()
        optimizer.step()

        writer.add_scalar('train loss', score.item(), epoch * (len(traindf)//batch_size + 1) + count)
        writer.add_scalar('train acc', ((pred>0)==label.bool()).sum()/batch_size, epoch * (len(traindf)//batch_size + 1) + count)
        count += 1
    save(model.state_dict(), modelname)

    # get loss on valid
    with torch.no_grad():
        model.eval()
        labels, preds, losses, fullpreds = [], [], [], []
        for batch in utils.data.DataLoader(u.Dataset(testdf_bx1, test_bool=True, fe=fe, norm=norm, int16=int16, sampleDur=5), batch_size=32, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True):
            x, label = batch
            pred = model(x.to(gpu)).squeeze().cpu().detach()
            preds.extend(pred.view(-1))
            labels.extend(label.view(-1))
            losses.append(loss_fun(pred, label.float()))

    validauc = roc_auc_score(labels, preds)
    # validauc = roc_auc_score(torch.cat((torch.stack(labels).view(-1,1)==0, torch.stack(labels).view(-1,1)==1),1), torch.stack(fullpreds))
    labels = np.array(labels).astype(bool)
    preds = (np.array(preds)>0).astype(bool)
    validacc = (labels==preds).sum() / len(labels)
    validtpr = (labels&preds).sum() / labels.sum()
    validtnr = (~labels&~preds).sum() / (~labels).sum()
    writer.add_scalar('valid loss', np.mean(losses), epoch)
    writer.add_scalar('valid acc', validacc, epoch)
    writer.add_scalar('valid tpr', validtpr, epoch)
    writer.add_scalar('valid tnr', validtnr, epoch)
    writer.add_scalar('valid auc', validauc, epoch)


# get loss on valid for Bx1
print('Eval for Bombyx1')
with torch.no_grad():
    # model[3].train()
    # for batch in utils.data.DataLoader(u.Dataset(testdf_bx1, test_bool=True, fe=fe, norm=norm, int16=int16, sampleDur=5), batch_size=32, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True):
    #     x, label = batch
    #     optimizer.zero_grad()
    #     pred = model(x.to(gpu)).view(-1)
    #     label = label.to(gpu).float()
    #     score = loss_fun(pred, label)
    #     #score.backward()
    #     optimizer.step()

    model.eval()
    labels, preds, losses = [], [], []
    for batch in utils.data.DataLoader(u.Dataset(testdf_bx1, test_bool=True, fe=fe, norm=norm, int16=int16, sampleDur=5), batch_size=32, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True):
        x, label = batch
        pred = model(x.to(gpu)).squeeze().cpu().detach()
        preds.extend(pred.view(-1))
        labels.extend(label.view(-1))
        losses.append(loss_fun(pred, label.float()))

    preds = np.array(preds)
    labels = np.array(labels)

    print('AUC ROC %f'%(roc_auc_score(labels, preds)))
    print('mAP %f'%(np.mean([average_precision_score( (labels+1)%2, preds*-1), average_precision_score(labels, preds)])))
    preds = preds > 0
    print('precision : %f, recall : %f, accuracy : %f'%(precision_score(labels, preds), recall_score(labels, preds), accuracy_score(labels, preds)))

print('Eval for Bombyx1 + noise')
with torch.no_grad():
    # model[3].train()
    # for batch in utils.data.DataLoader(u.Dataset(testdf_bx1, test_bool=True, fe=fe, norm=norm, reef_noise=True, int16=int16, sampleDur=5), batch_size=32, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True):
    #     x, label = batch
    #     optimizer.zero_grad()
    #     pred = model(x.to(gpu)).view(-1)
    #     label = label.to(gpu).float()
    #     score = loss_fun(pred, label)
    #     #score.backward()
    #     optimizer.step()

    model.eval()
    labels, preds, losses = [], [], []
    for batch in utils.data.DataLoader(u.Dataset(testdf_bx1, test_bool=True, fe=fe, norm=norm, reef_noise=True, int16=int16, sampleDur=5), batch_size=32, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True):
        x, label = batch
        pred = model(x.to(gpu)).squeeze().cpu().detach()
        preds.extend(pred.view(-1))
        labels.extend(label.view(-1))
        losses.append(loss_fun(pred, label.float()))

    preds = np.array(preds)
    labels = np.array(labels)

    print('AUC ROC %f'%(roc_auc_score(labels, preds)))
    print('mAP %f'%(np.mean([average_precision_score( (labels+1)%2, preds*-1), average_precision_score(labels, preds)])))
    preds = preds > 0
    print('precision : %f, recall : %f, accuracy : %f'%(precision_score(labels, preds), recall_score(labels, preds), accuracy_score(labels, preds)))

print('Eval for Bombyx2')
with torch.no_grad():
    # model[3].train()
    # for batch in utils.data.DataLoader(u.Dataset(testdf_bx2, test_bool=True, fe=fe, norm=norm, int16=int16, sampleDur=5), batch_size=32, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True):
    #     x, label = batch
    #     optimizer.zero_grad()
    #     pred = model(x.to(gpu)).view(-1)
    #     label = label.to(gpu).float()
    #     score = loss_fun(pred, label)
    #     #score.backward()
    #     optimizer.step()

    model.eval()
    labels, preds, losses = [], [], []
    for batch in utils.data.DataLoader(u.Dataset(testdf_bx2, test_bool=True, fe=fe, norm=norm, int16=int16, sampleDur=5), batch_size=32, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True):
        x, label = batch
        pred = model(x.to(gpu)).squeeze().cpu().detach()
        preds.extend(pred.view(-1))
        labels.extend(label.view(-1))
        losses.append(loss_fun(pred, label.float()))

    preds = np.array(preds)
    labels = np.array(labels)

    print('AUC ROC %f'%(roc_auc_score(labels, preds)))
    print('mAP %f'%(np.mean([average_precision_score( (labels+1)%2, preds*-1), average_precision_score(labels, preds)])))
    preds = preds > 0
    print('precision : %f, recall : %f, accuracy : %f'%(precision_score(labels, preds), recall_score(labels, preds), accuracy_score(labels, preds)))