import torch
import numpy as np
from torch import nn, tensor, utils, device, cuda, optim, long, save
from torch.utils import data
import pandas as pd
import utils as u
import time
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import os
from models import get
import argparse
df = pd.read_pickle('./annot_all.pkl')

parser = argparse.ArgumentParser(description='')
parser.add_argument('-repeat', type=str, default='')
parser.add_argument('-nfeat',  type=int, default=128)
parser.add_argument('-kernel',  type=int, default=11)
args = parser.parse_args()


modelname = f'stft_{args.nfeat}_r{args.repeat}.stdc'
print('Go for model '+modelname)

model = get['stft'](args.nfeat)
model.load_state_dict(torch.load('models/'+modelname))
#print('original size : (MB) ', os.path.getsize('models/'+modelname)/1e6)
#model = model.to(device('cuda:0'))
#print(model)
model = model.eval()
#for l in model:
#    if type(l)==u.depthwise_separable_conv1d:
#        l.fuse_module()
#torch.quantization.fuse_modules(model, [['3','4', '5'],['7','8','9']], inplace=True)
#print(model)


def test_model(model, df):
    preds, labels = [], []
    loader = utils.data.DataLoader(u.Dataset(df), batch_size=32, shuffle=True, num_workers=8, prefetch_factor=4)
    with torch.no_grad():
        model.eval()
        for batch in tqdm(loader, total=len(loader), leave=False):
            x, label = batch
            pred = model(x).detach().squeeze()
            preds.extend(pred)
            labels.extend(label)
    return roc_auc_score(labels, preds)

AUC = test_model(model, df[df.wavpath.str.startswith('/BOMBYX2017')])
print('original perf', AUC)
out = {'before':AUC}

model = torch.nn.Sequential(model[:3], torch.quantization.QuantStub(), *model[3:-1], torch.quantization.DeQuantStub(), model[-1])
#print(model)
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.fuse_modules(model, [["2", "3", "4"], ["6", "7", "8"]], inplace=True)
torch.quantization.prepare(model, inplace=True)

#calibrate
test_model(model, df.sample(256))
torch.quantization.convert(model, inplace=True)

AUC = test_model(model, df[df.wavpath.str.startswith('/BOMBYX2017')])
print('new perf', AUC)
out['after'] = AUC
#print('new size (MB)', os.path.getsize('models/postquat_'+modelname)/1e6)

#print(quantized)

np.save(modelname[:-4], out)
