import numpy as np
import scipy.signal as sg
import soundfile as sf
import argparse
import os
import librosa as lr

import glob
from multiprocessing import Pool
import sys
import pandas as pd

EPS = 1e-16
NFFT = 2048
    
def energy_one_file(args):
    filename, idx_f, ch, freq_min, freq_max, nb_band = args
    print("\r%d"%idx_f, end="")

    interval = (freq_max - freq_min) // nb_band
    freq_beg = np.linspace(freq_min, freq_max - interval, nb_band)
    freq_end = np.linspace(freq_min + interval, freq_max, nb_band)
    try :
        sig, sr = sf.read(filename)
        
        if sig.ndim > 1:
            sig = sig[:, ch]

        stft_sig = np.abs(lr.stft((sig), n_fft=NFFT))
        bin_beg = ((freq_beg/sr) * NFFT).astype(int)
        bin_end = ((freq_end/sr) * NFFT).astype(int)

        power_band = np.zeros(nb_band)
        
        for idx_b in range(nb_band):
            power_band[idx_b] = 20*np.log10(np.mean(np.power(stft_sig[bin_beg[idx_b]:bin_end[idx_b]], 2)) + EPS)
    except :
        return filename, np.zeros(nb_band) * np.NAN
        
    return filename, power_band



def main(args):
    files = glob.glob(args.input)
    files = np.sort(files)
    print("Total file : %d"%len(files))

    proc_args = list()
    for idx, f in enumerate(files):
        proc_args.append((f, idx, args.ch, args.freq_min, args.freq_max, args.nb_band))

    nb_proc = min(args.nb_cpu, os.cpu_count())
    with Pool(processes=nb_proc) as pool:
        res = pool.map(energy_one_file, proc_args, chunksize=1)

    df = pd.DataFrame(res, columns=['file', 'band_energy'])
    df.to_csv("results_noise_%s.csv"%(args.output))



if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     description='')
    parser.add_argument('input', type=str, help='Path of the directory')
    parser.add_argument('output', type=str, help='Name of output file')
    parser.add_argument('--ch', type=int, default=0, help='Channel')
    parser.add_argument('--freq_min', type=float, default=20_000, help='Minimal frequency')
    parser.add_argument('--freq_max', type=float, default=100_000, help='Maximal frequency')
    parser.add_argument('--nb_band', type=int, default=8, help='Number of frequency bands')
    parser.add_argument('--nb_cpu', type=int, default=16, help='Number of CPU used during multiprocessing')

    sys.exit(main(parser.parse_args()))

