import pandas as pd, numpy as np
from scipy import stats
from tqdm import tqdm
import matplotlib.pyplot as plt

df = pd.read_pickle('seqs.pkl')
# # Select only A1 sequences
# df['type_'] = df.type + '-'
# grp = df.groupby('passage').type_.sum()
# df = df[df.passage.isin(grp[(grp.str.contains('N09i-')&(~grp.str.contains('N09iii-|N09ii-')))].index)]

# Get behavior
pods = pd.read_csv('pods_2016_annot.csv')
df[['pod', 'behavior', 'matriline']] = None
pods = pd.read_csv('pods_2016_annot.csv')
pods.behavior = pods.behavior.replace({'socializing':'socialising'})
pods = pods[pods.pod.isin(['A1', 'A4', 'A5'])]
for i, r in pods.iterrows():
    passages = df.loc[((df.date>r.start)&(df.date<r.end)&(df.hydro_==r.hydrophone))].passage.unique()
    df.loc[df.passage.isin(passages), ['pod', 'behavior', 'matriline']] = r.pod, r.behavior, r.matriline
print(f'{(~df.behavior.isna()).sum()} calls with behavior annotation')

min_nb_seqs_behav = int(df.groupby('behavior').seq.nunique().min()*0.8)
min_nb_seqs_hydro = int(df.groupby('hydro_').seq.nunique().min()*0.5)
print(min_nb_seqs_hydro, 'min_nb_seqs_hydro', df.groupby('hydro_').seq.nunique().idxmin())
print(min_nb_seqs_behav, 'min_nb_seqs_behav', df.groupby('behavior').seq.nunique().idxmin())

get_unigram = lambda df, types: df.type.value_counts().loc[types]/len(df)
get_entropy = lambda distrib : -sum([p*np.log2(p) if p > 0 else 0 for p in distrib]) / np.log2(len(distrib))

def get_bigram(df, types):
    type_idx = {t:i for i, t in enumerate(types)}
    values = np.zeros((len(types), len(types)))
    for t1, grp in df.groupby('type'):
        for t2, grpp in grp.groupby('totype'):
            if t2 == 'end':
                continue
            values[type_idx[t1], type_idx[t2]] = len(grpp) / len(grp)
    return values

def get_ER(df):
    types = df.type.unique()
    bigram = get_bigram(df, types)
    unigram = get_unigram(df, types)
    return sum([p1*get_entropy(bigram[i1]) for i1, p1 in enumerate(unigram)])

def get_assoc_rate(df, types):
    type_idx = {t:i for i, t in enumerate(list(types))}
    values = np.zeros((len(types), len(types)))
    for t1, grp in df.groupby('type'):
        for t2, grpp in df[df.seq.isin(grp.seq.unique())].groupby('type'):
            values[type_idx[t1], type_idx[t2]] = grpp.seq.nunique() / grp.seq.nunique()
    return values

def get_EAR(df):
    types = df.type.unique()
    assoc_rate = get_assoc_rate(df, types)
    unigram = get_unigram(df, types)
    return sum([p1*get_entropy(assoc_rate[i1]) for i1, p1 in enumerate(unigram)])


fig, ax = plt.subplots(nrows=4, ncols=2, sharex='col', figsize=(10, 5))
ax[0,0].set_ylabel('N')
ax[1,0].set_ylabel('H')
ax[2,0].set_ylabel('ER')
ax[3,0].set_ylabel('EAR')

Rb, Rh, Eb, Eh, ERb, ERh, EARb, EARh = [], [], [], [], [], [], [], []
for repeat in tqdm(range(100)):
    dfs_behav = [df[df.seq.isin(np.random.choice(df[df.behavior==p].seq.unique(), min_nb_seqs_behav, replace=False))] for p in df.behavior.dropna().unique()]
    dfs_hydro = [df[df.seq.isin(np.random.choice(df[df.hydro_==p].seq.unique(), min_nb_seqs_hydro, replace=False))] for p in df.hydro_.unique()]

    annot_R, inferr_R = ([df_.type.nunique() for df_ in dfs] for dfs in [dfs_behav, dfs_hydro])
    Rb.append(annot_R)
    Rh.append(inferr_R)

    annot_E, inferr_E = ([get_entropy(get_unigram(df_, df_.type.unique())) for df_ in dfs] for dfs in [dfs_behav, dfs_hydro])
    Eb.append(annot_E)
    Eh.append(inferr_E)
              
    annot_ER, inferr_ER = ([get_ER(df_) for df_ in dfs] for dfs in [dfs_behav, dfs_hydro])
    ERb.append(annot_ER)
    ERh.append(inferr_ER)
              
    annot_EAR, inferr_EAR = ([get_EAR(df_) for df_ in dfs] for dfs in [dfs_behav, dfs_hydro])
    EARb.append(annot_EAR)
    EARh.append(inferr_EAR)

Rb, Eb, ERb, EARb, Rh, Eh, ERh, EARh = [np.array(a) for a in [Rb, Eb, ERb, EARb, Rh, Eh, ERh, EARh]]

print('Behavior')
for i in range(Rb.shape[1]-1):
    for n, l in zip(['N','H','ER','EAR'], [Rb, Eb, ERb, EARb]):
        print('& '+n, end='')
        for j in range(1, Rb.shape[1]):
            if i < j:
                p = stats.kruskal(l[:,i], l[:,j]).pvalue
                print(" & " + (' ns ' if p>.001 else ' + ' if np.median(l[:,i])<np.median(l[:,j]) else ' - '), end='')
            else:
                print(" & ", end='')
        print('\\\\')
    print('\hline')

print('Hydro')
for i in range(Rh.shape[1]-1):
    for n, l in zip(['N','H','ER','EAR'], [Rh, Eh, ERh, EARh]):
        print('& '+n, end='')
        for j in range(1, Rh.shape[1]):
            if i < j:
                p = stats.kruskal(l[:,i], l[:,j]).pvalue
                print(" & " + (' ns ' if p>.001 else ' + ' if np.median(l[:,i])<np.median(l[:,j]) else ' - '), end='')
            else:
                print(" & ", end='')
        print('\\\\')
    print('\hline')

ax[0,0].violinplot(Rb, positions=range(df.behavior.nunique()))
ax[1,0].violinplot(Eb, positions=range(df.behavior.nunique()))
ax[2,0].violinplot(ERb, positions=range(df.behavior.nunique()))
ax[3,0].violinplot(EARb, positions=range(df.behavior.nunique()))

dfs = [df[df.behavior==p] for p in df.behavior.dropna().unique()]
ax[0,0].scatter(range(df.behavior.dropna().nunique()), [df_.type.nunique() for df_ in dfs])
ax[1,0].scatter(range(df.behavior.dropna().nunique()), [get_entropy(get_unigram(df_, df_.type.unique())) for df_ in dfs])
ax[2,0].scatter(range(df.behavior.dropna().nunique()), [get_ER(df_) for df_ in dfs])
ax[3,0].scatter(range(df.behavior.dropna().nunique()), [get_EAR(df_) for df_ in dfs])

ax[3,0].set_xlabel('Behavior')
ax[3,0].set_xticks(range(df.behavior.nunique()))
ax[3,0].set_xticklabels(df.behavior.dropna().unique())

ax[0,1].violinplot(Rh, positions=range(df.hydro_.nunique()))
ax[1,1].violinplot(Eh, positions=range(df.hydro_.nunique()))
ax[2,1].violinplot(ERh, positions=range(df.hydro_.nunique()))
ax[3,1].violinplot(EARh, positions=range(df.hydro_.nunique()))

dfs = [df[df.hydro_==p] for p in df.hydro_.unique()]
ax[0,1].scatter(range(df.hydro_.nunique()), [df_.type.nunique() for df_ in dfs])
ax[1,1].scatter(range(df.hydro_.nunique()), [get_entropy(get_unigram(df_, df_.type.unique())) for df_ in dfs])
ax[2,1].scatter(range(df.hydro_.nunique()), [get_ER(df_) for df_ in dfs])
ax[3,1].scatter(range(df.hydro_.nunique()), [get_EAR(df_) for df_ in dfs])

ax[3,1].set_xlabel('Hydrophone')
ax[3,1].set_xticks(range(df.hydro_.nunique()))
ax[3,1].set_xticklabels(df.hydro_.unique())

for i in range(4):
    for j in range(2):
        ax[i,j].grid()
plt.tight_layout()
plt.savefig('complexity_vs_hydro.pdf')
plt.show()
