import torch
import torchvision
import torchsummary
from torch.utils.tensorboard import SummaryWriter
import time

from Audiodataset import *
from Model1d import AE

import mdct
import matplotlib.pyplot as plt

path_train_dataset = "../dataset-voice//"
path_valid_dataset = "../dataset-voice//"
path_test_dataset = "../dataset-voice//"

SR = 44100  # Audio signal sample rate

def train(nb_epochs):
    """
    Train a model for nb_epochs, then save the trained weights
    Args:
        nb_epochs (int): number of epochs
    """
    # Test if GPU is available
    if torch.cuda.is_available():
        #By default GPU cluster use GPU device
        device = torch.device("cuda:"+str(torch.cuda.current_device()))
        print("Running on the GPU")
    else:
        device = torch.device("cpu")
        print("Running on the CPU")

    # Init Tensorboard
    writer = SummaryWriter("runs/run_%d"%int(time.time()))

    # Load datasets
    composed_train = torchvision.transforms.Compose([RandomCrop(3, SR), FadeInOut(0.05, SR), ToTensor()])
    composed_valid = torchvision.transforms.Compose([RandomCrop(3, SR), FadeInOut(0.05, SR), ToTensor()])

    trainset = AudioDataset(path_train_dataset, transform=composed_train)
    validset = AudioDataset(path_valid_dataset, transform=composed_valid)

    batch_size = 8
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, pin_memory=True, shuffle=True, drop_last=True)#, num_workers=8, prefetch_factor=2)
    validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size, pin_memory=True, shuffle=False, drop_last=True)#, num_workers=8, prefetch_factor=2)
    nb_batch = len(trainloader)

    # Init model
    autoencoder = AE(16)
    autoencoder.to(device)

    # Print model information
    if torch.cuda.is_available():
        torchsummary.summary(autoencoder, (1, 3*SR), device='cuda')
    else:
        torchsummary.summary(autoencoder, (1, 3*SR), device='cpu')


    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-4)
    criterion = torch.nn.MSELoss(reduction='mean').to(device)

    for epoch in range(nb_epochs):
        import ipdb; ipdb.set_trace()
        for idx_batch, data in enumerate(trainloader):
            optimizer.zero_grad()

            input_data = torch.unsqueeze(data['noisy'], 1)
            input_data = input_data.to(device)
            target_data = torch.unsqueeze(data['clean'], 1)
            target_data = target_data.to(device)

            # Forward
            denoise_data = autoencoder(input_data)
            denoise_data = denoise_data[:, :, :input_data.shape[-1]]

            # Compute Loss
            loss = criterion(target_data, denoise_data)

            # Back propagation
            loss.backward()
            # Update weights
            optimizer.step()

            # Monitor learning process
            if (idx_batch+1) % 10 == 0:
                print("%03d / %03d"%(idx_batch, nb_batch-1))
                mean_batch_loss = loss.item() / batch_size
                print(mean_batch_loss)
                writer.add_scalar('Train/MSE', mean_batch_loss, (epoch*nb_batch+idx_batch))

                plt.figure()
                plt.plot(target_data[0, 0].detach().cpu())
                plt.savefig("res/%02d_%03d_target.png"%(epoch, idx_batch))
                plt.close()

                plt.figure()
                plt.plot(denoise_data[0, 0].detach().cpu())
                plt.savefig("res/%02d_%03d_denoise.png"%(epoch, idx_batch))
                plt.close()

                export_audio(input_data[0], target_data[0], denoise_data[0])

                # Save weights of the model
                save_model("model1", autoencoder, optimizer, epoch)

        # Compute validation test
        valid_loss = valid()
        mean_valid_loss = valid_loss / len(validset)
        writer.add_scalar('Valid/MSE', mean_valid_loss, (epoch*nb_batch+idx_batch))

        # Save final model
        save_model("model1", autoencoder, optimizer, epoch)

    writer.close()



def save_model(save_name, model, optimizer, epoch):
    """
    Args:
        save_name (string): name of the [save_file].pth
        model (torch model): trained model
        optimizer (torch optim): current optimizer
        epoch (int): current epoch
    """
    torch.save({
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }, "model_save/%s.pth"%(save_name))

def load_model(save_name, model, optimizer):
    """
    Args:
        save_name (string): name of the [save_file].pth
        model (torch model): Initialized model
        epoch (int): current epoch
    """
    checkpoint = torch.load("model_save/%s.pth"%(save_name))
    epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer, epoch


def export_audio(data_input, data_target, data_output):
    """
    Convert the torch tensors (input, target, output) to audio signals and write the wavefile in folder "res"
    Args:
        data_input (torch.tensor): one input signal of the batch
        data_target (torch.tensor): one target signal of the batch
        data_output (torch.tensor): one output signal of the batch
    """
    sig_input = np.array(data_input.detach().cpu().squeeze())
    sig_target = np.array(data_target.detach().cpu().squeeze())
    sig_output = np.array(data_output.detach().cpu().squeeze())
    t = int(time.time())
    sf.write("res/noisy_%d.wav"%t, sig_input, SR)
    sf.write("res/clean_%d.wav"%t, sig_target, SR)
    sf.write("res/denoise_%d.wav"%t, sig_output, SR)


#TODO voir collab
#TODO excrire le texte
#TODO forward

def valid():
    """
    TO BE DONE
    """
    return 0



if __name__ == "__main__":
    nb_epochs = 5
    train(nb_epochs)
