import os
from torch.utils.tensorboard import SummaryWriter
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score, average_precision_score, ConfusionMatrixDisplay,multilabel_confusion_matrix,accuracy_score, roc_curve, auc
from torch import nn, utils, device, optim, save
from torch.utils import data
import pandas as pd
import matplotlib.pyplot as plt
import utils as u
import time
from tqdm import tqdm
from models import get
from itertools import cycle
import seaborn as sn
import io
from PIL import Image
import torchvision.transforms as transforms
from tensorflow.keras.layers import Concatenate


df = pd.read_csv("./annot/tri/CSLS_HP_SPCG.csv") ################################################################################### traindf a la place de df pour juste train
noms = ["Moqueur_grivotte","Paruline_caféiette","Colombe_rouviolette","Colombe_à_croissants","Pic_de_Guadeloupe","Tyran_janeau","Saltator_gros_bec"]

#traindf = pd.read_csv("./annot/tri/CS_HP_SPC.csv")###

#testdf = pd.read_csv("./annot/tri/LS_HP_SPC.csv")   ###


nepoch = 1000 # changer ca  = nb de tours
batch_size = 16 # + grand = + rapide (16)
modelname = 'test4.stdc' ##################################################################################
writer = SummaryWriter('runs/'+modelname)
print('Go for model '+modelname)
lr = 0.05 ## 0.05
wdL2 = 0.002

traindf, testdf = train_test_split(df, random_state=24, test_size=.4) ###########################################"""" 40%

model = get['main']

gpu = device('cuda:1')
model.to(gpu)
optimizer = optim.SGD(model.parameters(), weight_decay=wdL2, lr=lr, momentum=.9)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.99**epoch)
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, factor=.8, cooldown=50) ###############"
loader = utils.data.DataLoader(u.Dataset(traindf), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
testLoader = utils.data.DataLoader(u.Dataset(testdf), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) ################### pas besoin si juste train
loss = nn.BCEWithLogitsLoss() #binary cross enthropie

#getit = traindf.__getitem__
#print(getit)
#print(loader)
#print(len(loader))

print('Started at ',time.ctime(time.time()))
step = 0
for epoch in range(nepoch): # TRAIN LOOP
    model.train()
    for batch in tqdm(loader, desc=str(epoch), total= len(loader), leave=False):
        x, label = batch
        x = x.to(gpu).unsqueeze(1) #sur la dimention 1
        optimizer.zero_grad()
        pred = model(x).cpu().squeeze()
        score = loss(pred, label.float())
        score.backward()
        optimizer.step()
        writer.add_scalar('loss', score.item(), step)
        #writer.add_scalar('f1 score', f1_score(label, pred.detach()), step)
        #writer.add_scalar('Accuracy classification score', accuracy_score(label, pred.detach()), step)  #normalize=False
        if step % 10 == 0:
            writer.add_scalar('train mAP', average_precision_score(label, pred.detach()), step)
        if (label.sum(0) > 0).all():
            writer.add_scalar('train AUC', roc_auc_score(label, pred.detach()), step)

        #writer.add_graph(model, x)

        #fig, ax = plt.subplots()
        #im = ax.imshow(pred.detach(), origin='lower', aspect='auto')

        # Show all ticks and label them with the respective list entries
        #ax.set_xticks(np.arange(len(noms)), labels=noms)
        #ax.set_yticks(np.arange(len(noms)), labels=noms)

        # Rotate the tick labels and set their alignment.
        #plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")

        #ax.set_title("titre")
        #fig.tight_layout()
        #plt.savefig('verif/' + str(epoch))

        #plt.close()

        step += 1


    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)


    #continue #######
    model.eval()
    preds, labels = [], []
    for batch in tqdm(testLoader, desc=str(epoch), total=len(testLoader), leave=False):
        x, label = batch
        x = x.to(gpu).unsqueeze(1) #################################"""
        pred = model(x).cpu().detach().squeeze() # suppr les dimentions en trop
        #print(pred.shape, label.shape)
        preds.extend(pred.numpy())
        labels.extend(label.numpy())
    scheduler.step()
        #print(label)
        #print(x)
    preds, labels = np.array(preds), np.array(labels)
    testloss = loss(torch.Tensor(preds), torch.Tensor(labels)) # cross entropie en gros c lerreur
    writer.add_scalar('test loss', testloss, epoch)
    writer.add_scalar('test mAP', average_precision_score(labels, preds), epoch)
    #writer.add_scalar('test acc', (preds.argmax(-1) == int(labels)).sum() / len(labels), epoch)
    #
    #cm = ConfusionMatrixDisplay.from_predictions(labels.argmax(axis=1), preds.argmax(axis=1), labels=np.arange(11), display_labels=[u.idxtotype[i] for i in range(11)], xticks_rotation='vertical', normalize = 'pred')
    #cm.figure_.canvas.draw()
    #img = np.frombuffer(cm.figure_.canvas.tostring_rgb(), dtype=np.uint8).reshape(
    #    (cm.figure_.canvas.get_width_height()[::-1] + (3,))).T.swapaxes(1, 2)
    #writer.add_image('confusion matrix', img, epoch)

    if (labels.sum(0) > 0).all():
        writer.add_scalar('test AUC', roc_auc_score(labels, preds), epoch) ########################### ajouter ex pos dans une classe ex augm 0.1 # auc = aire sous la courbe de roc

        if epoch % 50 == 0: # tous les 50 tours
            # Compute ROC curve and ROC area for each class
            n_classes = 7 ############ nb de label (especes)
            fpr = dict()
            tpr = dict()
            roc_auc = dict()

            for i in range(n_classes):
                fpr[i], tpr[i], _ = roc_curve(labels[:, i], preds[:, i])
                roc_auc[i] = auc(fpr[i], tpr[i])

            # Compute micro-average ROC curve and ROC area
            fpr["micro"], tpr["micro"], _ = roc_curve(labels.ravel(), preds.ravel())
            roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

            # First aggregate all false positive rates
            all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

            # Then interpolate all ROC curves at this points
            mean_tpr = np.zeros_like(all_fpr)
            for i in range(n_classes):
                mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])

            # Finally average it and compute AUC
            mean_tpr /= n_classes

            fpr["macro"] = all_fpr
            tpr["macro"] = mean_tpr
            roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

            # Plot all ROC curves
            plt.figure(figsize=(25,10))
            lw = 2
            plt.plot(
                fpr["micro"],
                tpr["micro"],
                label="micro-average ROC curve (area = {0:0.2f})".format(roc_auc["micro"]),
                color="deeppink",
                linestyle=":",
                linewidth=4,
            )

            plt.plot(
                fpr["macro"],
                tpr["macro"],
                label="macro-average ROC curve (area = {0:0.2f})".format(roc_auc["macro"]),
                color="navy",
                linestyle=":",
                linewidth=4,
            )

            colors = cycle(["aqua", "darkorange","b","g","r","m","pink"]) ##############
            for i, color in zip(range(n_classes), colors):
                plt.plot(
                    fpr[i],
                    tpr[i],
                    color=color,
                    lw=lw,
                    label="ROC curve of {0} (area = {1:0.2f})".format(noms[i], roc_auc[i]),
                )

            plt.plot([0, 1], [0, 1], "k--", lw=lw)
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel("False Positive Rate")
            plt.ylabel("True Positive Rate")
            plt.title("Receiver operating characteristic")
            plt.legend(loc='lower right')
            plt.savefig("./imagesROC/test ROC " + str(epoch) +" "+ str(modelname)+ ".png", bbox_inches='tight')
            #writer.add_image('test ROC', img_tensor, epoch)

    save(model.state_dict(), modelname)

# ex new csv = dissocier les deux en les attribuant au debut ou raj col
