import matplotlib.pyplot as plt
import numpy as np
import umap
import pandas as pd
import glob
import os

l_files = glob.glob('/home/pierre.mahe/clicks/*.pkl')
v_labels = np.zeros(0)
all_session_click_data = np.empty((0, 13))

for idx_f, file_session in enumerate(l_files):
    print(file_session)
    df = pd.read_pickle(file_session)
    session_click_data = df[['pos','peakpeak', 'efpeak', 'fpeak', 'centroid', 'centroid3dB', 'centroid10dB', 'flatness', 'duration10dB', 'duration20dB', 'rms_time', 'rms_fft', 'interval_prev_peak']].values
    v_labels = np.hstack([v_labels, idx_f*np.ones(session_click_data.shape[0])])
    all_session_click_data = np.vstack([all_session_click_data, session_click_data])

each_points = 10000

import ipdb; ipdb.set_trace()
print("Umpa start")
reducer = umap.UMAP()
embedding = reducer.fit_transform(all_session_click_data[::each_points, 1:])

plt.figure()
plt.scatter(embedding[:, 0], embedding[:, 1], cmap='Paired', c=v_labels[::each_points], s=2)
plt.gca().set_aspect('equal', 'datalim')
plt.colorbar(boundaries=np.arange(np.max(v_labels)+2)-0.5, ticks=np.arange(np.max(v_labels)+1)).set_ticklabels([os.path.basename(file) for file in l_files])

plt.title('UMAP projection Clicks')

import ipdb; ipdb.set_trace()
