import torch
import torch.nn as nn
from torchaudio.transforms import MelScale

from torchsummary import summary

SR = 256_000

def cnn_block(n_input, n_feat, last=False):
    mp_size = (2, 2)
    if last is True:
        mp_size = (1,2)

    seq = nn.Sequential(
        nn.Conv2d(n_input, n_feat, kernel_size=3, padding=1),
        nn.LeakyReLU(),
        nn.BatchNorm2d(n_feat),

        nn.Conv2d(n_feat, n_feat, kernel_size=3, padding=1),
        nn.LeakyReLU(),
        nn.BatchNorm2d(n_feat),
        nn.MaxPool2d(mp_size)
    )
    return seq

def context_block(n_input, n_feat):
    context = nn.Sequential(
        nn.Conv1d(n_input, n_feat, kernel_size=1),
        nn.LeakyReLU(),
        nn.BatchNorm1d(n_feat),
    )
    return context

class DetectorErbs(nn.Module):
    """
    Deep Learing Detector
    """
    def __init__(self):
        super().__init__()
        self.mel_scale = MelScale(n_mels=128, 
                                  sample_rate=SR, 
                                  f_min=100, 
                                  f_max=SR//2, 
                                  n_stft=2049) #2048)

        self.all_cnn_blocks = nn.Sequential(
            cnn_block(1, 32),
            cnn_block(32, 64),
            cnn_block(64, 96),
            cnn_block(96, 128),
            cnn_block(128, 160, True),
            nn.Flatten(1, 2)
        )

        self.all_context_block = nn.Sequential(
            context_block(1280, 256),
            context_block(256, 256),
            nn.Conv1d(256, 2, kernel_size=1),
            nn.Sigmoid()
        )

        self.last =  nn.Sequential(
            nn.Flatten(),
            nn.Linear(2*16, 2),
        )

    def forward(self, x):
        mel_x = self.mel_scale(x)
        flat = self.all_cnn_blocks(mel_x)
        cntxt = self.all_context_block(flat)
        #return torch.quantile(cntxt, 75./100., dim=2), cntxt
        return self.last(cntxt), cntxt



def main():
    context_cnn = DetectorErbs()
    summary(context_cnn.cuda(), input_size=(1, 2049, 512), batch_size = 32)

if __name__ == "__main__":
    main()
