import numpy as np
from torch import tensor, nn, exp, log, ones, stack, log1p, zeros


class PCENLayer(nn.Module):
    def __init__(self, num_bands,
                 s=0.025,
                 alpha=.8,
                 delta=10.,
                 r=.25,
                 eps=1e-6,
                 init_smoother_from_data=True):
        super(PCENLayer, self).__init__()
        self.num_bands = num_bands
        self.log_s = nn.Parameter( log(ones((1,1,num_bands)) * s))
        self.log_alpha = nn.Parameter( log(ones((1,1,num_bands,1)) * alpha))
        self.log_delta = nn.Parameter( log(ones((1,1,num_bands,1)) * delta))
        self.log_r = nn.Parameter( log(ones((1,1,num_bands,1)) * r))
        self.eps = tensor(eps)
        self.init_smoother_from_data = init_smoother_from_data

    def forward(self, input): # expected input (batch, channel, freqs, time)
        batchsize = input.shape[0]
        input = input.unsqueeze(1)
        init = input[:,:,:,0]  # initialize the filter with the first frame
        if not self.init_smoother_from_data:
            init = zeros(init.shape)  # initialize with zeros instead

        filtered = exp(self.log_s.unsqueeze(-1))*input
        filtered = filtered.permute(3, 0, 1, 2)
#        filtered[0] = init
#        for i in range(1, len(filtered)):
#            filtered[i] = filtered[i] + (1-exp(self.log_s)) * filtered[i-1]
#        filtered = filtered.permute(1,2,3,0)

        filtered = [init]
        for iframe in range(1, input.shape[-1]):
            filtered.append( (1-exp(self.log_s)) * filtered[iframe-1] + exp(self.log_s) * input[:,:,:,iframe] )
        filtered = stack(filtered).permute(1,2,3,0)

        # stable reformulation due to Vincent Lostanlen; original formula was:
        #return (input / (self.eps + filtered)**alpha + delta)**r - delta**r
        alpha, delta, r = exp(self.log_alpha), exp(self.log_delta), exp(self.log_r)
        filtered = exp(-alpha * (log(self.eps) + log1p(filtered / self.eps)))
        return ((input * filtered + delta)**r - delta**r).reshape(batchsize, self.num_bands, -1)
