import os
import sys
import numpy as np
import scipy.signal as sg
import matplotlib.pyplot as plt
import soundfile as sf
import sounddevice as sd
import argparse
import librosa as lr
from matplotlib.widgets import Button
import time

ROOT_DIR = "/run/user/1003/gvfs/sftp:host=10.2.248.22,user=pierre.mahe/"
#ROOT_DIR = "/"

parser = argparse.ArgumentParser()
parser.add_argument("session_path", type=str)
args = parser.parse_args()

SR = 256000

# Init filter
sos = sg.butter(6, 10000, 'hp', output='sos', fs=SR)

# TODO check if session is already done

#Load session statistic
sess_path = args.session_path
histo = np.load(ROOT_DIR+'/short/CARIMAM/detec_click/result_transformer_V1/'+os.path.basename(sess_path)+'_hist.npy')
pers_histo = np.percentile(histo, 97, axis=1)
none_zeros = pers_histo != 0
pers_histo = pers_histo[none_zeros]
med_pers_histo = sg.medfilt(pers_histo, 9)

std_pers = (pers_histo-med_pers_histo)[9:-9].std()
thd_pers = -3 * std_pers

idx_candidate_file = np.where((pers_histo-med_pers_histo) < (thd_pers))[0]

v_files = np.array(sorted([f for f in os.listdir(sess_path) if f.lower().endswith('.wav')]))
v_files = v_files[none_zeros]
nb_candidate_file = len(idx_candidate_file)


#INIT all shared values
iter_files = enumerate(v_files[idx_candidate_file])
file = ''
idx_peak = None
v_peaks = None
v_filt_sig = None
v_pos_frame = None
candidate_clicks = None

l_checked_file = list()


fig, ax = plt.subplots(2, figsize=(15, 10))#, sharex=True)
ax[0].set_xlabel("Sample")
ax[0].set_ylabel("Amplitude")
ax[0].set_ylim([-1.1, 1.1])

ax[1].set_xlabel("Time (seconds)")
ax[1].set_ylabel("Frequency")

plt.tight_layout()

axyes = plt.axes([0.10, 0.05, 0.075, 0.075])
axno = plt.axes([0.175, 0.05, 0.075, 0.075])
axund = plt.axes([0.25, 0.05, 0.075, 0.075])
axnext_click = plt.axes([0.40, 0.05, 0.075, 0.075])
axprev_click = plt.axes([0.475, 0.05, 0.075, 0.075])
axstart = plt.axes([0.625, 0.05, 0.075, 0.075])
axaud = plt.axes([0.7, 0.05, 0.075, 0.075])
axlisten = plt.axes([0.775, 0.05, 0.075, 0.075])

fig.subplots_adjust(
    top=0.94,
    bottom=0.2,
    left=0.06,
    right=0.94,
    hspace=0.15,
    wspace=0.2
)

class ppp(object):
    def __init__(self):
        self.next(None)

    def onclick(self):
        pass

    def yes(self, event):
        global file
        global l_checked_file
        l_checked_file.append([sess_path + '/' + file, v_pos_frame[v_peaks[idx_peak]]/SR, 'd'])
        self.next_click(event)

    def no(self, event):
        global file
        global l_checked_file
        l_checked_file.append([sess_path + '/' + file, v_pos_frame[v_peaks[idx_peak]]/SR, 'n'])
        self.next_click(event)

    def undefined(self, event):
        global file
        global l_checked_file
        l_checked_file.append([sess_path + '/' + file, v_pos_frame[v_peaks[idx_peak]]/SR, '?'])
        self.next_click(event)

    def start(self, event):
        self.next(event)

    def next(self, event):
        global iter_files
        global file
        global idx_peak
        global v_peaks
        global candidate_clicks
        global v_filt_sig
        global v_pos_frame

        idx_f, file = next(iter_files)
        idx_peak = 0
        print("NEXT FILE")
        print(sess_path + '/' + file)
        print("File Loading ...")


        v_sig, _ = sf.read(sess_path + '/' + file)
        v_filt_sig = sg.sosfiltfilt(sos, v_sig)
        v_filt_sig = v_filt_sig / (np.std(v_filt_sig)*50) #np.max(np.abs(v_filt_sig))

        #os.system("python3 " + ROOT_DIR + "/home/pierre.mahe/src/transformer-carimam/forward_detector.py " + sess_path + '/' + file + " --weight /short/CARIMAM/detec_click/transformer_carimam/04-08-22_11\:40\:18\:hp\=False\:lr\=0.0005\:ne\=250\:wd\=0.05\:bs\=4/ckpt_1600_best.pth  --trans")

        m_pred = np.load(ROOT_DIR+"/home/pierre.mahe/src/transformer-carimam/out_file_pred.npy")
        v_pos_frame = np.load(ROOT_DIR+"/home/pierre.mahe/src/transformer-carimam/out_file_pos_frame.npy")

        candidate_clicks = m_pred[:, 1] > 0.97
        most_prob_mom = np.convolve(candidate_clicks, np.ones(int(SR/256)), mode="same").astype(int)
        peaks_data = sg.find_peaks(most_prob_mom, distance=int(SR/256), prominence = 5)
        v_peaks = peaks_data[0][np.argsort(most_prob_mom[peaks_data[0]])[::-1]] #Get peak in inverse probability order

        candidate_clicks = np.repeat(candidate_clicks, 256)
        #ax[1].plot(most_prob_mom)
        #ax[1].plot(v_peaks, most_prob_mom[v_peaks], '+')

        # updating plots
        self.display()

    def next_click(self, event):
        global idx_peak
        print("Click Loading ...")
        idx_peak = (idx_peak + 1) % len(v_peaks)

        self.display()

    def prev_click(self, event):
        global idx_peak
        idx_peak = (idx_peak - 1) % len(v_peaks)

        self.display()

    def display(self):
        ax[0].cla()
        ax[1].cla()
        ax[0].set_ylim([-1.1, 1.1])
        ax[0].set_title("%s\nClicks : %02d / %02d"%(os.path.basename(file), idx_peak+1, len(v_peaks)))
        #ax[0].plot(v_filt_sig)
        #ax[0].plot(v_peaks*256, np.ones(len(v_peaks)), '+', c='r')
        beg = max(0, v_pos_frame[v_peaks[idx_peak]] - (3*SR))
        end = min(len(v_filt_sig), v_pos_frame[v_peaks[idx_peak]] + (3*SR))
        ax[0].plot(np.linspace(beg/SR , end/SR, end-beg), v_filt_sig[beg:end])
        ax[0].plot(np.linspace(beg/SR , end/SR, end-beg), candidate_clicks[beg:end]*0.5-1)
        ax[1].specgram(v_filt_sig[max(0, v_pos_frame[v_peaks[idx_peak]] - (3*SR)) : min(len(v_filt_sig), v_pos_frame[v_peaks[idx_peak]] + (3*SR))], NFFT=2**10, noverlap=2**9, Fs=SR, window=np.blackman(2**10))
        ax[1].set_ylim(10000,10000)
        plt.draw()
        print("Done")

    def audacity(self, event):
        os.system("audacity " + sess_path + '/' + file + ' 2> /dev/null')

    def listen(self, event):
        print("Start playing...")
        sd.play(v_filt_sig[max(0, v_pos_frame[v_peaks[idx_peak]] - (3*SR)) : min(len(v_filt_sig), v_pos_frame[v_peaks[idx_peak]] + (3*SR))], int(SR/5)) #Play signal slowdown 5 times
        print("Stop playing")

callback = ppp()

byes = Button(axyes, 'Click')
byes.on_clicked(callback.yes)

bno = Button(axno, 'Noise')
bno.on_clicked(callback.no)

bund = Button(axund, '?')
bund.on_clicked(callback.undefined)

bstart = Button(axstart, 'Next file')
bstart.on_clicked(callback.start)

bprev_clk = Button(axprev_click, 'Prev click')
bprev_clk.on_clicked(callback.prev_click)

bnext_clk = Button(axnext_click, 'Next click')
bnext_clk.on_clicked(callback.next_click)

blisten = Button(axlisten, 'Listen')
blisten.on_clicked(callback.listen)

baud = Button(axaud, 'Audacity')
baud.on_clicked(callback.audacity)

cid = fig.canvas.mpl_connect('pick_event', callback.onclick)


plt.show()

print(l_checked_file)
