import nics_fix_pt.nn_fix as nnf
from nics_fix_pt import register_fix_module
import numpy as np
from torch import nn, tensor, load, utils, device, cuda, optim, long, save
import torch
import pandas as pd
import utils as u
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve
from models import get
import argparse

parser = argparse.ArgumentParser(description='')
parser.add_argument('-nfeat',  type=int, default=512)
parser.add_argument('-repeat', type=str, default='')
args = parser.parse_args()


def _test(df, name):
    with torch.no_grad():
        loader = utils.data.DataLoader(u.Dataset(df, filename=True), shuffle=True, batch_size=batch_size, num_workers=2, pin_memory=True)
        labels, preds, filenames = [], [], []
        for batch in tqdm(loader):
            x, label, filename = batch
            x = x.to('cuda', non_blocking=True)
            pred = model(x).cpu().detach()
            preds.extend(pred)
            labels.extend(label)
            filenames.extend(filename)

        # auc = roc_auc_score(torch.cat((torch.stack(labels).view(-1,1), torch.stack(labels==1).view(-1,1)),1), torch.stack(fullpreds))
        auc = roc_auc_score(labels, preds)
        roc = roc_curve(labels, preds, drop_intermediate=False)
        preds = (np.array(preds)>21.111765).astype(bool)
        labels = np.array(labels).astype(bool)
        filenames = np.array(filenames)
        print(name+' : '+str(len(df))+' samples')
        print(str(labels.sum())+' positive samples')
        print(str((~labels).sum())+' negative samples')
        print('Accuracy : '+str((labels==preds).sum() / len(labels)))
        print('False positives : '+str(((preds)&(~labels)).sum()/labels.sum()))
        print('False negatives : '+str(((~preds)&(labels)).sum()/(~labels).sum()))

        print('True positive rate : '+str((preds & labels).sum() / labels.sum()))
        print('True negative rate : '+str(((~preds) & (~labels)).sum() / (~labels).sum()))
        print('AUC ', auc)

        fails[name+'_false_neg'] = filenames[((~preds)&(labels))]
        fails[name+'_false_pos'] = filenames[((preds)&(~labels))]
        fails[name+'_roc'] = roc
        fails[name+'_auc'] = auc

print('load...', end='')
df = pd.read_pickle('./annot_all.pkl')
# (BATCH, CHANNEL, Y, X)
batch_size = 256

#modelname = 'stft_depthwise_'+str(args.nfeat)+'_r'+args.repeat+'.stdc'
modelname = 'stft_depthwise_64_r0.stdc'

cuda0 = device('cuda:0')
model = get['stft_depthwise'](64) #(args.nfeat)
model.load_state_dict(load('models/'+modelname))
model.eval()
model = model.to(cuda0)
fails = {}

print('Performance of :'+modelname)

for k in model[3:].state_dict():
    obj = getattr(getattr(model[int(k.split('.')[0])], k.split('.')[1]), k.split('.')[2])
    print(obj[0]*1e9)
    setattr(getattr(model[int(k.split('.')[0])], k.split('.')[1]), k.split('.')[2], torch.nn.Parameter(torch.round(obj)))
    print(obj[0]*1e9)
exit()
#_test(df, 'overall')
_test(df[~df.wavpath.str.startswith('/BOMBYX2017')], 'train')
_test(df[df.wavpath.str.startswith('/BOMBYX2017')], 'test')



#np.save('testreport_'+modelname[:-5], fails)

