import numpy as np
import matplotlib.pyplot as plt


avgtr, stdtr, avgte, stdte = [], [], [], []

#feats = np.concatenate([np.arange(16, 64, 16), np.arange(64, 513, 64)])
feats = np.arange(7, 13)

for feat in feats:
    temptr, tempte = [], []
    for r in range(10):
            a = np.load('models/testreport_stft_depthwise_DO2d_64_kernel'+str(feat)+'_r'+str(r)+'_ovs.npy', allow_pickle=True).item()
            temptr.append(a['train_auc'])
            tempte.append(a['test_auc'])
    avgtr.append(np.mean(temptr))
    stdtr.append(np.std(temptr))
    avgte.append(np.mean(tempte))
    stdte.append(np.std(tempte))

avgtr, avgte, stdtr, stdte = np.array(avgtr), np.array(avgte), np.array(stdtr), np.array(stdte)


plt.figure()
plt.plot(feats, avgtr, label='Train AUC')
plt.plot(feats, avgte, label='Test AUC')

plt.fill_between(feats, avgtr - stdtr, avgtr+stdtr, alpha=.5)
plt.fill_between(feats, avgte - stdte, avgte+stdte, alpha=.5)

plt.legend()
plt.xlabel('# kernel size')
plt.ylabel('AUC score')

plt.savefig('AUC_vs_feat')

