import numpy as np
from tqdm import tqdm
from scipy.special import expit
import pandas as pd
import matplotlib.pyplot as plt



preds = pd.read_pickle('../manip_2021/stft_depthwise_ovs_128_k7_r1.preds')
cachas = pd.read_pickle('../manip_2021/passages.pkl')


for d, grp in tqdm(preds.groupby(preds.date.dt.date)):
    plt.figure(figsize=(10,5))
    passage = cachas[((cachas.dateS < pd.to_datetime(d) + pd.Timedelta('1d'))&(cachas.dateE > pd.to_datetime(d)))]
    if len(passage) > 0:
        plt.title(f'Cachalot confirmed from {passage.iloc[0].dateS} to {passage.iloc[0].dateE}')
    plt.hist2d(grp.date.dt.hour + grp.date.dt.minute/60, expit(grp.pred), bins=(np.arange(0, 24, 1/6), np.arange(0, 1, .01)))
    plt.xlabel('Time (hours)')
    plt.ylabel('Model prediction')
    plt.savefig(str(d))
    plt.close()
