from sklearn.metrics import roc_auc_score
from torch.utils.tensorboard import SummaryWriter
import torch
import numpy as np
from torch import nn, tensor, utils, device, cuda, optim, long, save
from torch.utils import data
import pandas as pd
import utils as u
import time
import warnings
warnings.filterwarnings("ignore")
from tqdm import tqdm
#from sklearn.model_selection import train_test_split
from models import get
import argparse
pDropout=.25
sampleDur = 1
norm = True
int16 = False
# (BATCH, CHANNEL, Y, X)
# input size before the linear layers, independant of model input thks to adaptative pool

parser = argparse.ArgumentParser(description='')
parser.add_argument('-nfeat',  type=int, default=128)
parser.add_argument('-repeat', type=str, default='')
parser.add_argument('-kernel', type=int, default=7)
args = parser.parse_args()

nepoch = 10
batch_size = 8
modelname = 'stft_depthwise_carimam_'+str(args.nfeat)+'_k'+str(args.kernel)+'_r'+str(args.repeat)+'.stdc'
print('Go for model '+ modelname)
lr = 0.0005
wdL2 = 0.002
fe = 128_000


writer = SummaryWriter('runs/'+modelname)

df = pd.read_excel('/nfs/NAS6/mahe/src/transformer-carimam/Annotation_CARIMAM_apo_23_05_11.xlsx', 'annot_click_apo', usecols='A:E').dropna()
df = df[~df['File'].str.startswith('/nfs/NAS4/')]
df = df[~df['File'].str.startswith('/nfs/NAS3/')]

# df_pos = df[df.positif_negatif != 'n']
# df_save = df.copy()
# for idx, item in tqdm(df.iterrows()):
#     if item.positif_negatif == 'n':
#         pos_mid = (item.pos_start + item.pos_end)/2
#         pos_margin_left = pos_mid - sampleDur
#         pos_margin_right = pos_mid + sampleDur

#         df_pos_file = df_pos[df_pos["File"]== item["File"] ]
                        
#         if np.logical_and(df_pos_file.pos_start > pos_margin_left, df_pos_file.pos_end < pos_margin_right).sum() > 0 or \
#         np.logical_and(df_pos_file.pos_end < pos_margin_left, df_pos_file.pos_end > pos_margin_right).sum() > 0 :
#             df_save = df_save.drop(index=idx)
# df = df_save


df["annot"] = df.positif_negatif
df["wavpath"] = df.File
df["time"] = (df.pos_start + df.pos_end)/2
df.annot[df.annot != 'n'] = 'cachcach'
df.annot[df.annot == 'n'] = 'noise'

#mask = np.random.rand(len(df)) < 0.1
mask = ((df['File'].str.contains('LOT2/BERMUDE'))|(df.File.str.contains('LOT2/GUYANNE')) | (df.File.str.contains('LOT2/ANG')))

traindf = df[~mask]
testdf = df[mask]
print('train size is '+str(len(traindf))+' with '+str((traindf.annot=='cachcach').sum())+' positives')
print('test size is '+str(len(testdf))+' with '+str((testdf.annot=='cachcach').sum())+' positives')

model = get['stft_depthwise_ksize_CARIMAM'](args.nfeat, args.kernel)

print('nb param', sum(m.numel() for m in model.parameters() if m.requires_grad))
bad_batches = {}
#u.PrintModel(model, indata=torch.ones(8,1,125000))
#model = nn.DataParallel(model)
gpu = device('cuda:1')
model = model.to(gpu)
#print('nb param', sum(m.numel() for m in model.module[:].parameters()))

optimizer = optim.Adam(model.parameters(), weight_decay=wdL2, lr=lr, betas=(0.8, 0.999))
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : .98**epoch)

# samplingweights = torch.ones(len(traindf))
# samplingweights[(traindf.annot=='cachcach').to_numpy()] = 3
# samplingweights[(traindf.annot=='globi').to_numpy()] = 10
# samplingweights /= len(traindf) + (traindf.annot=='cachcach').sum()*2 + (traindf.annot=='globi').sum()*9
# sampler = utils.data.WeightedRandomSampler(weights = samplingweights, num_samples=len(traindf), replacement=True)

train_ds = u.Dataset(traindf, fe=fe, norm=norm, int16=int16, reef_noise=True, sampleDur=sampleDur)
loader = utils.data.DataLoader(train_ds, batch_size=batch_size, num_workers=8, prefetch_factor=4, shuffle=True, pin_memory=True, drop_last=True)#, sampler=sampler)

print('Started at ',time.ctime(time.time()))
loss_fun = torch.nn.BCEWithLogitsLoss()

# TRAINING LOOP
for epoch in range(nepoch):

    model.train()
    optimizer.step()
    scheduler.step()
    count = 0
    for batch in tqdm(loader, desc=str(epoch), leave=False):
        x, label = batch
        optimizer.zero_grad()
        pred = model(x.to(gpu)).squeeze()
        label = label.to(gpu).float()
        score = loss_fun(pred, label)
        score.backward()
        optimizer.step()

        writer.add_scalar('train loss', score.item(), epoch * (len(traindf)//batch_size + 1) + count)
        writer.add_scalar('train acc', ((pred>0)==label.bool()).sum()/batch_size, epoch * (len(traindf)//batch_size + 1) + count)
        count += 1
    save(model.state_dict(), modelname)
    # get loss on valid
    with torch.no_grad():
        model.eval()
        labels, preds, losses, fullpreds = [], [], [], []
        for batch in utils.data.DataLoader(u.Dataset(testdf, fe=fe, norm=norm, int16=int16, sampleDur=sampleDur), batch_size=32, shuffle=True, num_workers=2, pin_memory=True):
            x, label = batch
            pred = model(x.to(gpu)).squeeze().cpu().detach()
            preds.extend(pred.view(-1))
            labels.extend(label.view(-1))
            losses.append(loss_fun(pred, label.float()))
    
    validauc = roc_auc_score(labels, preds)
    # validauc = roc_auc_score(torch.cat((torch.stack(labels).view(-1,1)==0, torch.stack(labels).view(-1,1)==1),1), torch.stack(fullpreds))
    labels = np.array(labels).astype(bool)
    preds = (np.array(preds)>0).astype(bool)
    validacc = (labels==preds).sum() / len(labels)
    validtpr = (labels&preds).sum() / labels.sum()
    validtnr = (~labels&~preds).sum() / (~labels).sum()
    writer.add_scalar('valid loss', np.mean(losses), epoch)
    writer.add_scalar('valid acc', validacc, epoch)
    writer.add_scalar('valid tpr', validtpr, epoch)
    writer.add_scalar('valid tnr', validtnr, epoch)
    writer.add_scalar('valid auc', validauc, epoch)

