import numpy as np
import os, sys
from tqdm import tqdm
import pandas as pd
#sys.path.append('../call_classif_sup/')
#sys.path.append('../scripts/')
import models, utils
import torch


# forward classifier on detected calls

modelname = 'resnet18_frontend3b_29_04_resnet18stride2_dur_2_test_swop_augm.stdc'
df = pd.read_pickle('../predictions_sans_classification_04_05.pkl')
#df = pd.read_csv('../sparrow_whales_train1_1002_frontend_pcen_conv1d_noaugm_bs32_lr.005_.csv')
m = torch.nn.Sequential(models.get['frontend3b'], models.get['resnet18'])
m.load_state_dict(torch.load(modelname))
m.eval().to('cuda')

loader = torch.utils.data.DataLoader(utils.Dataset(df, sampleDur=2), batch_size=64, num_workers=4, prefetch_factor=4, collate_fn=utils.collate_fn)
for x, idx in tqdm(loader):
    conf, pred = torch.nn.functional.softmax(m(x.to('cuda')).detach().cpu(), dim=-1).max(-1)
    df.loc[idx, 'pred_label'] = [utils.idxtotype[p.item()] for p in pred]
    df.loc[idx, 'pred_conf'] = conf.numpy()

df.drop(df[df.fn.isin(df[df.pred_label=='corrupt'].fn.unique())].index, inplace=True)
df.to_pickle('classification_04_05_pred.pkl')
