
from datetime import datetime
import librosa as lib
from librosa import display
import maad
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from maad import sound, features, util
import numpy as np
import os
import pandas as pd
from plotnine import *
import scipy.signal
from scipy.io import wavfile
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import soundfile as sf


def acoustic_indices():
    ## list of all acoustic indices that can be computed
    SPECTRAL_FEATURES=['MEANf','VARf','SKEWf','KURTf','NBPEAKS','LEQf',
    'ENRf','BGNf','SNRf','Hf', 'EAS','ECU','ECV','EPS','EPS_KURT','EPS_SKEW','ACI',
    'NDSI','rBA','AnthroEnergy','BioEnergy','BI','ROU','ADI','AEI','LFC','MFC','HFC',
    'ACTspFract','ACTspCount','ACTspMean', 'EVNspFract','EVNspMean','EVNspCount',
    'TFSD','H_Havrda','H_Renyi','H_pairedShannon', 'H_gamma', 'H_GiniSimpson','RAOQ',
    'AGI','ROItotal','ROIcover']

    TEMPORAL_FEATURES=['ZCR','MEANt', 'VARt', 'SKEWt', 'KURTt',
    'LEQt','BGNt', 'SNRt','MED', 'Ht','ACTtFraction', 'ACTtCount',
    'ACTtMean','EVNtFraction', 'EVNtMean', 'EVNtCount']
    ## data upload (recommended to first test on a small subset, e.g. 5 recordings)
    df = pd.read_csv("/nfs/NAS4/SABIOD/SITE/greenpraxis/indonesia/plot2/20220930/wav/filename.txt", sep=";", names=['file'])
    # save date and time of recording
    df['Date'] = pd.to_datetime(df['file'].str.slice(start=0, stop=15),format='%Y%m%d_%H%M%S')
    data = pd.DataFrame(df['Date'])
    # computing and saving acoustic indices
    for i in range(0,4): #4 channels

        df_indices = pd.DataFrame()
        df_indices_per_bin = pd.DataFrame()

        for index, row in df.iterrows() :
            fullfilename = row['file']
        # Load the original sound (24bits) and get the sampling frequency fs
            try :
                wave,fs = sf.read('/nfs/NAS4/SABIOD/SITE/greenpraxis/indonesia/plot2/20220930/wav/'+fullfilename)
            except:
                # Delete the row if the file does not exist or raise a value error (i.e. no EOF)
                df.drop(index, inplace=True)
                continue

            # compute all the audio indices and store them into a DataFrame
            try:
                df_audio_ind = features.all_temporal_alpha_indices(wave[:,i], fs,
                                                verbose = False, display = False)
                
                Sxx_power,tn,fn,ext = sound.spectrogram(wave[:,i], fs, window='hanning', nperseg = 2048, noverlap=2048//2, verbose = False, display = False, savefig = None)
                df_spec_ind, df_spec_ind_per_bin = features.all_spectral_alpha_indices(Sxx_power,
                                                                tn,fn,
                                                                flim_low = [0,1500],
                                                                flim_mid = [1500,10000],
                                                                flim_hi  = [10000,128000],
                                                                verbose = False,
                                                                R_compatible = 'soundecology',
                                                                display = False)
            except ZeroDivisionError:
                continue
            df_indices = df_indices.append(pd.concat([df_audio_ind, df_spec_ind], axis=1))
            df_indices_per_bin = df_indices_per_bin.append([df_spec_ind_per_bin])
        df_indices.reset_index(inplace=True, drop=True)
        df_indices_per_bin.reset_index(inplace=True, drop=True)
        df_indices2 = pd.concat([data,df_indices], axis=1)
        df_indices_per_bin2 = pd.concat([data,df_indices_per_bin], axis=1)

        df_indices2.to_csv('/nfs/NAS4/SABIOD/SITE/greenpraxis/indonesia/results/acoustic_indices/indices/AI_plot2_'+fullfilename[0:8]+'_'+str(i+1)+'.csv', sep=";",date_format='%Y-%m-%d %H:%M:%S')
        df_indices_per_bin2.to_csv('/nfs/NAS4/SABIOD/SITE/greenpraxis/indonesia/results/acoustic_indices/indices_per_bin/AI_bin_plot2_'+fullfilename[0:8]+'_'+str(i+1)+'.csv', sep=";",date_format='%Y-%m-%d %H:%M:%S')

def ai_display():

    folder = '/nfs/NAS4/SABIOD/SITE/greenpraxis/indonesia/results/acoustic_indices/indices/'
    test = pd.read_csv(folder+'AI_control_plot2_20220928_1.csv',sep=';') # to retrieve column names
    test.drop(['Unnamed: 0'], inplace=True, axis=1)
    col_names=test.columns
    var_names=col_names.insert(0,['location','channel'])
    database=pd.DataFrame(columns=var_names)

    filelist = [file for file in os.listdir(folder) if file.endswith('.csv')]
    for file in filelist:
        df=pd.read_csv(folder+file,sep=';')
        df.drop('Unnamed: 0', inplace=True, axis=1)
        df['location']=file.split('AI_')[1].split('_2022')[0]
        df['channel']=file.split('.csv')[0][-1]
        database = pd.concat([df,database],axis=0)
        df=[]
    # database.to_csv(folder+'../'+'all_indices.csv',sep=';')

    database['date']=pd.to_datetime(database['Date']) # to have datetime series
    database['day']=database['date'].dt.day
    database['time']=database['date'].dt.time

    database['hour'] = database['date'].dt.hour
    sub1 = database.iloc[:,1:67]
    sub1.drop('date',inplace=True,axis=1)
    hourlymeans = sub1.groupby(['hour','location','channel','day']).mean()
    hourlymeans.reset_index(level=['hour','location','channel','day'],inplace=True)
    hourlymeans['period']='night'
    hourlymeans.loc[(hourlymeans["hour"] >= 5) & (hourlymeans["hour"] <= 7),'period'] = "dawn"
    hourlymeans.loc[(hourlymeans["hour"] >= 8) & (hourlymeans["hour"] <= 15),'period'] = "day"
    hourlymeans.loc[(hourlymeans["hour"] >= 16) & (hourlymeans["hour"] <= 18),'period'] = "dusk"
    # hourlymeans.loc[(hourlymeans["hour"] >= 18) & (hourlymeans["hour"] <= 4),'period'] = "night"
    
    # hourlymeans.isnull().values.any()
    # means = hourlymeans[hourlymeans.isna().any(axis=1)] #to display where na are

    only_ind = hourlymeans.iloc[:,4:-1]
    normalized_df=(only_ind-only_ind.min())/(only_ind.max()-only_ind.min()) #works column by column (normalization per acoustic index)
    all=pd.concat([hourlymeans[['location','hour','period','channel','day']],normalized_df],axis=1)
    meas_var=all.iloc[:,4:63].columns
    id_hour=['hour', 'location','period','channel','day']
    piv_hour=all.melt(id_vars=id_hour, var_name="acoustic_index")

    # sub = piv_hour[(piv_hour["acoustic_index"] == 'EVNtCount') | (piv_hour["acoustic_index"] == 'ECU') | (piv_hour["acoustic_index"] == 'HFC')| (piv_hour["acoustic_index"] == 'ROIcover')]
    sub = piv_hour[(piv_hour["acoustic_index"] == 'BI') | (piv_hour["acoustic_index"] == 'ACI')]
    sub['normalized value'] = sub['value']

    loc_sub=sub.loc[(sub.location=='control_plot2') & (sub.channel=='1')]
    sub2= sub.drop(loc_sub.index)
    loc_sub2=sub2.loc[(sub2.location=='plot2') & (sub2.channel=='4')]
    sub3 = sub2.drop(loc_sub2.index)

    # sub 3 with not-functional channels removed
    # k = ggplot(sub3) + geom_boxplot(aes(x='location',y='normalized value', fill='channel')) + facet_grid('period ~ acoustic_index')
    # print(k)

    # temporal variation with not-functional channels removed
    # z = ggplot(sub3) + geom_point(aes(x='hour',y='normalized value',color='channel')) + facet_grid('location ~ acoustic_index')
    # print(z)

    # for the presentation
    sub3['type'] = sub3['location']
    sub3.loc[sub3['type']=='production_plot2','type'] = 'production'
    sub3.loc[sub3['type']=='plot2','type'] = 'conservation'
    sub3.loc[sub3['type']=='control_plot2','type'] = 'pristine'
    # color_dict = {'pristine': 'white', 'conservation': 'white','production':'grey'}
    color_dict = {'pristine': 'blue', 'conservation': 'green','production':'orange'}
    # k = ggplot(sub3[(sub3.acoustic_index=='ECU') & (sub3.period=='day') | (sub.acoustic_index=='HFC') & (sub.period=='day') | (sub.acoustic_index=='ROIcover') & (sub.period=='day')]) + geom_boxplot(aes(x='type',y='normalized value', fill='type')) + facet_grid('period ~ acoustic_index') + theme_classic() + scale_fill_manual(values=color_dict)+ scale_x_discrete(limits = ["pristine", "conservation", "production"]) + theme(axis_text = element_text(size = 24)) + theme(axis_title = element_text(size = 24)) + theme(strip_text = element_text(size = 24)) 
    # m = ggplot(sub3[(sub3.acoustic_index=='ROIcover') & (sub3.period=='dawn') | (sub3.acoustic_index=='HFC') & (sub3.period=='dawn') | (sub3.acoustic_index=='EVNtCount') & (sub3.period=='dawn')]) + geom_boxplot(aes(x='type',y='normalized value', fill='type')) + facet_grid('period ~ acoustic_index') + theme_classic() + scale_fill_manual(values=color_dict)+ scale_x_discrete(limits = ["pristine", "conservation", "production"]) + theme(axis_text = element_text(size = 18)) + theme(axis_title = element_text(size = 24)) + theme(strip_text = element_text(size = 24)) 
    l = ggplot(sub3[sub3.acoustic_index=='BI']) + geom_boxplot(aes(x='type',y='normalized value', fill='type')) + theme_classic() + scale_x_discrete(limits = ["pristine", "conservation", "production"]) + theme(axis_text = element_text(size = 24)) + theme(axis_title = element_text(size = 24)) + scale_fill_manual(values=color_dict) + facet_wrap('~period',ncol=2,nrow=2)
    print(l)

    # z = ggplot(sub3[(sub3.acoustic_index=='ECU')|(sub3.acoustic_index=='ROIcover')|(sub3.acoustic_index=='EVNtCount') | (sub3.acoustic_index=='HFC')]) + geom_point(aes(x='hour',y='normalized value',color='channel')) + facet_grid('type ~ acoustic_index') + theme(axis_text = element_text(size = 24)) + theme(axis_title = element_text(size = 24)) + theme(strip_text = element_text(size = 24)) 
    # print(z)

    # print(m)

    # correlation map
    # maad.util.plot_correlation_map(hourlymeans[(hourlymeans.location == 'plot2') | (hourlymeans.location == 'control_plot2')])
    # maad.util.plot_correlation_map(hourlymeans[hourlymeans.location == 'production_plot2'])
    # maad.util.plot_correlation_map(hourlymeans.loc[(hourlymeans.location == 'production_plot2') & (hourlymeans.period == 'day') ,['ROIcover','EVNtCount','HFC']])

    # loc_h=hourlymeans.loc[(hourlymeans.location=='control_plot2') & (hourlymeans.channel=='1')]
    # hourlymeans2= hourlymeans.drop(loc_h.index)
    # loc_h2=hourlymeans2.loc[(hourlymeans2.location=='plot2') & (hourlymeans2.channel=='4')]
    # hourlymeans3 = hourlymeans2.drop(loc_h2.index)

    # correlation indices
    # targets= ['production_plot2','control_plot2','plot2']
    # periodes = sub3.period.unique()
    # create = []
    # for i in targets:
    #     for j in periodes:
    #         choice = hourlymeans3.loc[(hourlymeans3.location == i) & (hourlymeans3.period == j), ['ROIcover','EVNtCount','HFC']]
    #         corr_matrix = choice.corr()
    #         create.append(i)
    #         create.append(j)
    #         create.append(corr_matrix.values[0][1])
    #         create.append(corr_matrix.values[0][2])
    #         create.append(corr_matrix.values[1][2])
    # first = create[2:len(create):5]
    # second = create[3:len(create):5]
    # third = create[4:len(create):5]
    # fourth = create[0:len(create):5]
    # fifth = create[1:len(create):5]
    # df = pd.concat([pd.DataFrame(first),pd.DataFrame(second),pd.DataFrame(third),pd.DataFrame(fourth),pd.DataFrame(fifth)],axis=1)
    # df.columns=['ROI_EVNtCount','ROI_HFC','HFC_EVNtCount','site','period']
    # # print(df.sort_values(by=['ROI_EVNtCount','ROI_HFC'],ascending=False))
    # # create.to_csv('/nfs/NAS4/SABIOD/SITE/greenpraxis/indonesia/results/acoustic_indices/correlation.csv', sep=';')

    # PCA with selected indices
    # pca = PCA(n_components=3)
    # x = StandardScaler().fit_transform(hourlymeans[['EVNtCount','EVNtFraction','HFC','MFC','LFC', 'ECU','ROIcover','ACTtCount','ACTtFraction','NDSI']])
    # # x = StandardScaler().fit_transform(hourlymeans3.iloc[:,3:63])
    # principalComponents = pca.fit_transform(x)
    # principalDf = pd.DataFrame(data = principalComponents, columns = ['principal component 1', 'principal component 2','principal component 3'])
    # finalDf = pd.concat([principalDf, hourlymeans[['location','hour','period','channel']]], axis = 1)
    # pca.explained_variance_ratio_ # 43%, 25%, 11%


    # fig = plt.figure(figsize = (8,8))
    # ax = fig.add_subplot(1,1,1)
    # ax.set_xlabel('Principal Component 1', fontsize = 15)
    # ax.set_ylabel('Principal Component 2', fontsize = 15)
    # ax.set_title('3 component PCA', fontsize = 20)
    # colors = ['r', 'g', 'b']


    # for target, color in zip(targets,colors):
    #     indicesToKeep = finalDf['location'] == target
    #     ax.scatter(finalDf.loc[indicesToKeep, 'principal component 1'], finalDf.loc[indicesToKeep, 'principal component 2'], c = color, s = 50)
    
    # ax.legend(targets)
    # ax.grid()

    # by location/period (not enought points to display it by channel)
    # h = ggplot(finalDf) + geom_point(aes(x=finalDf['principal component 1'],y=finalDf['principal component 2'], colour='location')) + facet_wrap('period')
    # print(h)

    # daily plot
    # val = sub1.iloc[:,0:-3] # only values
    # norm_data = (database.iloc[:,1:-6]-database.iloc[:,1:-6].min())/(database.iloc[:,1:-6].max()-database.iloc[:,1:-6].min()) 
    # join = pd.concat([database[['location','hour','date','day','channel','time','Date']],norm_data],axis=1)
    # id_col = ['location','hour','date','day','channel','time','Date']
    # melting = join.melt(id_vars=id_col, var_name="acoustic_index")
    # melting['value']= pd.to_numeric(melting.value)
    # one_day=melting[(melting.day == 28) & (melting.channel == '1') & (melting.location == 'plot2') & (melting.acoustic_index=='ROIcover')]
    
    # plt.plot(one_day.date,one_day.value, linewidth=5,color='red')
    # xformatter = mdates.DateFormatter('%H:%M')
    # plt.gcf().axes[0].xaxis.set_major_formatter(xformatter)

    # doesn't work with time data
    # a = ggplot(one_day) + geom_point(aes(x=one_day.hour,y=one_day.value))
    # print(a)
        


    return database, hourlymeans, piv_hour, sub3
    
def check_channels():
    sr, x = wavfile.read('/nfs/NAS4/SABIOD/SITE/greenpraxis/indonesia/control_plot2/20220927/wav/20220927_120715UTC_V12.wav')
    x[:,0] # first channel
    x[:,1] # second channel
    plt.specgram(x[:,3], Fs=sr) # plot mono spectrogram
    plt.show()


# acoustic_indices()
database, hourlymeans, piv_hour, sub3 = ai_display()
