import os
from torch.utils.tensorboard import SummaryWriter
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score, average_precision_score, ConfusionMatrixDisplay,multilabel_confusion_matrix,accuracy_score, roc_curve, auc
from torch import nn, utils, device, optim, save
from torch.utils import data
import pandas as pd
import matplotlib.pyplot as plt
import utils as u
import time
from tqdm import tqdm
from models import get
from itertools import cycle
import torchvision
import seaborn as sn
import io
from PIL import Image
import torchvision.transforms as transforms
from tensorflow.keras.layers import Concatenate
from sklearn import preprocessing
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

#df = pd.read_csv("./annot/ale/CS_P_H_HP.csv") ################################################################################### traindf a la place de df pour juste train
noms = ["Moqueur_grivotte","Paruline_caféiette","Colibri_madère","Colombe_rouviolette","Colombe_à_croissants","Pic_de_Guadeloupe","Tyran_janeau","Pigeon_à_couronne_blanche","Pigeon_à_cou_rouge","Saltator_gros_bec","Grive_à_pieds_jaunes"]

traindf = pd.read_csv("./annot/tri/CS_HP_SP.csv")###

testdf = pd.read_csv("./annot/tri/LS_HP_SP.csv")   ###

def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

def select_n_random(data, lab, n=100):
    le = preprocessing.LabelEncoder()
    lab = le.fit_transform(lab)
    lab = torch.as_tensor(lab)
    #assert len(data) == len(lab)
    perm = torch.randperm(len(data))
    return data[perm][:n], lab[perm][:n]

def id_lab_ligne1(mat_lab):
    mat_lab = mat_lab.tolist()
    for i in range(len(mat_lab)):
        mat_lab[i] = str(mat_lab[i])
    return mat_lab

def id_lab(mat_lab):
    mat_lab = mat_lab.tolist()
    for i in range(len(mat_lab)):
        for j in range(len(mat_lab[i])):
            mat_lab[i][j] = str(mat_lab[i][j])
    return mat_lab

nepoch = 300 # changer ca  = nb de tours
batch_size = 16 # + grand = + rapide (16)
modelname = 'HP_H_P_mod_16_lr97.stdc' ##################################################################################
writer = SummaryWriter('runs/'+modelname)
print('Go for model '+modelname)
lr = 0.05 ## 0.05
wdL2 = 0.002

#traindf, testdf = train_test_split(df, random_state=24, test_size=.1) ###########################################"""" 10%

model = get['main']

gpu = device('cuda:1')
model.to(gpu)
optimizer = optim.SGD(model.parameters(), weight_decay=wdL2, lr=lr, momentum=.9)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.97**epoch) ##0.95
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, factor=.8, cooldown=50) ###############"
loader = utils.data.DataLoader(u.Dataset(traindf), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
testLoader = utils.data.DataLoader(u.Dataset(testdf), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) ################### pas besoin si juste train
loss = nn.BCEWithLogitsLoss() #binary cross enthropie

#getit = traindf.__getitem__
#print(getit)
#print(loader)
#print(len(loader))


print('Started at ',time.ctime(time.time()))
step = 0
for epoch in range(nepoch): # TRAIN LOOP
    model.train()
    for batch in tqdm(loader, desc=str(epoch), total= len(loader), leave=False):
        x, label = batch
        x = x.to(gpu).unsqueeze(1) #sur la dimention 1
        optimizer.zero_grad()
        pred = model(x).cpu().squeeze()
        score = loss(pred, label.float())
        score.backward()
        optimizer.step()
        writer.add_scalar('loss', score.item(), step)
        #writer.add_scalar('f1 score', f1_score(label, pred.detach()), step)
        #writer.add_scalar('Accuracy classification score', accuracy_score(label, pred.detach()), step)  #normalize=False



        if step % 10 == 0:
            writer.add_scalar('train mAP', average_precision_score(label, pred.detach()), step)

        if (label.sum(0) > 0).all():
            writer.add_scalar('train AUC', roc_auc_score(label, pred.detach()), step)

        #fig, ax = plt.subplots()
        #im = ax.imshow(pred.detach(), origin='lower', aspect='auto')

        # Show all ticks and label them with the respective list entries
        #ax.set_xticks(np.arange(len(noms)), labels=noms)
        #ax.set_yticks(np.arange(len(noms)), labels=noms)

        # Rotate the tick labels and set their alignment.
        #plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

        #ax.set_title("titre")
        #fig.tight_layout()
        #plt.savefig('verif/' + str(epoch))

        #plt.close()

        step += 1

    ##### remettre ici


    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)


    #continue #######
    model.eval()
    preds, labels = [], []
    for batch in tqdm(testLoader, desc=str(epoch), total=len(testLoader), leave=False):
        x, label = batch
        x = x.to(gpu).unsqueeze(1) #################################"""
        pred = model(x).cpu().detach().squeeze() # suppr les dimentions en trop
        #print(pred.shape, label.shape)
        preds.extend(pred.numpy())
        labels.extend(label.numpy())
    scheduler.step()
        #print(label)
        #print(x)
    preds, labels = np.array(preds), np.array(labels)
    testloss = loss(torch.Tensor(preds), torch.Tensor(labels)) # cross entropie en gros c lerreur
    writer.add_scalar('test loss', testloss, epoch)
    writer.add_scalar('test mAP', average_precision_score(labels, preds), epoch)
    #writer.add_scalar('test acc', (preds.argmax(-1) == int(labels)).sum() / len(labels), epoch)
    #
    #cm = ConfusionMatrixDisplay.from_predictions(labels.argmax(axis=1), preds.argmax(axis=1), labels=np.arange(11), display_labels=[u.idxtotype[i] for i in range(11)], xticks_rotation='vertical', normalize = 'pred')
    #cm.figure_.canvas.draw()
    #img = np.frombuffer(cm.figure_.canvas.tostring_rgb(), dtype=np.uint8).reshape(
    #    (cm.figure_.canvas.get_width_height()[::-1] + (3,))).T.swapaxes(1, 2)
    #writer.add_image('confusion matrix', img, epoch)

    if (labels.sum(0) > 0).all():
        writer.add_scalar('test AUC', roc_auc_score(labels, preds), epoch) ########################### ajouter ex pos dans une classe ex augm 0.1 # auc = aire sous la courbe de roc

        if epoch % 50 == 0: # tous les 50 tours
            # Compute ROC curve and ROC area for each class
            n_classes = 11 ############ nb de label (especes)
            fpr = dict()
            tpr = dict()
            roc_auc = dict()

            for i in range(n_classes):
                fpr[noms[i]], tpr[noms[i]], _ = roc_curve(labels[:, i], preds[:, i])
                roc_auc[noms[i]] = auc(fpr[noms[i]], tpr[noms[i]])

            # Compute micro-average ROC curve and ROC area
            fpr["micro"], tpr["micro"], _ = roc_curve(labels.ravel(), preds.ravel())
            roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

            # First aggregate all false positive rates
            all_fpr = np.unique(np.concatenate([fpr[noms[i]] for i in range(n_classes)]))

            # Then interpolate all ROC curves at this points
            mean_tpr = np.zeros_like(all_fpr)
            for i in range(n_classes):
                mean_tpr += np.interp(all_fpr, fpr[noms[i]], tpr[noms[i]])

            # Finally average it and compute AUC
            mean_tpr /= n_classes

            fpr["macro"] = all_fpr
            tpr["macro"] = mean_tpr
            roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

            # Plot all ROC curves
            plt.figure(figsize=(25,10))
            lw = 2
            plt.plot(
                fpr["micro"],
                tpr["micro"],
                label="micro-average ROC curve (area = {0:0.2f})".format(roc_auc["micro"]),
                color="deeppink",
                linestyle=":",
                linewidth=4,
            )

            plt.plot(
                fpr["macro"],
                tpr["macro"],
                label="macro-average ROC curve (area = {0:0.2f})".format(roc_auc["macro"]),
                color="navy",
                linestyle=":",
                linewidth=4,
            )

            colors = cycle(["aqua", "darkorange", "cornflowerblue","b","g","r","m","y","lime","pink","silver"])
            for i, color in zip(range(n_classes), colors):
                plt.plot(
                    fpr[noms[i]],
                    tpr[noms[i]],
                    color=color,
                    lw=lw,
                    label="ROC curve of {0} (area = {1:0.2f})".format(noms[i], roc_auc[noms[i]]),
                )

            plt.plot([0, 1], [0, 1], "k--", lw=lw)
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel("False Positive Rate")
            plt.ylabel("True Positive Rate")
            plt.title("Receiver operating characteristic to multiclass")
            plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
            plt.savefig("./imagesROC/test ROC " + str(epoch) +" "+ str(modelname)+ ".png", bbox_inches='tight')
            #writer.add_image('test ROC', img_tensor, epoch)

    save(model.state_dict(), modelname)

# ex new csv = dissocier les deux en les attribuant au debut ou raj col


def plustard():
    # create grid of images #############################################""#############################################""
    img_grid = torchvision.utils.make_grid(x.cpu())
    # show images
    matplotlib_imshow(img_grid, one_channel=True)
    # write to tensorboard
    writer.add_image('4 images' + str(epoch), img_grid)

    images, labs = x, id_lab_ligne1(label[0])
    print(labs)

    # get the class labels for each image
    class_labels = []
    for i in range(len(labs)):
        # class_idx = labs[i].tolist().index("1") ### multi
        class_idx = labs.index("1")
        class_labels.append(noms[class_idx])

    # log embeddings
    # features = images.view(-1, 942 * 942)
    writer.add_embedding(images, metadata=class_labels,
                         label_img=images.unsqueeze(1))
    writer.close()

    writer.add_graph(model, x)
    writer.close()