import numpy as np
import matplotlib.pyplot as plt



kernel = np.arange(3, 9)
feat = [32, 64, 128]
repeat = np.arange(5)
out = np.zeros((len(feat), len(kernel), len(repeat), 2)) # feat, kernel, repeat, train/test
for i, f in enumerate(feat):
    for j, k in enumerate(kernel):
        for r in range(5):
            a = np.load('models/testreport_stft_depthwise_ovs_'+str(f)+'_k'+str(k)+'_r'+str(r)+'.npy', allow_pickle=True).item()
            out[i, j, r, 0] = a['train_auc']
            out[i, j, r, 1] = a['test_auc']

fig, ax = plt.subplots(ncols=len(feat), sharey=True, figsize=(10, 5))
for f in range(len(feat)):
    bp1 = ax[f].boxplot(out[f,:,:,0].T, widths=.3, positions=np.arange(len(kernel))+.2, patch_artist=True, manage_ticks=False)
    for element in ['boxes', 'whiskers', 'fliers', 'medians', 'caps']:
        plt.setp(bp1[element], color='blue')
    for patch in bp1['boxes']:
        patch.set(facecolor='white')

    bp2 = ax[f].boxplot(out[f,:,:,1].T, widths=.3, positions=np.arange(len(kernel))-.2, patch_artist=True, manage_ticks=False)
    for element in ['boxes', 'whiskers', 'fliers', 'medians', 'caps']:
        plt.setp(bp2[element], color='orange')
    for patch in bp2['boxes']:
        patch.set(facecolor='white')
    ax[f].legend([bp1["boxes"][0], bp2["boxes"][0]], ['train', 'test'], loc='lower right')

#    ax[f].set_ylim(0.9, 1)
    ax[f].set_xticks(np.arange(len(kernel)))
    ax[f].set_xticklabels(kernel)
    ax[f].set_xlabel('kernel size for '+str(feat[f])+' feats')
    ax[f].grid()
ax[0].set_ylabel('AUC')
plt.tight_layout()
#plt.show()
plt.savefig('cach_aucs_vs_kernelsize&feats.pdf')

