import matplotlib.pyplot as plt
import numpy as np


feats = [16, 32, 64, 128]
nparams = [7*(64 + f *2) + f*(64 +f + 1) for f in feats]
testaucs = np.array([[np.load(f'models/testreport_stft_depthwise_ovs_{f}_k7_r{r}.npy', allow_pickle=True).item()['test_auc'] for r in range(5)] for f in feats])
trainaucs = np.array([[np.load(f'models/testreport_stft_depthwise_ovs_{f}_k7_r{r}.npy', allow_pickle=True).item()['train_auc'] for r in range(5)] for f in feats])
plt.figure()
plt.xscale('log')
bo = plt.boxplot(testaucs.T, positions=nparams, widths=np.array(nparams)/20, manage_ticks=False, patch_artist=True)
for b in bo['boxes']:
    b.set_facecolor('green')
prunes = [0.4, .3, .2, .1]
nparams = [(7*(64 + f *2) + f*(64 +f + 1))*(1-p) for f in [32, 64, 128] for p in prunes]
testaucs = np.array([[np.load(f'models/testreport_prune{p}_stft_depthwise_ovs_{f}_k7_r{r}.npy', allow_pickle=True).item()['test_auc'] for r in range(5)] for f in [32, 64, 128] for p in prunes])
plt.boxplot(testaucs.T, positions=nparams, manage_ticks=False, widths=np.array(nparams)/20)
plt.xlabel('Number of non-zeroed parameters')

plt.ylabel('AUC')
plt.tight_layout()
plt.ylim(.8, .95)
plt.xlabel('Number of parameters')
plt.xlabel('Number of non-zero parameters')
plt.tight_layout()
plt.tight_layout()
plt.savefig('pruning.pdf')
