import argparse
import soundfile as sf
import umap
import librosa as lr
import glob

import numpy as  np
import scipy as scp
from scipy import signal

import matplotlib.pyplot as plt
DEBUG = False
EPS = np.finfo(float).eps
ANALYSIS_DURATION = 10

def nextpow(a):
    return int(np.ceil(np.log2(a)))

def energy_freq_band(stft_sig, freq_min, freq_max, sr, nfft, nb_filters=10):
    v_energy = np.zeros((stft_sig.shape[0], nb_filters))
    bin_min = int(freq_min / sr * nfft)
    bin_max = int(freq_max / sr * nfft)
    bin_cut = np.linspace(bin_min, bin_max, nb_filters+1).astype(int)
    for idx_i in range(0, nb_filters):
        v_energy[:, idx_i] = np.sum(np.power(np.abs(stft_sig[:, bin_cut[idx_i]:bin_cut[idx_i+1]]), 2), axis=1)
    return v_energy


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="TODO")
    args = parser.parse_args()

    #SHARED PARAMS
    sr = 256_000
    nb_filters = 12

    frame_duration = 1 # in mili_second
    overlap_ratio = 1/2
    nfft = 2**8

    frame_len = int(frame_duration * sr /1000) # in sample
    hop_len = int(np.floor(frame_len * overlap_ratio))
    nb_frames = int(np.floor((ANALYSIS_DURATION * sr) / hop_len))

    # GET AUDIO FILES
    nb_files_peer_folder = 10

    l_filenames = list()

    l_folders = glob.glob("/short/CARIMAM/DATA/LOT2/*")
    nb_folder = len(l_folders)

    import ipdb; ipdb.set_trace()
    for idx_i, folder in enumerate(l_folders):
        print(folder)
        files_in_folder = glob.glob(folder + "/*.WAV")
        idx_f = 0
        count_f = 0
        while idx_f < files_in_folder and count_f < nb_files_peer_folder:
            if lr.get_duration(files_in_folder[idx_f]) >= ANALYSIS_DURATION:
                l_filenames += files_in_folder[idx_i]
                count_f += 1
            idx_f += 1
        if count_f < nb_files_peer_folder:
            print("ERROR: TOO FRE VALIDE AUDIO FILE IN : " + folder)

    v_labels = np.empty(0)
    t_band_energy = np.zeros((len(l_filenames), nb_frames, nb_filters))

    for idx_j, file in enumerate(l_filenames):
        print(file)
        v_sig, sr = sf.read(file)
        if v_sig.ndim > 1:
            v_sig = v_sig[:,0]
        if len(v_sig) < ANALYSIS_DURATION * sr:
            print("ERROR: FILE TOO SHORT :"+ file)
            exit()

        frames = lr.util.frame(v_sig, frame_length=frame_len, hop_length=hop_len).T
        frames = frames[:nb_frames]
        frames = np.multiply(frames, np.hanning(frame_len))
        stft_sig = scp.fft.fft(frames, nfft, axis=-1)[:, :int(nfft/2)+1]
        t_band_energy[idx_j] = energy_freq_band(stft_sig, 10_000, 120_000, sr, nfft, nb_filters)
        v_labels = np.hstack([v_labels, np.ones(nb_frames)* idx_j])

    import ipdb; ipdb.set_trace()

    print("Umpa start")
    reducer = umap.UMAP()
    embedding = reducer.fit_transform(band_energy)
    plt.scatter(embedding[:, 0], embedding[:, 1], c=frame_class, cmap='Spectral', s=5)
    plt.gca().set_aspect('equal', 'datalim')
    plt.colorbar()
    plt.title('UMAP projection of the Digits dataset', fontsize=24)
    plt.show()

    import ipdb; ipdb.set_trace()
