import numpy as np
import scipy.signal as sg
import soundfile as sf
import argparse
import os

import glob
from multiprocessing import Pool
import sys
import pandas as pd
import gc


def clickness_one_file(args):
    filename, idx, ch, hp, dist = args
    print("\r%d"%idx, end="")
    try :
        sig, sr = sf.read(filename)
        if sig.ndim > 1:
            sig = sig[:, ch]

        sos = sg.butter(6, hp, fs=sr, btype='hp', output='sos')
        sig = sg.sosfiltfilt(sos, sig)
    except Exception :
        return filename, np.NAN
    amp = np.abs(sg.hilbert(sig))
    percentiles = (np.percentile(amp, np.linspace(15, 40, 1024))/np.sqrt(-2*np.log(1-np.linspace(15, 40, 1024)/100)))
    sigma = percentiles.mean()
    #sigma = np.median(amp) / np.sqrt(2*np.log(2))
    # #print('σ:', sigma)#, '±', percentiles.std())
    seuil = sigma * np.sqrt(-2*np.log(1-0.99))  * 10**(5/20) #quantile 0.99 + 5db (donc : seuil * 10**(5/20))
    peaks = sg.find_peaks(amp, height=seuil, distance=int(dist/1000*sr))[0]
    clickness = len(peaks)/(len(sig)/sr/60)
    return filename, clickness

def main(args):
    files = glob.glob(args.input)
    files = np.sort(files)
    print("Total file : %d"%len(files))

    proc_args = list()
    for idx, f in enumerate(files):
        proc_args.append((f, idx, args.ch, args.hp, args.dist))

    nb_proc = min(args.nb_cpu, os.cpu_count())
    df = pd.DataFrame(columns=['file', 'clickness'])

    batch = 250
    for idx in range(0, len(proc_args)+batch, batch):
        sub_proc_args = proc_args[idx: min(idx+batch, len(proc_args))]
        pool = Pool(processes=nb_proc)
        sub_res = pool.map(clickness_one_file, sub_proc_args, chunksize=1)
        pool.close()
        sub_df = pd.DataFrame(sub_res, columns=['file', 'clickness'])
        df = pd.concat((df, sub_df))
        gc.collect()

    df.sort_values("file", inplace=True)
    df.reset_index(inplace=True)
    df.drop(columns='index', inplace=True)
    df.to_csv("results_clickness_%s.csv"%(args.output))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                     description='Tout doux')
    parser.add_argument('input', type=str, help='Path of the directory')
    parser.add_argument('output', type=str, help='Name of output file')
    parser.add_argument('--hp', type=float, default=1e4, help='High pass filter cut off frequency')
    parser.add_argument('--ch', type=int, default=0, help='Channel')
    parser.add_argument('--dist', type=float, default=2, help='Distance between clicks in miliseconds')
    parser.add_argument('--nb_cpu', type=int, default=8, help='Number of CPU used during multiprocessing')

    sys.exit(main(parser.parse_args()))
