import argparse
import soundfile as sf
import umap
import librosa as lr
import glob
import os

import numpy as  np
import scipy as scp
from scipy import signal

import matplotlib.pyplot as plt
DEBUG = False
EPS = np.finfo(float).eps
ANALYSIS_DURATION = 50

def nextpow(a):
    return int(np.ceil(np.log2(a)))

def energy_freq_band(stft_sig, freq_min, freq_max, sr, nfft, nb_filters=10):
    v_energy = np.zeros((stft_sig.shape[0], nb_filters))
    bin_min = int(freq_min / sr * nfft)
    bin_max = int(freq_max / sr * nfft)
    bin_cut = np.linspace(bin_min, bin_max, nb_filters+1).astype(int)
    for idx_i in range(0, nb_filters):
        v_energy[:, idx_i] = np.sum(np.power(np.abs(stft_sig[:, bin_cut[idx_i]:bin_cut[idx_i+1]]), 2), axis=1)
    return v_energy


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="TODO")
    args = parser.parse_args()

    #SHARED PARAMS
    sr = 256_000
    nb_filters = 10

    frame_duration = 1 # in mili_second
    overlap_ratio = 1/2
    nfft = 2**8

    frame_len = int(frame_duration * sr /1000) # in sample
    hop_len = int(np.floor(frame_len * overlap_ratio))
    nb_frames = int(np.floor((ANALYSIS_DURATION * sr) / hop_len))

    # GET AUDIO FILES
    nb_files_peer_folder = 100

    l_filenames = list()

    l_folders = glob.glob("/short/CARIMAM/DATA/LOT2/*")
    nb_folder = len(l_folders)

    for idx_i, folder in enumerate(l_folders):
        print(folder)
        files_in_folder = glob.glob(folder + "/*.WAV")
        idx_f = 0
        count_f = 0
        while idx_f < len(files_in_folder) and count_f < nb_files_peer_folder:
            #print(files_in_folder[idx_f])
            try:
                if sf.info(files_in_folder[idx_f]).duration >= ANALYSIS_DURATION:
                    l_filenames.append(files_in_folder[idx_f])
                    count_f += 1
            except:
                #print("STRANGE FILE ...")
                pass
            idx_f += 1
        if count_f < nb_files_peer_folder:
            print("ERROR: TOO FRE VALIDE AUDIO FILE IN : " + folder)

    v_labels = np.empty(0)
    t_band_energy = np.zeros((len(l_filenames), nb_frames, nb_filters))

    print("Process STFT and band filters")
    for idx_j, file in enumerate(l_filenames):
        print(file)
        v_sig, sr = sf.read(file)
        if v_sig.ndim > 1:
            v_sig = v_sig[:,0]
        if len(v_sig) < ANALYSIS_DURATION * sr:
            print("ERROR: FILE TOO SHORT :"+ file)
            exit()

        frames = lr.util.frame(v_sig, frame_length=frame_len, hop_length=hop_len).T
        frames = frames[:nb_frames]
        frames = np.multiply(frames, np.hanning(frame_len))
        stft_sig = scp.fft.fft(frames, nfft, axis=-1)[:, :int(nfft/2)+1]
        t_band_energy[idx_j] = energy_freq_band(stft_sig, 10_000, 120_000, sr, nfft, nb_filters)
        v_labels = np.hstack([v_labels, np.ones(nb_frames)* idx_j])

    #band_energy = t_band_energy.reshape([len(l_filenames) * nb_frames, nb_filters])
    band_energy = t_band_energy.reshape([len(l_filenames), nb_frames * nb_filters])
    v_labels = np.arange(len(l_filenames))
    v_labels_spot = (v_labels/nb_files_peer_folder).astype(int)

    print("Umpa start")
    reducer = umap.UMAP()
    embedding = reducer.fit_transform(band_energy)

    np.save("embedding_interfiles.npy", embedding)
    np.save("labels_interfiles.npy", v_labels)
    np.save("labels_session_interfiles.npy", v_labels_spot)

    each_N_points = 1
    plt.figure()
    plt.scatter(embedding[::each_N_points, 0], embedding[::each_N_points, 1], c=v_labels[::each_N_points], cmap='Spectral', s=2)
    plt.gca().set_aspect('equal', 'datalim')
    plt.colorbar(boundaries=np.arange(np.max(v_labels)+2)-0.5).set_ticks(np.arange(np.max(v_labels)+1))
    plt.title('UMAP projection - file labels')

    plt.figure()
    plt.scatter(embedding[::each_N_points, 0], embedding[::each_N_points, 1], c=v_labels_spot[::each_N_points], cmap='Spectral', s=2)
    ax = plt.gca()
    ax.set_aspect('equal', 'datalim')
    plt.colorbar(boundaries=np.arange(np.max(v_labels_spot)+2)-0.5, ticks=np.arange(np.max(v_labels_spot+1))).set_ticklabels([os.path.basename(fold) for fold in l_folders])
    plt.title('UMAP projection - spot labels')

    #plt.show()

    import ipdb; ipdb.set_trace()
