import multiprocessing
import matplotlib.pyplot as plt
import pandas as pd
import soundfile as sf
import numpy as np
from tqdm import tqdm
from scipy import signal
import os
from scipy.special import expit

os.system('rm predpngs/*')
predfn = '../results/KM3Net_stft_depthwise_ovs_128_k7_r1.preds'
#predfn = './BOMBYX5stft_depthwise_64_r0.preds'
folder = '/nfs/NAS6/SABIOD/SITE/KM3Net/DATA_WAV/'
df = pd.read_pickle(folder+predfn)
predTime = (np.arange(len(df.iloc[0].pred))*256*8 + 11264/2)/50000

#df.offset = df.offset.astype(int)
df['maxpred'] = [expit(max(p)) for p in df.pred]
df = df[df.maxpred>0.6]
df['offset'] = [predTime[np.argmax(p)] for p in df.pred]

print('found '+str(len(df))+' files')
ffs = sf.info(folder+df.iloc[3].fn).samplerate
dur = sf.info(folder+df.iloc[3].fn).duration
sos = signal.butter(3, 32000/ffs, 'lp', output='sos')

def f(r):
    print(r['i'], end='\r')
    try:
        sig, fs = sf.read(folder+r['fn'], start=max(0,int((r['offset']-5)*ffs)), stop=min(int(dur*ffs), int((r['offset']+5)*ffs)))
    except :
        print('failed with', r, fs, dur)
    sig = sig[:,0] if sig.ndim > 1 else sig
    sig = signal.sosfiltfilt(sos, sig)[::6]
    plt.figure()
    plt.specgram(sig, NFFT=256, Fs=32000, noverlap=128)
    plt.ylim(3000, 16000)
    plt.title('pred : '+str(round(r['maxpred'],2)) + ' at '+str(round(r['offset']))+'sec')
    plt.savefig('./predpngs/'+r['fn'].split('/')[-1][:-4])
    plt.close()


pool = multiprocessing.Pool(processes=4)
ll = [dict(r, i=ii) for r, ii in zip(df.to_dict('records'), df.to_dict('index'))]
pool.map(f, ll)
