import numpy as np
import torch
import torch.nn as nn
import math
from torchaudio.transforms import Resample
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import torch.nn.functional as F


def gen_conv_layer(num):
    seq = [nn.Conv1d(1, 16, 3, 1, 1),
           nn.LeakyReLU(),
           nn.Conv1d(16, 32, 5, 1, 2),
           nn.LeakyReLU(),
           nn.Conv1d(32, 32, 7, 1, 3),
           nn.LeakyReLU(),
           nn.Conv1d(32, 32, 11, 1, 5),
           nn.LeakyReLU(),
           nn.BatchNorm1d(32)]
    for p in np.flip(np.arange(1, num + 1) * 2):        # Insertion de MaxPool plus num est grand plus on ajoute des MaxPool
        seq.insert(p, nn.MaxPool1d(2, 2))               # C'est mieux d'ajouter les MaxPool au debut ?
    return nn.Sequential(*seq)


class Detector(nn.Module):
    def __init__(self):
        super().__init__()
        self.decimate = Resample(128_000, 64_000)
        self.conv_128 = gen_conv_layer(4)
        self.conv_64 = gen_conv_layer(3)
        self.conv_32 = gen_conv_layer(2)
        self.conv_16 = gen_conv_layer(1)

        self.cnn = nn.Sequential(nn.Conv1d(128, 128, 3, 1, 1),
                                 nn.MaxPool1d(2, 2),
                                 nn.LeakyReLU(),
                                 nn.Conv1d(128, 128, 3, 1, 1),
                                 nn.MaxPool1d(2, 2),
                                 nn.LeakyReLU(),
                                 nn.Conv1d(128, 128, 3, 1, 1),
                                 nn.MaxPool1d(2, 2),
                                 nn.LeakyReLU(),
                                 nn.Flatten())

        self.mlp = nn.Sequential(nn.Linear(512, 128),
                                 nn.LeakyReLU(),
                                 nn.Linear(128, 32),
                                 nn.LeakyReLU(),
                                 nn.Linear(32, 32),
                                 nn.LeakyReLU())
        self.last = nn.Linear(32, 2)            # A partir des 32 embeddings on sort une classifciation prob dauphin, prob bruit

    def resample(self, x):
        x64 = self.decimate(x)
        x32 = self.decimate(x64)
        x16 = self.decimate(x32)
        return x, x64, x32, x16

    def forward(self, x):
        # assume x.shape == (512 @ 256kHz)
        x128, x64, x32, x16 = self.resample(x)
        cat = torch.cat((self.conv_128(x128),
                         self.conv_64(x64),
                         self.conv_32(x32),
                         self.conv_16(x16)), -2)
        flat = self.cnn(cat)
        emb = self.mlp(flat)
        
        return self.last(emb), emb

class Context2(nn.Module):
    def __init__(self, input_size=32, nb_featuremap=256):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(input_size, nb_featuremap, stride=2, kernel_size=5),
            nn.BatchNorm1d(nb_featuremap),
            nn.LeakyReLU(),
            nn.Conv1d(nb_featuremap, nb_featuremap, stride=2, kernel_size=5),
            nn.BatchNorm1d(nb_featuremap),
            nn.LeakyReLU(),
            nn.Conv1d(nb_featuremap, 2, stride=2, kernel_size=5),
            nn.LeakyReLU(),
            nn.AdaptiveMaxPool1d(1),
            nn.Flatten()
        )

    def forward(self, x):
        return self.cnn(x)


class Context(nn.Module):
    def __init__(self, input_size=32, nb_featuremap=64, krnl_size=3):

        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(input_size, nb_featuremap, kernel_size=krnl_size, stride=(1), padding=(krnl_size//2)),
            nn.BatchNorm1d(nb_featuremap),
            nn.MaxPool1d(2),
            nn.LeakyReLU(),
            nn.Conv1d(nb_featuremap, nb_featuremap, kernel_size=krnl_size, stride=(1), padding=(krnl_size//2)),
            nn.BatchNorm1d(nb_featuremap),
            nn.MaxPool1d(2),
            nn.LeakyReLU(),
            nn.Conv1d(nb_featuremap, nb_featuremap, kernel_size=krnl_size, stride=(1), padding=(krnl_size//2)),
            nn.BatchNorm1d(nb_featuremap),
            nn.MaxPool1d(2),
            nn.LeakyReLU(),
            nn.Conv1d(nb_featuremap, nb_featuremap, kernel_size=krnl_size, stride=(1), padding=(krnl_size//2)),
            nn.BatchNorm1d(nb_featuremap),
            nn.MaxPool1d(2),
            nn.LeakyReLU(),
            nn.Conv1d(nb_featuremap, nb_featuremap, kernel_size=krnl_size, stride=(1), padding=(krnl_size//2)),
            nn.BatchNorm1d(nb_featuremap),
            nn.LeakyReLU(),
            nn.Flatten()
        )
        self.mlp = nn.Sequential(
            nn.Linear(nb_featuremap * 64, input_size), #same output size as input embedding
            nn.LeakyReLU(),
        )
        self.last = nn.Linear(input_size, 2)

    def forward(self, x):
        flat = self.cnn(x)
        emb = self.mlp(flat)
        return self.last(emb), emb


class Context_dil_2D(nn.Module):
    def __init__(self, input_size=1, input_len=512):
        krnl_size = 3
        nb_featuremap = 64
        emb_size = 32

        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(input_size, nb_featuremap, kernel_size=krnl_size, stride=(2), dilation=(2), padding=(krnl_size//2+1)),
            nn.LeakyReLU(),
            nn.Conv2d(nb_featuremap, nb_featuremap, kernel_size=krnl_size, stride=(2), dilation=(2), padding=(krnl_size//2+1)),
            nn.LeakyReLU(),
            nn.Conv2d(nb_featuremap, nb_featuremap, kernel_size=krnl_size, stride=(2), dilation=(2), padding=(krnl_size//2+1)),
            nn.LeakyReLU(),
            nn.Conv2d(nb_featuremap, nb_featuremap, kernel_size=krnl_size, stride=(2), dilation=(2), padding=(krnl_size//2+1)),
            nn.LeakyReLU(),
            nn.Conv2d(nb_featuremap, nb_featuremap, kernel_size=krnl_size, stride=(1), dilation=(1), padding=(krnl_size//2)),
            nn.LeakyReLU(),
            # #nn.BatchNorm2d(nb_featuremap),
            # #nn.MaxPool1d(2),        #Faut il ajouter des maxpool ?
            nn.Flatten()
        )
        self.mlp = nn.Sequential(
            nn.Linear(nb_featuremap * input_len//(2**3), emb_size), #same output size as input embedding
            nn.LeakyReLU(),
        )
        self.last = nn.Linear(emb_size, 2)

    def forward(self, x):
        flat = self.cnn(x)
        emb = self.mlp(flat)
        return self.last(emb), emb

class Context_rnn_1d(nn.Module):
    def __init__(self):
        emb_size = 32
        lstm_hidden_size = 64 #16 #64 #128 #64 #16
        lstm_num_layers = 1 #2 #1
        
        super().__init__()
        self.rnn = nn.Sequential(
            #nn.Conv2d(???, 1, kernel_size=(8, 1), stride=(8, 1)), # TODO: remplacer le maxpool et trouver le input_dim
            #nn.MaxPool2d((8, 1)),
            nn.LSTM(input_size= emb_size, hidden_size=lstm_hidden_size//2, num_layers=lstm_num_layers, batch_first=True, bidirectional=True),
        )
        self.mlp = nn.Sequential(
            nn.Linear(lstm_hidden_size, emb_size),
            nn.LeakyReLU()
        )
        self.last = nn.Linear(emb_size, 2)


    def forward(self, x):
        flat, _ = self.rnn(x)
        flat = nn.functional.max_pool2d(flat.unsqueeze(axis=1),(flat.shape[1], 1)).squeeze(axis=1)
        emb = self.mlp(flat[:, -1, :])
        return self.last(emb), emb


class Context_ViT_1d(nn.Module):
    def __init__(self, input_len=128):
        super().__init__()
        emb_size = 32  
        ViT_output_size = 32

        num_layers = 6

        self.classtoken = nn.Parameter(torch.randn(1,1,emb_size), requires_grad=True)
        self.positional_embedding = nn.Parameter(torch.randn(1, input_len+1, emb_size))

        self.encoders = nn.TransformerEncoder(encoder_layer= nn.TransformerEncoderLayer(d_model=emb_size,\
            nhead=8, dim_feedforward=1024, dropout=0.1, activation='gelu', batch_first=True, norm_first=True)\
            , num_layers=num_layers)
        
        # self.encoders = \ # nn.Sequential(
        #     #nn.MaxPool2d((8, 1)),
        #     nn.TransformerEncoder(encoder_layer= nn.TransformerEncoderLayer(d_model=emb_size,\
        #     nhead=8, dim_feedforward=1024, dropout=0.1, activation='gelu', batch_first=True, norm_first=True)\
        #         , num_layers=num_layers)
        #       #  )

        self.dropout = nn.Dropout(p=0.1)

        self.mlp = nn.Sequential(
            nn.Linear(ViT_output_size, emb_size),
            nn.LeakyReLU()
        )
        self.last = nn.Linear(emb_size, 2)

    def forward(self, x):
        batch_size = x.shape[0]
        cls_token = self.classtoken.expand(batch_size, -1, -1)
        x = torch.cat((cls_token,x), dim=1)
        x = self.positional_embedding + x
        x = self.dropout(x)

        #import ipdb; ipdb.set_trace()
        flat = self.encoders(x)
        emb = self.mlp(flat[:, 0,:])
        return self.last(emb), emb


class Context_dil(nn.Module):
    def __init__(self, input_size=32,  input_len=512):
        krnl_size = 3
        nb_featuremap = 64
        
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(input_size, nb_featuremap, kernel_size=krnl_size, stride=(2), dilation=(2), padding=(krnl_size//2+1)),
            nn.LeakyReLU(),
            nn.Conv1d(nb_featuremap, nb_featuremap, kernel_size=krnl_size, stride=(2), dilation=(2), padding=(krnl_size//2+1)),
            nn.LeakyReLU(),
            nn.Conv1d(nb_featuremap, nb_featuremap, kernel_size=krnl_size, stride=(2), dilation=(2), padding=(krnl_size//2+1)),
            nn.LeakyReLU(),
            nn.Conv1d(nb_featuremap, nb_featuremap, kernel_size=krnl_size, stride=(2), dilation=(2), padding=(krnl_size//2+1)),
            nn.LeakyReLU(),
            nn.Conv1d(nb_featuremap, nb_featuremap, kernel_size=krnl_size, stride=(1), dilation=(1), padding=(krnl_size//2)),
            nn.LeakyReLU(),
            #nn.BatchNorm1d(nb_featuremap),
            #nn.MaxPool1d(2),        #Faut il ajouter des maxpool ?
            nn.Flatten()
        )
        self.mlp = nn.Sequential(
            nn.Linear(nb_featuremap * input_len//(2**3)//2, input_size), # input_len//(2**3) is the size of tensor after self.cnn()
            nn.LeakyReLU(),
        )
        self.last = nn.Linear(input_size, 2)

    def forward(self, x):
        flat = self.cnn(x)
        emb = self.mlp(flat)
        return self.last(emb), emb

class PositionalEncoding(nn.Module):

    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        self.register_buffer('div_term', div_term)

    def forward(self, x, pos):
        pos = pos.unsqueeze(2)
        pe = torch.zeros_like(x)
        pe[:, :, 0::2] = torch.sin(pos * self.div_term)
        pe[:, :, 1::2] = torch.cos(pos * self.div_term)
        x = x + pe
        return x


class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.transformer = nn.Transformer(d_model=32, dim_feedforward=256, num_encoder_layers=3, num_decoder_layers=3)

    def forward(self, x, out):
        return self.transformer(x, out)


class TransformerModel(nn.Module):
    """Container module with an encoder, a recurrent or transformer module, and a decoder."""

    def __init__(self):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        emb_dim = 32
        outntoken=2
        self.pos_encoder = PositionalEncoding(emb_dim)
        encoder_layers = TransformerEncoderLayer(emb_dim, nhead=8, dim_feedforward=256, dropout=0.5)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers=3)
        self.ninp = emb_dim
        self.decoder = nn.Linear(emb_dim, outntoken)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, src, pos, has_mask=True):
        if has_mask:
            device = src.device
            if self.src_mask is None or self.src_mask.size(0) != len(src):
                mask = self._generate_square_subsequent_mask(len(src)).to(device)
                self.src_mask = mask
        else:
            self.src_mask = None

        src = src * math.sqrt(self.ninp)
        src = self.pos_encoder(src, pos)
        output = self.transformer_encoder(src, self.src_mask)
        output = self.decoder(output).mean(1)
        return output


class Both(nn.Module):
    def __init__(self, dec, trans):
        super().__init__()
        self.dec = dec
        self.trans = trans

    def forward(self, x, pos):
        return self.trans(self.dec(x), pos)
