import pandas as pd, numpy as np
from scipy import stats
from tqdm import tqdm
import matplotlib.pyplot as plt

df = pd.read_pickle('seqs.pkl')

# Get pod information from annots
pods_df = pd.read_csv('pods_2016_annot.csv')
df['pod'] = ''
pods_df = pods_df[pods_df.pod.isin(['A1', 'A4', 'A5'])]
for i, r in pods_df.iterrows():
    passages = df.loc[((df.date>r.start)&(df.date<r.end)&(df.hydro_==r.hydrophone))].passage.unique()
    df.loc[(df.passage.isin(passages)&(df.pod=='')), 'pod'] = r.pod
    df.loc[(df.passage.isin(passages)&(df.pod!='')&(df.pod!=r.pod)), 'pod'] += '-'+r.pod
df.pod = df.pod.replace({'': None})
pods = ['A5','A1']
print(f'{(~df.pod.isna()).sum()} calls with pod annotation')

# Infer pod information from call types
df['type_'] = df.type + '-'
grp = df.groupby('passage').type_.sum()
df.loc[df.passage.isin(grp[(grp.str.contains('N09i-')&(~grp.str.contains('N09iii-|N09ii-')))].index), 'pod_'] = 'A1'
df.loc[df.passage.isin(grp[(grp.str.contains('N09ii-')&(~grp.str.contains('N09iii-|N09i-')))].index), 'pod_'] = 'A4'
df.loc[df.passage.isin(grp[(grp.str.contains('N09iii-')&(~grp.str.contains('N09i-|N09ii-')))].index), 'pod_'] = 'A5'
df.loc[df.passage.isin(grp[(grp.str.contains('N09i-')&grp.str.contains('N09ii-')&(~grp.str.contains('N09iii-')))].index), 'pod_'] = 'A1-A4'
df.loc[df.passage.isin(grp[(grp.str.contains('N09i-')&grp.str.contains('N09iii-')&(~grp.str.contains('N09ii-')))].index), 'pod_'] = 'A1-A5'
df.loc[df.passage.isin(grp[(grp.str.contains('N09ii-')&grp.str.contains('N09iii-')&(~grp.str.contains('N09i-')))].index), 'pod_'] = 'A4-A5'
df.loc[df.passage.isin(grp[(grp.str.contains('N09i-')&grp.str.contains('N09ii-')&grp.str.contains('N09iii-'))].index), 'pod_'] = 'A1-A4-A5'
pods_ = ['A5', 'A4', 'A1', 'A1-A5', 'A1-A4', 'A1-A4-A5'] # dropped A4-A5 because too few seqs

print(f'{(~df.pod_.isna()).sum()} calls with inferred pod')

annot_min_nb_seqs = int(df[~df.pod.isin(['A4', 'A5-A4'])].groupby('pod').seq.nunique().min()*.5)
inferr_min_nb_seqs = int(df[df.pod_!='A4-A5'].groupby('pod_').seq.nunique().min()*.5)
print(annot_min_nb_seqs, 'annot_min_nb_seqs', df[~df.pod.isin(['A4', 'A5-A4'])].groupby('pod').seq.nunique().idxmin())
print(inferr_min_nb_seqs, 'inferr_min_nb_seqs', df[df.pod_!='A4-A5'].groupby('pod_').seq.nunique().idxmin())

# sizes = [16+22, 13+7+20, 12+10+13] # tot nb indiv by pod A1(A34+A50+A54), A4(A35+A56+A24), A5(A23+A25+A42)
sizes = [11+4+8, 10+1+8, 5+3+7] # nb living indiv by pod A1(A34+A50+A54), A4(A35+A56+A24), A5(A23+A25+A42)

get_unigram = lambda df, types: df.type.value_counts().loc[types]/len(df)
get_entropy = lambda distrib : -sum([p*np.log2(p) if p > 0 else 0 for p in distrib]) / np.log2(len(distrib))

def get_bigram(df, types):
    type_idx = {t:i for i, t in enumerate(types)}
    values = np.zeros((len(types), len(types)))
    for t1, grp in df.groupby('type'):
        for t2, grpp in grp.groupby('totype'):
            if t2 == 'end':
                continue
            values[type_idx[t1], type_idx[t2]] = len(grpp) / len(grp)
    return values

def get_ER(df):
    types = df.type.unique()
    bigram = get_bigram(df, types)
    unigram = get_unigram(df, types)
    return sum([p1*get_entropy(bigram[i1]) for i1, p1 in enumerate(unigram)])

def get_assoc_rate(df, types):
    type_idx = {t:i for i, t in enumerate(list(types))}
    values = np.zeros((len(types), len(types)))
    for t1, grp in df.groupby('type'):
        for t2, grpp in df[df.seq.isin(grp.seq.unique())].groupby('type'):
            values[type_idx[t1], type_idx[t2]] = grpp.seq.nunique() / grp.seq.nunique()
    return values

def get_EAR(df):
    types = df.type.unique()
    assoc_rate = get_assoc_rate(df, types)
    unigram = get_unigram(df, types)
    return sum([p1*get_entropy(assoc_rate[i1]) for i1, p1 in enumerate(unigram)])


fig, ax = plt.subplots(nrows=4, ncols=2, figsize=(10, 7), sharex='col', sharey='row')
ax[0, 0].set_title('Sequences with annoted pods')
ax[0, 1].set_title('Sequences with inferred pods')
ax[0, 0].set_ylabel('N')
ax[1, 0].set_ylabel('H')
ax[2, 0].set_ylabel('ER')
ax[3, 0].set_ylabel('EAR')

Rb, Rh, Eb, Eh, ERb, ERh, EARb, EARh = [], [], [], [], [], [], [], []
for repeat in tqdm(range(100)):
    dfs_annot = [df[df.seq.isin(np.random.choice(df[df.pod==p].seq.unique(), annot_min_nb_seqs, replace=False))] for p in pods]
    dfs_inferr = [df[df.seq.isin(np.random.choice(df[df.pod_==p].seq.unique(), inferr_min_nb_seqs, replace=False))] for p in pods_]

    annot_R, inferr_R = ([df_.type.nunique() for df_ in dfs] for dfs in [dfs_annot, dfs_inferr])
    Rb.append(annot_R)
    Rh.append(inferr_R)

    annot_E, inferr_E = ([get_entropy(get_unigram(df_, df_.type.unique())) for df_ in dfs] for dfs in [dfs_annot, dfs_inferr])
    Eb.append(annot_E)
    Eh.append(inferr_E)
              
    annot_ER, inferr_ER = ([get_ER(df_) for df_ in dfs] for dfs in [dfs_annot, dfs_inferr])
    ERb.append(annot_ER)
    ERh.append(inferr_ER)
              
    annot_EAR, inferr_EAR = ([get_EAR(df_) for df_ in dfs] for dfs in [dfs_annot, dfs_inferr])
    EARb.append(annot_EAR)
    EARh.append(inferr_EAR)

Rb, Eb, ERb, EARb, Rh, Eh, ERh, EARh = [np.array(a) for a in [Rb, Eb, ERb, EARb, Rh, Eh, ERh, EARh]]

print('Annotated')
for i in range(Rb.shape[1]-1):
    for n, l in zip(['N','H','ER','EAR'], [Rb, Eb, ERb, EARb]):
        print('& '+n, end='')
        for j in range(1, Rb.shape[1]):
            if i < j:
                p = stats.kruskal(l[:,i], l[:,j]).pvalue
                print(" & " + (' ns ' if p>.001 else ' + ' if np.median(l[:,i])<np.median(l[:,j]) else ' - '), end='')
            else:
                print(" & ", end='')
        print('\\\\')
    print('\hline')
print('Automatic')
for i in range(Rh.shape[1]-1):
    for n, l in zip(['N','H','ER','EAR'], [Rh, Eh, ERh, EARh]):
        print('& '+n, end='')
        for j in range(1, Rh.shape[1]):
            if i < j:
                p = stats.kruskal(l[:,i], l[:,j]).pvalue
                print(" & " + (' ns ' if p>.001 else ' + ' if np.median(l[:,i])<np.median(l[:,j]) else ' - '), end='')
            else:
                print(" & ", end='')
        print('\\\\')
    print('\hline')

ax[0, 0].violinplot(Rb, positions=[0, 2])
ax[0, 1].violinplot(Rh, positions=[0,1,2,4,5,6])
ax[1, 0].violinplot(Eb, positions=[0, 2])
ax[1, 1].violinplot(Eh, positions=[0,1,2,4,5,6])
ax[2, 0].violinplot(ERb, positions=[0, 2])
ax[2, 1].violinplot(ERh, positions=[0,1,2,4,5,6])
ax[3, 0].violinplot(EARb, positions=[0, 2])
ax[3, 1].violinplot(EARh, positions=[0,1,2,4,5,6])


pods_ = ['A5', 'A4', 'A1', 'A4-A5', 'A1-A5', 'A1-A4', 'A1-A4-A5'] # readd A4-A5
pods = ['A5', 'A4', 'A1', 'A5-A4'] # readd A4 and A5-A4
dfs_annot = [df[df.pod==p] for p in pods]
dfs_inferr = [df[df.pod_==p] for p in pods_]

ax[0, 0].scatter(range(len(pods)), [df_.type.nunique() for df_ in dfs_annot])
ax[0, 1].scatter(range(len(pods_)), [df_.type.nunique() for df_ in dfs_inferr])
ax[1, 0].scatter(range(len(pods)), [get_entropy(get_unigram(df_, df_.type.unique())) for df_ in dfs_annot])
ax[1, 1].scatter(range(len(pods_)), [get_entropy(get_unigram(df_, df_.type.unique())) for df_ in dfs_inferr])
ax[2, 0].scatter(range(len(pods)), [get_ER(df_) for df_ in dfs_annot])
ax[2, 1].scatter(range(len(pods_)), [get_ER(df_) for df_ in dfs_inferr])
ax[3, 0].scatter(range(len(pods)), [get_EAR(df_) for df_ in dfs_annot])
ax[3, 1].scatter(range(len(pods_)), [get_EAR(df_) for df_ in dfs_inferr])

ax[3,0].set_xticks(range(len(pods)))
ax[3,0].set_xticklabels(['A5\n15','A4\n19','A1\n23','A5-A4\n34'])
ax[3,1].set_xticks(range(len(pods_)))
ax[3,1].set_xticklabels(['A5\n15', 'A4\n19', 'A1\n23', 'A4-A5\n34', 'A1-A5\n38', 'A1-A4\n42', 'A1-A4-A5\n57'])

ax[3,0].set_xlabel('Pod(s) and size')
ax[3,1].set_xlabel('Pod(s) and size')
for i in range(4):
    for j in range(2):
        ax[i, j].grid()
plt.tight_layout()
plt.savefig('complexity_vs_grpsize.pdf')
plt.show()
