import os
import librosa
import glob
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from p_tqdm import p_map
import soundfile as sf
import scipy.signal as signal
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

def arg_directory(path):
    if os.path.isdir(path):
        return path
    else:
        raise argparse.ArgumentTypeError(f'`{path}` is not a valid path')

def create_spectrogram(y, directory, filename, offset, duration, window_arg, hop_length_arg):
    window = np.hanning(window_arg)
    stft = librosa.core.spectrum.stft(y, n_fft=window_arg, hop_length=hop_length_arg, window=window)

    plt.close()
    plt.figure()

    log_stft = np.log10(np.abs(stft))
    vmin, vmax = log_stft.mean(), log_stft.max()

    plt.imshow(log_stft[::-1], aspect="auto", interpolation=None, cmap='jet', vmin=vmin, vmax=vmax)
    plt.subplots_adjust(top=1, bottom=0, left=0, right=1)

    name = os.path.join(directory, 'Spectrogram', f"{filename.replace('/', '_').split('.')[0]}_{offset}")
    
    try:
        plt.savefig(name + '.jpg')
    except FileNotFoundError:
        os.makedirs(os.path.join(directory, 'Spectrogram'), exist_ok=True)
        plt.savefig(name + '.jpg')

def process_recordings(data, img_per_rec, args):
    _, (i) = data
    duration = args.duration
    overlap = args.overlap
    filename = str(i[0])
    try:
        info = sf.info(filename)
        file_duration, fs = info.duration, info.samplerate
    except Exception as error:
        print(f'`{filename}` cannot be open... : {error}')
    for count in range(img_per_rec):
        offset = count * (duration - overlap)
        if offset > file_duration:
            continue
        try:
            sig, fs = sf.read(filename, start=int(offset*fs), stop=int((offset+duration)*fs), always_2d=True)
            sig = sig[:,0]
            if not args.sr:
                args.sr = fs
            sig = signal_processing(sig, args.sr, fs, args.up, args.low)
            create_spectrogram(sig, args.directory, filename, offset, duration, args.window, args.hop)
        except Exception:
            print(f'`{filename}` cannot be open...')

def signal_processing(sig, sr, fs, up, low): 
    sig_r = signal.resample(sig, int(len(sig)*sr/fs)) # resample
    if up:
        sos = signal.butter(2, up/(sr/2), 'hp', output='sos') # create high pass filter
        sig_r = signal.sosfilt(sos, sig_r) # apply high pass filter
    if low:
        sos2 = signal.butter(1, low/(sr/2), 'lp', output='sos') # create low pass filter
        sig_r = signal.sosfilt(sos2, sig_r) # create low pass filter
    return(sig_r)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter, description='Extract spectrogram for each .wav file')
    parser.add_argument('path_to_data', type=arg_directory, help='Path of the folder that contains the recordings')
    parser.add_argument('directory', type=arg_directory, help='Directory to which spectrograms will be stored')
    parser.add_argument('-m', '--mode', type=str, choices=['unique', 'multiple'], help='if unique, only 1 image per file', default='multiple')
    parser.add_argument('-n', '--columns_name', type=str, help='Name of the columns that contain the path of the .wav', default='Path')
    parser.add_argument('-f', '--file', type=str, help='Name of the file that contains the recording to print', default='None')
    parser.add_argument('--frames', type=int, help='Number of spectrogram per file', default=30)
    parser.add_argument('--duration', type=int, help='Duration for each spectrogram', default=8)
    parser.add_argument('--overlap', type=int, help='Overlap between 2 spectrograms', default=2)
    parser.add_argument('--sr', type=int, help='Sampling rate for the spectrogram. If no argument, '
                                            'SR will be original SR of the recording', default=None)
    parser.add_argument('--window', type=int, help='Window size for the Fourier Transform', default=1024)
    parser.add_argument('--hop', type=int, help='Hop lenght for the Fourier Transform', default=512)
    parser.add_argument('--cpu', type=int, help='To speed up the process, write 2 or more', default=1)
    parser.add_argument('--up', type=int, help='High Pass Filter value in Hz', default=10)
    parser.add_argument('--low', type=int, help='Low Pass Filter value in Hz', default=None)
    args = parser.parse_args()

    if args.mode == 'multiple':
        img_per_rec = args.frames
    elif args.mode == 'unique':
        img_per_rec = 1

    path_to_data = args.path_to_data

    if args.file != 'None':
        try : 
            df = pd.read_csv(args.file, low_memory=False)
        except Exception as error:
            print('Try to load as pickle...')
            df = pd.read_pickle(args.file, low_memory=False)
        df['Path'] = df[args.columns_name]
    else:
        df = pd.DataFrame(glob.glob(os.path.join(path_to_data, '*'), recursive=True), columns=['Path'])   
    
    if args.cpu == 1:
        for num, row in tqdm(df.iterrows(), total = len(df)):
            process_recordings([num, [row.Path]], img_per_rec, args)
        final_dest = os.path.join(args.directory,'Spectrogram')
        print(f'Saved to {final_dest}')
    else:
        img_per_rec = [img_per_rec]*len(df.groupby('Path'))
        args = [args]*len(df.groupby('Path')) 
        p_map(process_recordings, enumerate(df.groupby('Path')), img_per_rec, args, num_cpus=args[0].cpu, total=len(df.groupby('Path')))
        final_dest = os.path.join(args[0].directory,'Spectrogram')
        print(f'Saved to {final_dest}')