import torch


class AE(torch.nn.Module):
    def __init__(self, num_filters = 64, input_dim = 1):
        self.num_filters = num_filters
        super().__init__()

        self.encoder = torch.nn.Sequential(
            torch.nn.Conv1d(input_dim, self.num_filters, kernel_size=25, stride=1, padding=(25//2), padding_mode='replicate', bias=True),
            torch.nn.LeakyReLU(),
            torch.nn.Conv1d(self.num_filters, self.num_filters, kernel_size=25, stride=2, padding=(25//2), padding_mode='replicate', bias=True),
            torch.nn.LeakyReLU(),
            torch.nn.Conv1d(self.num_filters, self.num_filters, kernel_size=25, stride=2, padding=(25//2), padding_mode='replicate', bias=True),
            torch.nn.LeakyReLU(),
            torch.nn.Conv1d(self.num_filters, self.num_filters, kernel_size=25, stride=2, padding=(25//2), padding_mode='replicate', bias=True),
            torch.nn.LeakyReLU(),
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Conv1d(self.num_filters, self.num_filters, kernel_size=25, stride=1, padding=(25//2), padding_mode='replicate', bias=True),
            torch.nn.Upsample(scale_factor=2, mode='nearest'),
            torch.nn.LeakyReLU(),
            torch.nn.Conv1d(self.num_filters, self.num_filters, kernel_size=25, stride=1, padding=(25//2), padding_mode='replicate', bias=True),
            torch.nn.Upsample(scale_factor=2, mode='nearest'),
            torch.nn.LeakyReLU(),
            torch.nn.Conv1d(self.num_filters, self.num_filters, kernel_size=25, stride=1, padding=(25//2), padding_mode='replicate', bias=True),
            torch.nn.Upsample(scale_factor=2, mode='nearest'),
            torch.nn.LeakyReLU(),
            torch.nn.Conv1d(self.num_filters, input_dim, kernel_size=25, stride=1, padding=(25//2), padding_mode='zeros', bias=True),
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded
