import os
from tqdm.contrib.concurrent import process_map
from multiprocessing import cpu_count
from resampy import resample
import soundfile as sf
from scipy import signal
import numpy as np
import pandas as pd
import argparse

EPS = np.finfo(float).eps

RES_DIR = '/home/pierre.mahe/clicks/'
os.makedirs(RES_DIR, exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("session", type=str)
parser.add_argument('--number_cpu', type=int, default=10.) # number of  cpu used
args = parser.parse_args()

if os.path.isfile(RES_DIR+args.session.replace('/', '_')+'.pkl'):
    print("Session already exists")
    exit()

df = pd.read_csv('/short/CARIMAM/DATA/files.csv')        # Read the list of all audio files of the session
df.drop(df[df.samplerate.isna()].index, inplace=True)    # Remove corrupted files
df.drop(df[df.duration < 10.].index, inplace=True)       # Remove too short files

df = df[df.session==args.session]

sos = signal.butter(3, 5e3*2/256000, 'hp', output='sos')
hann = np.hanning(256)
moving_avg = lambda x, s : np.convolve(x, np.ones(s), 'same') / s

def get(fn):
    # load, resample, filter
    sig, fs = sf.read('/short/CARIMAM/DATA/'+fn)
    if len(sig) < fs :
        print("WARNNING, empty file : " + fn)
        return []
    if fs != 256000:
        sig = resample(sig, fs, 256000)
        fs = 256000
    if len(sig) < 50*fs :
        return []
    sig = signal.sosfiltfilt(sos, sig)

    # find peaks
    ssig = moving_avg(np.abs(signal.hilbert(sig)), int(50e-6*fs))
    ssig_db = 20*np.log10(ssig)
    med = np.median(ssig_db)
    std = np.sqrt(np.square(ssig_db[ssig_db < med] -med).mean())
    peaks, data = signal.find_peaks((ssig), height=10**((med + 3*std + 3)/20), width=[8e-6*fs, 1200e-6*fs], distance=100e-6*fs)

    # print('\r'+fn+' '+str(len(peaks)), end='')
    ret = []
    prev_p = -1
    for i, p in enumerate(peaks[((peaks>128)&(peaks<len(sig)-128))]):   #On ne prend pas les bords
        click = sig[p - int(10e-6*fs) : p + int(10e-6*fs)]
        fft = np.abs(np.fft.rfft(sig[p - 128 : p + 128] * hann))        #On prend 1 ms
        freqs = np.fft.rfftfreq(256, 1/fs)                              #Pourquoi est-ce calcule a chaque fois ?
        fmax, emax = np.argmax(fft), np.max(fft)                        #Position (bin) values of freq max
        troidBlow = (fmax - np.argmin(fft[fmax::-1] >= emax*0.708) + 1) % (fmax + 1)
        troidBHigh = np.argmin(fft[fmax:] >= emax*0.708) + fmax
        dixdBlow = (fmax - np.argmin(fft[fmax::-1] >= emax*0.316) +1) % (fmax + 1)
        dixdBHigh = np.argmin(fft[fmax:] >= emax*0.316) + fmax

        cur_sig = sig[p - 128 : p + 128]
        cur_ssig = ssig[p - 128 : p + 128]
        tmax, temax = np.argmax(cur_ssig), np.max(cur_ssig)
        lower_10dB_time = np.argmax(cur_ssig[tmax:] - temax/np.sqrt(10) < 0)   #largeur du pic 20 dB
        lower_20dB_time = np.argmax(cur_ssig[tmax:] - temax/np.sqrt(20) < 0)   #largeur du pic 20dB

        ener_sig = np.sum(cur_sig**2) + EPS
        t = np.arange(-128, 128) * 1/fs
        t0 = 128 * 1/fs
        mo2_time = np.sum((t-t0)**2 * cur_sig**2)
        rms_time = np.sqrt(mo2_time / ener_sig)         # Same formula as "medium term acoustic monitoring of Patagonian coastal dolphins"

        ener_fft = np.sum(fft**2) + EPS
        freqs0 = np.average(freqs, weights=fft)
        mo2_fft = np.sum((freqs-freqs0)**2 * fft**2)
        rms_fft = np.sqrt(mo2_fft / ener_fft)

        raw_ici = (p - prev_p) * 1/fs
        prev_p = p

        ret.append({'fn': fn,
                    'pos': p,
                    'width': data['widths'][i]/fs,
                    'peakpeak': max(click) - min(click),
                    'efpeak': emax,
                    'fpeak': freqs[fmax],
                    'centroid': np.average(freqs, weights=fft),
                    'centroid3dB': np.average(freqs[troidBlow:troidBHigh], weights=fft[troidBlow:troidBHigh]),
                    'centroid10dB': np.average(freqs[dixdBlow:dixdBHigh], weights=fft[dixdBlow:dixdBHigh]),
                    'flatness': np.exp(2*np.log(fft).mean())/np.square(fft).mean(),
                    'duration10dB': lower_10dB_time,
                    'duration20dB': lower_20dB_time,
                    'rms_time': rms_time,
                    'rms_fft': rms_fft,
                    'interval_prev_peak': raw_ici
                    })
    return ret

nb_cpu = np.min([cpu_count(), args.number_cpu])
ret = process_map(get, df.fn, max_workers=nb_cpu, chunksize=1)
df = pd.DataFrame().from_dict(list(np.concatenate(ret)))
df.to_pickle(RES_DIR+args.session.replace('/', '_')+'.pkl')
