import pandas as pd, numpy as np
from scipy import stats
from tqdm import tqdm
import matplotlib.pyplot as plt

df = pd.read_pickle('seqs.pkl')
df['year'] = df.date.dt.year
df['type_'] = df.type + '-'
grp = df.groupby('passage').type_.sum()
df = df[df.passage.isin(grp[(grp.str.contains('N09i-')&(~grp.str.contains('N09iii-|N09ii-')))].index)]

min_nb_seqs = int(df.groupby('year').seq.nunique().min()*.5)

get_unigram = lambda df, types: df.type.value_counts().loc[types]/len(df)
get_entropy = lambda distrib : -sum([p*np.log2(p) if p > 0 else 0 for p in distrib]) / np.log2(len(distrib))

def get_bigram(df, types):
    type_idx = {t:i for i, t in enumerate(types)}
    values = np.zeros((len(types), len(types)))
    for t1, grp in df.groupby('type'):
        for t2, grpp in grp.groupby('totype'):
            if t2 == 'end':
                continue
            values[type_idx[t1], type_idx[t2]] = len(grpp) / len(grp)
    return values

def get_ER(df):
    types = df.type.unique()
    bigram = get_bigram(df, types)
    unigram = get_unigram(df, types)
    return sum([p1*get_entropy(bigram[i1]) for i1, p1 in enumerate(unigram)])

def get_assoc_rate(df, types):
    type_idx = {t:i for i, t in enumerate(list(types))}
    values = np.zeros((len(types), len(types)))
    for t1, grp in df.groupby('type'):
        for t2, grpp in df[df.seq.isin(grp.seq.unique())].groupby('type'):
            values[type_idx[t1], type_idx[t2]] = grpp.seq.nunique() / grp.seq.nunique()
    return values

def get_EAR(df):
    types = df.type.unique()
    assoc_rate = get_assoc_rate(df, types)
    unigram = get_unigram(df, types)
    return sum([p1*get_entropy(assoc_rate[i1]) for i1, p1 in enumerate(unigram)])


fig, ax = plt.subplots(nrows=4, sharex=True)

R, E, ER, EAR = [], [], [], []
for repeat in tqdm(range(100)):
    dfs = [df[df.seq.isin(np.random.choice(df[df.year==p].seq.unique(), min_nb_seqs, replace=False))] for p in df.year.unique()]

    R.append([df_.type.nunique() for df_ in dfs])
    E.append([get_entropy(get_unigram(df_, df_.type.unique())) for df_ in dfs])         
    ER.append([get_ER(df_) for df_ in dfs])              
    EAR.append([get_EAR(df_) for df_ in dfs])

R, E, ER, EAR = [np.array(a) for a in [R, E, ER, EAR]]

for i in range(R.shape[1]-1):
    for n, l in zip(['N','H','ER','EAR'], [R, E, ER, EAR]):
        print('& '+n, end='')
        for j in range(1, R.shape[1]):
            if i < j:
                p = stats.kruskal(l[:,i], l[:,j]).pvalue
                print(" & " + (' ns ' if p>.001 else ' + ' if np.median(l[:,i])<np.median(l[:,j]) else ' - '), end='')
            else:
                print(" & ", end='')
        print('\\\\')
    print('\hline')

ax[0].boxplot(R, positions=df.year.unique())
ax[1].boxplot(E, positions=df.year.unique())
ax[2].boxplot(ER, positions=df.year.unique())
ax[3].boxplot(EAR, positions=df.year.unique())
plt.savefig('complexity_year.pdf')
plt.show()
