import os
from torch.utils.tensorboard import SummaryWriter
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score, average_precision_score, ConfusionMatrixDisplay,multilabel_confusion_matrix,accuracy_score, roc_curve, auc
from torch import nn, utils, device, optim, save
from torch.utils import data
import pandas as pd
import matplotlib.pyplot as plt
import utils as u
import time
from tqdm import tqdm
from models import get
from itertools import cycle
import seaborn as sn
import io
from PIL import Image
import torchvision.transforms as transforms
import tensorflow as tf
from tensorflow.keras.layers import Concatenate




#df = pd.read_csv("./annot/ale/CS_P_H_HP.csv") ################################################################################### traindf a la place de df pour juste train
noms = ["Moqueur_grivotte","Paruline_caféiette","Colibri_madère","Colombe_rouviolette","Colombe_à_croissants","Pic_de_Guadeloupe","Tyran_janeau","Pigeon_à_couronne_blanche","Pigeon_à_cou_rouge","Saltator_gros_bec","Grive_à_pieds_jaunes"]

#traindf = pd.read_csv("./annot/grad/CS_HP.csv") ###
#testdf = pd.read_csv("./annot/grad/LS_HP.csv")   ###


# train
traindf_left_input = pd.read_csv("./annot/grad/CS_HP.csv")
traindf_right_input = pd.read_csv("./annot/grad/CS_H.csv")

# test
testdf_left_input = pd.read_csv("./annot/grad/LS_HP.csv")
testdf_right_input = pd.read_csv("./annot/grad/LS_H.csv")


nepoch = 250 # changer ca  = nb de tours
batch_size = 16 # + grand = + rapide
modelname = 'HP+H_LF.stdc' ##################################################################################
writer = SummaryWriter('runs/'+modelname)
print('Go for model '+modelname)
lr = 0.05
wdL2 = 0.002


#traindf, testdf = train_test_split(df, random_state=24, test_size=.1) ###########################################"""" 10%

model = get['main']

gpu = device('cuda:1')
model.to(gpu)
optimizer = optim.SGD(model.parameters(), weight_decay=wdL2, lr=lr, momentum=.9)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.97**epoch) ###############"

loader_l = utils.data.DataLoader(u.Dataset(traindf_left_input), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
loader_r = utils.data.DataLoader(u.Dataset(traindf_right_input), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

testLoader_l = utils.data.DataLoader(u.Dataset(testdf_left_input), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
testLoader_r = utils.data.DataLoader(u.Dataset(testdf_right_input), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

#testLoader = utils.data.DataLoader(u.Dataset(testdf), batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

loss = nn.BCEWithLogitsLoss() #binary cross enthropie


print('Started at ',time.ctime(time.time()))
step = 0
for epoch in range(nepoch):  # TRAIN LOOP
    model.train()
    for batch_l, batch_r in tqdm(zip(loader_l, loader_r), desc=str(epoch), total=len(loader_l), leave=False):
        x, label_l = batch_l
        x = x.to(gpu).unsqueeze(1)  # sur la dimention 1

        xx, label_r = batch_r
        xx = xx.to(gpu).unsqueeze(1)  # sur la dimention 1
        optimizer.zero_grad()

        pred_l = model(x).cpu().squeeze()
        pred_r = model(xx).cpu().squeeze()



    #concatenation des 2 modeles

    #loader = Concatenate()([loader_l,loader_r])
    #loader = tf.concat([loader_l,loader_r], axis=1)
    #stockage = pred_l[-1][-1]
    #stock = stockage[-1]

        #pred = torch.stack([pred_l,pred_r],dim=0)
        pred = torch.cat([pred_l,pred_r], dim = 0)
        label = torch.cat([label_l, label_r], dim=0)
        #label = torch.stack([label_l,label_r], dim=0)

        #score_l = loss(pred_l, label.float())
        #score_l.backward()
        #score_r = loss(pred_r, label.float())
        #score_r.backward()

        score = loss(pred, label.float())
        score.backward()

        optimizer.step()

        #writer.add_scalar('loss_l', score_l.item(), step)
        #writer.add_scalar('loss_r', score_r.item(), step)

        #writer.add_scalar('loss', score.item(), step)

        #writer.add_scalar('f1 score', f1_score(label, pred.detach()), step)
        #writer.add_scalar('Accuracy classification score', accuracy_score(label, pred.detach()), step)  #normalize=False

        if step % 10 == 0:
            writer.add_scalar('train mAP', average_precision_score(label, pred.detach()), step)
        if (label.sum(0) > 0).all():
            writer.add_scalar('train AUC', roc_auc_score(label, pred.detach()), step)

        step += 1

    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)


    #continue #######
    model.eval()
    preds, labels= [], []
    for batch_l, batch_r in tqdm(zip(testLoader_l, testLoader_r), desc=str(epoch), total=len(testLoader_l), leave=False):
        x, label_l = batch_l
        x = x.to(gpu).unsqueeze(1)

        xx, label_r = batch_r
        xx = xx.to(gpu).unsqueeze(1)

        pred_l = model(x).cpu().detach().squeeze()
        pred_r = model(xx).cpu().detach().squeeze()

        pred_c = torch.cat([pred_l, pred_r], dim=0)
        label_c = torch.cat([label_l, label_r], dim=0)

        preds.extend(pred_c.numpy())
        labels.extend(label_c.numpy())

    scheduler.step()

    preds, labels = np.array(preds), np.array(labels)

    testloss = loss(torch.Tensor(preds), torch.Tensor(labels)) # cross entropie en gros c lerreur voir le nn. blabla au dessu
    writer.add_scalar('test loss', testloss, epoch)
    writer.add_scalar('test mAP', average_precision_score(labels, preds), epoch)

    if (labels.sum(0) > 0).all():
        writer.add_scalar('test AUC', roc_auc_score(labels, preds), epoch) ########################### ajouter ex pos dans une classe ex augm 0.1 # auc = aire sous la courbe de roc

        if epoch % 50 == 0: # tous les 50 tours
            # Compute ROC curve and ROC area for each class
            n_classes = 11 ############ nb de label (especes)
            fpr = dict()
            tpr = dict()
            roc_auc = dict()

            for i in range(n_classes):
                fpr[i], tpr[i], _ = roc_curve(labels[:, i], preds[:, i])
                roc_auc[i] = auc(fpr[i], tpr[i])

            # Compute micro-average ROC curve and ROC area
            fpr["micro"], tpr["micro"], _ = roc_curve(labels.ravel(), preds.ravel())
            roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

            # First aggregate all false positive rates
            all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

            # Then interpolate all ROC curves at this points
            mean_tpr = np.zeros_like(all_fpr)
            for i in range(n_classes):
                mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])

            # Finally average it and compute AUC
            mean_tpr /= n_classes

            fpr["macro"] = all_fpr
            tpr["macro"] = mean_tpr
            roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

            # Plot all ROC curves
            plt.figure(figsize=(25,10))
            lw = 2
            plt.plot(
                fpr["micro"],
                tpr["micro"],
                label="micro-average ROC curve (area = {0:0.2f})".format(roc_auc["micro"]),
                color="deeppink",
                linestyle=":",
                linewidth=4,
            )

            plt.plot(
                fpr["macro"],
                tpr["macro"],
                label="macro-average ROC curve (area = {0:0.2f})".format(roc_auc["macro"]),
                color="navy",
                linestyle=":",
                linewidth=4,
            )

            colors = cycle(["aqua", "darkorange", "cornflowerblue","b","g","r","m","y","lime","pink","silver"])
            for i, color in zip(range(n_classes), colors):
                plt.plot(
                    fpr[i],
                    tpr[i],
                    color=color,
                    lw=lw,
                    label="ROC curve of {0} (area = {1:0.2f})".format(noms[i], roc_auc[i]),
                )

            plt.plot([0, 1], [0, 1], "k--", lw=lw)
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel("False Positive Rate")
            plt.ylabel("True Positive Rate")
            plt.title("Receiver operating characteristic")
            plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
            plt.savefig("./imagesROC/test ROC " + str(epoch) +" "+ str(modelname)+ ".png", bbox_inches='tight')
            #writer.add_image('test ROC', img_tensor, epoch)

    save(model.state_dict(), modelname)

# ex new csv = dissocier les deux en les attribuant au debut ou raj col





def uvhsqijbvijbdqbfdfbrhergzrv():
    for epoch in range(nepoch): # TRAIN LOOP
        model.train()
        for i in tqdm(range(len(loader_l))):

            LL = iter(loader_l)
            x, label_l = next(LL)
            x = x.to(gpu).unsqueeze(1)  # sur la dimention 1

            RR = iter(loader_r)
            xx, label_r = next(RR)
            xx = xx.to(gpu).unsqueeze(1)  # sur la dimention 1
            optimizer.zero_grad()

            pred_l = model(x).cpu().squeeze()
            pred_r = model(xx).cpu().squeeze()

            optimizer.step()