import argparse
import soundfile as sf
import umap
import librosa as lr

import numpy as  np
import scipy as scp
from scipy import signal

import matplotlib.pyplot as plt
DEBUG = False
EPS = np.finfo(float).eps

def coef_comb_filters(freq_min, freq_max, sr, order=4, nb_filters=10):
    """
    """
    freq_cuts = np.linspace(freq_min, freq_max, nb_filters+1).astype(int)
    coef_sos = np.empty(nb_filters, dtype=object)

    for idx_i in range(1, nb_filters-1):
        coef_sos[idx_i] = signal.butter(order, [freq_cuts[idx_i], freq_cuts[idx_i+1]], 'bandpass', output='sos', fs=sr)

    if freq_min == 0:
        coef_sos[0] = signal.butter(order, freq_cuts[1], 'lowpass', output='sos', fs=sr)
    else:
        coef_sos[0] = signal.butter(order, [freq_cuts[0], freq_cuts[1]], 'bandpass', output='sos', fs=sr)
    if freq_max == int(sr/2):
        coef_sos[-1] = signal.butter(order, freq_cuts[-2], 'highpass', output='sos', fs=sr)
    else:
        coef_sos[-1] = signal.butter(order, [freq_cuts[-2], freq_cuts[-1]], 'bandpass', output='sos', fs=sr)

    if DEBUG:
        for idx_i in range(0, nb_filters-1):
            w, h = signal.sosfreqz(coef_sos[idx_i])
            plt.plot(w, abs(h))
        plt.show()
    return coef_sos


def nextpow(a):
    return int(np.ceil(np.log2(a)))

def energy_freq_band_old(tf_sig, freq_min, freq_max, sr, nfft, nb_filters=10):
    v_energy = np.zeros(nb_filters)
    bin_min = int(freq_min / sr * nfft)
    bin_max = int(freq_max / sr * nfft)
    bin_cut = np.linspace(bin_min, bin_max, nb_filters+1).astype(int)
    for idx_i in range(0, nb_filters-1):
        v_energy[idx_i] = np.dot(np.abs(tf_sig[bin_cut[idx_i]:bin_cut[idx_i+1]]), np.abs(tf_sig[bin_cut[idx_i]:bin_cut[idx_i+1]]))
    return v_energy

def energy_freq_band(stft_sig, freq_min, freq_max, sr, nfft, nb_filters=10):
    v_energy = np.zeros((stft_sig.shape[0], nb_filters))
    bin_min = int(freq_min / sr * nfft)
    bin_max = int(freq_max / sr * nfft)
    bin_cut = np.linspace(bin_min, bin_max, nb_filters+1).astype(int)
    for idx_i in range(0, nb_filters):
        v_energy[:, idx_i] = np.sum(np.power(np.abs(stft_sig[:, bin_cut[idx_i]:bin_cut[idx_i+1]]), 2), axis=1)
    return v_energy

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="TODO")
    parser.add_argument("input", type=str, help="An audio filename")
    args = parser.parse_args()

    v_sig, sr = sf.read(args.input)
    if v_sig.ndim > 1:
        v_sig = v_sig[:,0]

    frame_duration = 1 # in mili_second
    overlap_ratio = 1/2

    nfft = 2**8

    frame_len = int(frame_duration * sr /1000) # in sample
    hop_len = int(np.floor(frame_len * overlap_ratio))

    print("STFT")
    frames = lr.util.frame(v_sig, frame_length=frame_len, hop_length=hop_len).T
    nb_frames = frames.shape[0]
    frames = np.multiply(frames, np.hanning(frame_len))
    stft_sig = scp.fft.fft(frames, nfft, axis=-1)[:, :int(nfft/2)+1]

    print("ENERGY BAND")
    nb_filters = 12
    band_energy = energy_freq_band(stft_sig, 10_000, 120_000, sr, nfft, nb_filters)
    if DEBUG :
        plt.imshow(np.flipud(np.log10(band_energy+EPS).T), aspect='auto', interpolation=None)
        plt.colorbar()
        plt.xlabel("Frames")
        plt.ylabel("Band Filters")
        plt.show()

    pos_clics_time = np.array([0.579512, 0.653879, 0.689477, 0.745972, 0.851037, 1.047330, 1.058283, 1.346670, 1.374053, 1.513274, \
        1.853257, 1.977345, 2.028076, 2.059783, 2.197130, 2.409997, 2.462314, 2.526303, 2.713085, 3.501717, 3.683021, 3.963770, 4.299140, \
        4.715939, 4.852278, 4.904018, 5.208546, 5.226129, 5.286516, 5.352091, 5.377456, 5.419684, 5.496501, 6.582311, 7.098842, 7.183153, \
        10.899899, 11.011737, 11.404035, 13.103947, 15.527204, 17.760941, 18.287272, 18.319843, 20.089942, 23.241300, 23.400986, 24.106748, \
        24.113954, 24.710761, 24.730650, 25.523029, 25.595090, 26.636943, 26.671532, 26.773426, 26.800953, 27.264159, 28.689808, 32.844827, \
        34.360335, 35.001387, 35.091463, 35.487869, 36.262233, 37.320083, 37.689322, 37.706473, 37.724056, 40.767034, 43.232374, 44.343550, \
        44.406819, 44.822177, 45.315937, 45.487874, 46.006710, 46.098083, 46.126043, 46.792748, 46.806872, 47.158240, 48.250104, 48.406475, \
        48.589077, 50.981204, 52.077102, 52.124807, 52.262298, 52.514871, 55.237396, 56.703686, 58.334131])   #position of clics in time
    pos_clics_sample = pos_clics_time * sr
    pos_clics_frame = ((pos_clics_sample) / hop_len ).astype(int)                                      #position of clics in frame

    frame_class = np.zeros(nb_frames)
    frame_class[pos_clics_frame] = 1

    #band_energy = np.hstack((band_energy, band_energy / np.sum(band_energy, axis=1)[:, np.newaxis]))

    import ipdb; ipdb.set_trace()
    print("Umpa start")
    reducer = umap.UMAP()
    embedding = reducer.fit_transform(band_energy)
    plt.scatter(embedding[:, 0], embedding[:, 1], c=frame_class, cmap='Spectral', s=5)
    plt.gca().set_aspect('equal', 'datalim')
    plt.colorbar()
    plt.title('UMAP projection of the Digits dataset', fontsize=24)
    plt.show()

    import ipdb; ipdb.set_trace()
