import argparse
import soundfile as sf
import umap
import librosa as lr
import glob
import os
import multiprocessing
import pickle

import numpy as  np
import scipy as scp
from scipy import signal

import matplotlib.pyplot as plt
DEBUG = False
EPS = np.finfo(float).eps
ANALYSIS_DURATION = 10

def nextpow(a):
    return int(np.ceil(np.log2(a)))

def energy_freq_band(stft_sig, freq_min, freq_max, sr, nfft, nb_filters=10):
    v_energy = np.zeros((stft_sig.shape[0], nb_filters))
    bin_min = int(freq_min / sr * nfft)
    bin_max = int(freq_max / sr * nfft)
    bin_cut = np.linspace(bin_min, bin_max, nb_filters+1).astype(int)
    for idx_i in range(0, nb_filters):
        v_energy[:, idx_i] = np.sum(np.power(np.abs(stft_sig[:, bin_cut[idx_i]:bin_cut[idx_i+1]]), 2), axis=1)
    return v_energy

def  compute_one_file_starter(data):
    file = data[0]
    sr = data[1]
    nfft = data[2]
    nb_filters = data[3]
    [start_frame, end_frame] = data[4]
    return compute_one_file(file, sr, nfft, nb_filters, [start_frame, end_frame])

def compute_one_file(file, sr, nfft, nb_filters, frame_idx=[0,-1]):
    print(file)
    nb_frames = frame_idx[-1] - frame_idx[0]
    m_band_energy = np.zeros((nb_frames, nb_filters))
    v_sig, sr = sf.read(file)
    if v_sig.ndim > 1:
        v_sig = v_sig[:,0]

    frames = lr.util.frame(v_sig, frame_length=frame_len, hop_length=hop_len).T
    frames = frames[frame_idx[0]:frame_idx[1]]
    frames = np.multiply(frames, np.hanning(frame_len))
    stft_sig = scp.fft.fft(frames, nfft, axis=-1)[:, :int(nfft/2)+1]
    m_band_energy = energy_freq_band(stft_sig, 10_000, 120_000, sr, nfft, nb_filters)
    return m_band_energy


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="TODO")
    args = parser.parse_args()

    #SHARED PARAMS
    sr = 256_000
    nb_filters = 10

    frame_duration = 1 # in mili_second
    overlap_ratio = 1/2
    nfft = 2**8

    frame_len = int(frame_duration * sr /1000) # in sample
    hop_len = int(np.floor(frame_len * overlap_ratio))
    nb_frames = int(np.floor((ANALYSIS_DURATION * sr) / hop_len))

    # GET AUDIO FILES
    nb_files_peer_folder = 100

    l_filenames = list()

    l_folders = glob.glob("/short/CARIMAM/DATA/LOT2/*")
    nb_folder = len(l_folders)

    for idx_i, folder in enumerate(l_folders):
        print(folder)
        files_in_folder = glob.glob(folder + "/*.WAV")
        idx_f = 0
        count_f = 0
        while idx_f < len(files_in_folder) and count_f < nb_files_peer_folder:
            #print(files_in_folder[idx_f])
            try:
                if sf.info(files_in_folder[idx_f]).duration >= ANALYSIS_DURATION:
                    l_filenames.append(files_in_folder[idx_f])
                    count_f += 1
            except:
                #print("STRANGE FILE ...")
                pass
            idx_f += 1
        if count_f < nb_files_peer_folder:
            print("ERROR: TOO FEW VALIDE AUDIO FILE IN : " + folder)

    #v_labels = np.empty(0)
    #t_band_energy = np.zeros((len(l_filenames), nb_frames, nb_filters))

    print("Process STFT and band filters")
    #TODO AJOUTER BOUCLE POUR FAIRE DES SLICES DE FRAME
    start_frame = 0
    end_frame = nb_frames   #TODO compute right frame
    args_pool = [[file, sr, nfft, nb_filters, [start_frame, end_frame]] for file in l_filenames]

    nb_cpu = 6 #multiprocessing.cpu_count()
    #from tqdm.contrib.concurrent import process_map
    #r = process_map(_foo, range(0, 30), max_workers=2)
    pool = multiprocessing.Pool(processes=nb_cpu)
    l_band_energy = pool.map(compute_one_file_starter, args_pool)
    t_band_energy = np.array(l_band_energy)
    del l_band_energy

    #band_energy = t_band_energy.reshape([len(l_filenames) * nb_frames, nb_filters])
    band_energy = t_band_energy.reshape([len(l_filenames), nb_frames * nb_filters])
    v_labels = np.arange(len(l_filenames))
    v_labels_spot = (v_labels/nb_files_peer_folder).astype(int)

    print("Umpa start")
    reducer = umap.UMAP()
    embedding = reducer.fit_transform(band_energy)

    if DEBUG :
        np.save("embedding_interfiles.npy", embedding)
        np.save("labels_interfiles.npy", v_labels)
        np.save("labels_session_interfiles.npy", v_labels_spot)

        each_N_points = 1
        plt.figure()
        plt.scatter(embedding[::each_N_points, 0], embedding[::each_N_points, 1], c=v_labels[::each_N_points], cmap='Spectral', s=2)
        plt.gca().set_aspect('equal', 'datalim')
        plt.colorbar(boundaries=np.arange(np.max(v_labels)+2)-0.5).set_ticks(np.arange(np.max(v_labels)+1))
        plt.title('UMAP projection - file labels')

        plt.figure()
        plt.scatter(embedding[::each_N_points, 0], embedding[::each_N_points, 1], c=v_labels_spot[::each_N_points], cmap='Spectral', s=2)
        ax = plt.gca()
        ax.set_aspect('equal', 'datalim')
        plt.colorbar(boundaries=np.arange(np.max(v_labels_spot)+2)-0.5, ticks=np.arange(np.max(v_labels_spot+1))).set_ticklabels([os.path.basename(fold) for fold in l_folders])
        plt.title('UMAP projection - spot labels')

        #plt.show()

    pickle.dump(reducer, open("save_umap.pkl", "wb" ))

    import ipdb; ipdb.set_trace()

    pickle.load(open("save_umap.pkl" , "rb"))
    import ipdb; ipdb.set_trace()
