import matplotlib.pyplot as plt
import numpy as np
import umap
import pandas as pd
import os

sr = 256000

df = pd.read_pickle('/home/pierre.mahe/clicks/LOT2_JAM_20210406_20320510.pkl')
anot_file_df = df.loc[df["fn"] == "LOT2/JAM_20210406_20320510/21310412_103746UTC_V00OS11.WAV"]

anot_cliks_time = np.array([0.579512, 0.653879, 0.689477, 0.745972, 0.851037, 1.047330, 1.058283, 1.346670, 1.374053, 1.513274, \
        1.853257, 1.977345, 2.028076, 2.059783, 2.197130, 2.409997, 2.462314, 2.526303, 2.713085, 3.501717, 3.683021, 3.963770, 4.299140, \
        4.715939, 4.852278, 4.904018, 5.208546, 5.226129, 5.286516, 5.352091, 5.377456, 5.419684, 5.496501, 6.582311, 7.098842, 7.183153, \
        10.899899, 11.011737, 11.404035, 13.103947, 15.527204, 17.760941, 18.287272, 18.319843, 20.089942, 23.241300, 23.400986, 24.106748, \
        24.113954, 24.710761, 24.730650, 25.523029, 25.595090, 26.636943, 26.671532, 26.773426, 26.800953, 27.264159, 28.689808, 32.844827, \
        34.360335, 35.001387, 35.091463, 35.487869, 36.262233, 37.320083, 37.689322, 37.706473, 37.724056, 40.767034, 43.232374, 44.343550, \
        44.406819, 44.822177, 45.315937, 45.487874, 46.006710, 46.098083, 46.126043, 46.792748, 46.806872, 47.158240, 48.250104, 48.406475, \
        48.589077, 50.981204, 52.077102, 52.124807, 52.262298, 52.514871, 55.237396, 56.703686, 58.334131])
anot_click_pos = (anot_cliks_time * sr).astype(int)

click_data = anot_file_df[['pos','peakpeak', 'efpeak', 'fpeak', 'centroid', 'centroid3dB', 'centroid10dB', 'flatness', 'duration10dB', 'duration20dB', 'rms_time', 'rms_fft', 'interval_prev_peak']].values
all_click_data = df[['pos','peakpeak', 'efpeak', 'fpeak', 'centroid', 'centroid3dB', 'centroid10dB', 'flatness', 'duration10dB', 'duration20dB', 'rms_time', 'rms_fft', 'interval_prev_peak']].values


anot_click_data = np.zeros((len(anot_click_pos), click_data.shape[-1]))

for idx_p in range(anot_click_pos.shape[0]):
    anot_click_data[idx_p] = click_data[np.argmin(np.abs(click_data[:,0] - anot_click_pos[idx_p]))]


print("Umpa start")
reducer = umap.UMAP()
embedding = reducer.fit_transform(all_click_data[::1000, 1:])
#embedding = reducer.fit_transform(click_data[:, 1:])
file_embedding = reducer.transform(click_data[:, 1:])
anot_embedding = reducer.transform(anot_click_data[:,1:])

all_embedding = np.vstack([embedding, file_embedding, anot_embedding])
#all_embedding = np.vstack([embedding, anot_embedding])
v_labels = np.zeros(0)
v_labels = np.hstack([v_labels, 0*np.ones(len(embedding))])
v_labels = np.hstack([v_labels, 1*np.ones(len(file_embedding))])
v_labels = np.hstack([v_labels, 2*np.ones(len(anot_embedding))])

ticklabels = ['All files : LOT2/JAM_20210406_20320510', 'File : 21310412_103746UTC_V00OS11.WAV', 'Click anotation']
#ticklabels = ['File : 21310412_103746UTC_V00OS11.WAV', 'Click anotation']

plt.figure()
plt.scatter(all_embedding[:, 0], all_embedding[:, 1], cmap='Paired', c=v_labels, s=2)
plt.gca().set_aspect('equal', 'datalim')
plt.colorbar(boundaries=np.arange(np.max(v_labels)+2)-0.5, ticks=np.arange(np.max(v_labels)+1)).set_ticklabels(ticklabels)
plt.title('UMAP projection Clicks')

import ipdb; ipdb.set_trace()
