#
# Calculate the Long Term Spectrograms
# Input is all the wav files in the /wav directory
# Output is a .mat file in a /lts/mat/ch directory
# Output is 3 png files for different spectra in /lts
#

import os
from pathlib import Path
import maad
import numpy as np 
import matplotlib as mpl
import matplotlib.pyplot as plt 
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from scipy.io import wavfile
from scipy import signal
import json
from soundscape_IR.soundscape_viewer import lts_viewer, lts_maker, interactive_matrix, audio_visualization
from soundscape_IR.soundscape_viewer.utility import matrix_operation
mpl.rcParams['agg.path.chunksize'] = 10000
import argparse

# Defaults
LOCAL_ROOT = "/Users/patrickmclean/GPDev/bioacoustics/mks/plots/"
SERVER_ROOT = "D:/Bioacoustics/Projects/mks/plots/"
sub_dir = "plot1/20220926/"             # Set as blank to process everything in the root_dir
ch = 2                                  # Channel to be processed, using numbering [1,2,3,4]

# Create json of file times & get start and stop time
def get_start_and_stop_time(dir, cur_plot, cur_date):
    print(f"Processing start and stop on {cur_plot}:{cur_date} ")
    parent = Path(dir).parent.absolute()   
    os.makedirs(os.path.join(parent,"data"), exist_ok=True)
    files = sorted(os.listdir(dir) )
    files = [file for file in files if file[0] != '.']
    if len(files) > 0:
        file_times = list(map(lambda x: x[9:15], files))
        start_time = int(files[0][9:11]) + int(files[0][11:13])/60
        stop_time = int(files[-1][9:11]) + int(files[-1][11:13])/60
        print(f"Day range (decimal hours): {start_time:.2f} to {stop_time:.2f}")
        with open(f"../data/{cur_plot}_{cur_date}_files.json", 'w') as outfile:
            json.dump(files, outfile)
        with open(f"../data/{cur_plot}_{cur_date}_file_times.json", 'w') as outfile:
            json.dump(file_times, outfile)
        return (start_time, stop_time)
    else:
        print("No files found")
        return (-1, -1)
    
## Production of a matrix containing the info to produce a LTS
def create_lts(dir, cur_plot, cur_date, channel):
    if len(os.listdir(dir)) > 0:
        LTS_run=lts_maker(sensitivity=0, environment='air', FFT_size=2048, channel=(channel-1),initial_skip=0)
        LTS_run.collect_folder(path=dir) 
        LTS_run.filename_check(dateformat='yyyymmdd_HHMMSS',year_initial=2000)
        mat_path = f"../lts/mat/ch{channel}/{cur_plot}_{cur_date}_ch{channel}.mat"
        if not os.path.exists(mat_path):
            LTS_run.run(save_filename=mat_path)
        else:
            print(f"LTS already exists for {cur_plot}:{cur_date}:{channel}, skipping")
    else:
        print(f"Not creating LTS - No files in {dir}")

## Visualization of the LTS
def visualize_lts(f_low, f_high, dir, plot, date, channel, start_time=0, stop_time=24):

    # Viz Parameters
    f_tickgap = 10000       # Frequency interval for Y axis 
    fig_size = (30,30)      # Display size of figure

    # Read and transform mat data
    LTS=lts_viewer(path=f"{dir}/mat/ch{channel}/")  
    print(f"lts: {dir}/mat/ch{channel}")
    input_data,f=LTS.input_selection('diff', f_range=[f_low, f_high], prewhiten_percent=0) # 3 types of lts "median","mean", or "diff"
    flipped_array=np.fliplr(input_data[:,1:])
    image_array = np.log(np.array(flipped_array.T))
    image_array = image_array.clip(0,100)
    image_array *= (255.0/image_array.max())
    image_width = image_array.shape[1]
    image_height = image_array.shape[0]

    # Pad the image if partial day
    segment_duration = image_width / (stop_time - start_time)
    left_pad_len = int(start_time * segment_duration)
    right_pad_len = int((24 - stop_time) * segment_duration)
    left_pad = np.zeros(shape=(image_height,left_pad_len))
    right_pad = np.zeros(shape=(image_height,right_pad_len))
    image_array = np.hstack((left_pad,image_array,right_pad))
    image_width = image_array.shape[1]
    image_aspect_ratio = image_height / image_width
    print(f"Partial day. Padding left {left_pad_len}, right {right_pad_len}")

    print(f"Visualizing array: w {image_width} x h {image_height}")
    
    # Set up plots
    fig, ax = plt.subplots(1,1, figsize=fig_size) # 
    im = ax.imshow(X=image_array,cmap='jet',extent=[0,image_width,0,image_height], aspect=1/image_aspect_ratio)

    # X & Y Axes 
    f_steps = (f_high-f_low)/((len(f)-1))
    ytick_steps = f_tickgap/f_steps
    ytick_locations = np.arange(0,len(f)+1,ytick_steps)
    ytick_labels = list(map(str,(np.arange(f_low, f_high+1, f_tickgap))))
    ax.set_yticks(ytick_locations)
    ax.set_yticklabels(ytick_labels)

    xtick_locations = np.arange(0,image_array.shape[1],image_array.shape[1]/24)
    xtick_labels = list(map(str,np.arange(0, 25, 1)))
    ax.set_xticks(xtick_locations)
    ax.set_xticklabels(xtick_labels)

    # Plot it & save it
    plt.colorbar(im, ax=ax)
    plt.savefig(f"{dir}/{plot}_{date}_{f_low}_{f_high}_ch{channel}_fig.png")
    #plt.show()
    plt.imsave(f"{dir}/{plot}_{date}_{f_low}_{f_high}_ch{channel}_lts.png",image_array)

def visualize_lts_set(root, plot, date, channel, start_time, stop_time):
    visualize_lts(f_low=0, f_high=128000, dir=root, plot=plot, date=date, channel=channel, start_time=start_time, stop_time=stop_time)
    visualize_lts(f_low=0, f_high=20000, dir=root, plot=plot, date=date, channel=channel, start_time=start_time, stop_time=stop_time)
    visualize_lts(f_low=20000, f_high=128000, dir=root, plot=plot, date=date, channel=channel, start_time=start_time, stop_time=stop_time)

####################
# Main #
####################

# Read command line arguments, otherwise we use the default values
env = os.getenv("GPSVR")
argParser = argparse.ArgumentParser()
argParser.add_argument("-d", "--dir", required=False, help="sub directory to process")
argParser.add_argument("-c", "--channel", required=False, type=int, help="channel to process")
args = vars(argParser.parse_args())
if args['dir'] != None:
    sub_dir = args['dir']
if args['channel'] != None:
    ch = args['channel']
root_dir = (SERVER_ROOT if env == "SERVER" else LOCAL_ROOT) + sub_dir

# Iterate over all the plots and dates under a root directory and create lts
for root, dirs, data_files in os.walk(root_dir):
    path = root.split(os.sep)
    dir = os.path.basename(root)
    if dir == "wav":
        os.chdir(root)
        parent = Path(root).parent.absolute()
        cur_date = os.path.basename(parent)
        cur_plot = os.path.basename(parent.parent.absolute())
        for i in range(1,5):
            os.makedirs(f"{parent}/lts/mat/ch{i}",exist_ok=True)
        print(f"Processing {cur_plot}:{cur_date} {str(parent)}")
        (start_time, stop_time) = get_start_and_stop_time(root, cur_plot, cur_date)
        if (start_time != -1): # Indicates no files found
            create_lts(root, cur_plot, cur_date, ch)
            visualize_lts_set(f"{str(parent)}/lts", cur_plot, cur_date, ch, start_time, stop_time)