#
# Create master index files across plots in data directory
# Create box plots of audio indices in data/graph directory
#

import os
from pathlib import Path
import soundfile as sf
import maad
import numpy as np 
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt 
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from scipy.io import wavfile
from scipy import signal
import json
from soundscape_IR.soundscape_viewer import lts_viewer, lts_maker, interactive_matrix, audio_visualization
from soundscape_IR.soundscape_viewer.utility import matrix_operation
from plotnine import ggplot, aes, geom_boxplot, theme_classic, scale_x_discrete, theme, element_text, scale_fill_manual, facet_wrap, ggsave
import argparse

# Defaults
LOCAL_ROOT = "/Users/patrickmclean/GPDev/bioacoustics/mks/plots/"
SERVER_ROOT = "D:/Bioacoustics/Projects/mks/plots/"
sub_dir = "plot1/20220926/"             # Set as blank to process everything in the root_dir
ch = 2                                  # Channel to be processed, using numbering [1,2,3,4]
indices = ['BI', 'ACI', 'HFC','Ht', 'EVNtMean']

# Create box plots of audio indices
def ai_display(df, plot_types,index):

    # Parse dates
    df['date']=pd.to_datetime(df['Date'],format="%Y%m%d_%H%M%S") 
    df['day']=df['date'].dt.day
    df['time']=df['date'].dt.time
    df['hour'] = df['date'].dt.hour

    # Calculate period means
    #sub1 = df.iloc[:,1:86] 
    sub1 = df.copy()
    sub1.drop('date',inplace=True,axis=1)
    hourlymeans = sub1.groupby(['hour','plot','channel','day']).mean()
    hourlymeans.reset_index(level=['hour','plot','channel','day'],inplace=True)
    hourlymeans['period']='night'
    hourlymeans.loc[(hourlymeans["hour"] >= 5) & (hourlymeans["hour"] <= 7),'period'] = "dawn"
    hourlymeans.loc[(hourlymeans["hour"] >= 8) & (hourlymeans["hour"] <= 16),'period'] = "day"
    hourlymeans.loc[(hourlymeans["hour"] >= 17) & (hourlymeans["hour"] <= 19),'period'] = "dusk"
    # hourlymeans.isnull().values.any()
    # means = hourlymeans[hourlymeans.isna().any(axis=1)] #to display where na are
    only_ind = hourlymeans.iloc[:,4:-1]
    normalized_df=(only_ind-only_ind.min())/(only_ind.max()-only_ind.min()) # works column by column (normalization per acoustic index)
    all=pd.concat([hourlymeans[['plot','hour','period','channel','day']],normalized_df],axis=1)
    id_hour=['hour', 'plot','period','channel','day']
    piv_hour=all.melt(id_vars=id_hour, var_name="acoustic_index")
    sub = piv_hour[(piv_hour["acoustic_index"] == index)].copy()
    
    # add type column to sub based on plot_types
    sub.loc[:,'type'] = sub['plot'].map(plot_types)

    # make box plot
    color_dict = {'control': 'blue', 'conservation': 'green','production':'orange'}
    l = ggplot(sub[sub.acoustic_index==index]) + geom_boxplot(aes(x='type',y='value', fill='type')) + theme_classic() + scale_x_discrete(limits = ["control", "conservation", "production"]) + theme(axis_text = element_text(size = 20)) + theme(axis_title = element_text(size = 20)) + scale_fill_manual(values=color_dict) + facet_wrap('~period',ncol=2,nrow=2)
    #print(l)
    return l


####################
# Main #
####################

# Read command line arguments, otherwise we use the default values
env = os.getenv("GPSVR")
argParser = argparse.ArgumentParser()
argParser.add_argument("-d", "--dir", required=False, help="sub directory to process")
argParser.add_argument("-c", "--channel", required=False, type=int, help="channel to process")
args = vars(argParser.parse_args())
if args['dir'] != None:
    sub_dir = args['dir']
if args['channel'] != None:
    ch = args['channel']
root_dir = (SERVER_ROOT if env == "SERVER" else LOCAL_ROOT)
sub_root_dir = root_dir + sub_dir

# Iterate over the index files and concatenate

master_df = pd.DataFrame()
master_bin_df = pd.DataFrame()
# Index of what plots are what type
with open(f"{root_dir}/data/plot_types.json") as f:
    plot_types = json.load(f)

# Iterate over all index files in the directory
for root, dirs, data_files in os.walk(sub_root_dir):
    path = root.split(os.sep)
    dir = os.path.basename(root)
    if dir == "indices":
        os.chdir(root)
        parent = Path(root).parent.absolute()
        cur_date = os.path.basename(parent)
        cur_plot = os.path.basename(parent.parent.absolute())
        print(f"Adding indices for {cur_plot}:{cur_date} {str(parent)}")
        files = os.listdir()
        try:
            index_csv_file = [f for f in files if (f.endswith(f"ch{ch}.csv") and "bin" not in f )][0]
            index_bin_csv_file = [f for f in files if (f.endswith(f"ch{ch}.csv") and "bin" in f)][0]
        except:
            print("No index files found for this channel")
            continue
        
        if cur_plot in plot_types:
            plot_type = plot_types[cur_plot]
        else:
            plot_type = "unknown"

        # Perform the concatenations
        df = pd.read_csv(index_csv_file,sep=";")
        df["date"] = cur_date
        df["plot"] = cur_plot
        df["channel"] = ch
        df["type"] = plot_type
        master_df = master_df.append(df, ignore_index=True)
        df = pd.read_csv(index_bin_csv_file,sep=";")
        df["date"] = cur_date
        df["plot"] = cur_plot
        df["channel"] = ch
        df["type"] = plot_type
        master_bin_df = master_bin_df.append(df, ignore_index=True)

# Normalize each column of the master index files to 1.5 x IQR
for i in indices:
    q1 = master_df[i].quantile(0.25)
    q3 = master_df[i].quantile(0.75)
    iqr = q3 - q1 # Interquartile range
    min = q1 - 1.5 * iqr
    max = q3 + 1.5 * iqr
    master_df.loc[:,'n_'+i] = (master_df[i] - min) / (max - min)
    # remove rows with values outside of the range
   #master_df = master_df[(master_df['n_'+i] >= 0) & (master_df['n_'+i] <= 1)]
    # clip values outside of the range
    master_df.loc[:,'n_'+i] = master_df['n_'+i].clip(0,1)

# Output the master index files
print('Outputting master index files')
master_df.to_csv(f"{root_dir}/data/master_index_ch{ch}.csv")
master_bin_df.to_csv(f"{root_dir}/data/master_bin_index_ch{ch}.csv") 

for i in indices:
    print(f"Plotting {i}")
    graph = ai_display(master_df, plot_types, index=('n_'+i))
    ggsave(graph, filename=f"{i}_plot_ch{ch}.png", path=f"{root_dir}data/graphs",width=20, height=20, dpi=300, verbose = False)

