#from wavelets_pytorch.transform import WaveletTransformTorch
from torch import nn
import torch
import utils as u
from binarized_modules import BinarizeConv1d
import numpy as np
from filterbank import STFT, MelFilter, Log1p
from PCEN_pytorch import PCENLayer
import sys
sys.path.append('/nfs/NAS5/best/cacha_detec/sincnet')
import sincnet

pDropout = .25

get = {
'cwt_depthwise_64kHz' : lambda nfeat, kernel: nn.Sequential(
 # WaveletTransformTorch(1/64000, 0.125)
),
'stft_depthwise_64kHz_ksize_specBN': lambda nfeat, kernel : nn.Sequential(
  STFT(512, 256),
  MelFilter(64000, 512, 64, 2000, 25000),
  Log1p(a=np.log10(1e7), trainable=True),
  nn.BatchNorm1d(64),
  u.depthwise_separable_conv1d(64, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, 1, kernel, stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'raw': lambda nfeat : nn.Sequential(
  u.Reshape(1, -1),
  nn.Conv1d(1, nfeat//2, 5, stride=2),
  nn.BatchNorm1d(nfeat//2),
  u.Dropout1d(pDropout),
  nn.Conv1d(nfeat//2, nfeat, 7, stride=2),
  nn.BatchNorm1d(nfeat),
  u.Dropout1d(pDropout),
  nn.Conv1d(nfeat, nfeat, 11, stride=2),
  nn.BatchNorm1d(nfeat),
  u.Dropout1d(pDropout),
  nn.Conv1d(nfeat, nfeat, 15, stride=2),
  nn.BatchNorm1d(nfeat),
  u.Dropout1d(pDropout),
  nn.Conv1d(nfeat, 1, 15, stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'quat_stft_depthwise': lambda f, k : nn.Sequential(
  STFT(512, 256),
  MelFilter(50000, 512, 64, 2000, 25000),
  Log1p(a=np.log10(1e7), trainable=True),
  u.Quat_depthwise_separable_conv1d(64, f, k, stride=2, quat=True),
  nn.Dropout(p=pDropout),
  u.Quat_depthwise_separable_conv1d(f, f, k, stride=2, quat=True),
  nn.Dropout(p=pDropout),
  u.Quat_depthwise_separable_conv1d(f, 1, k, stride=2, quat=True),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'stft_depthwise_32kHz_bighop': lambda nfeat : nn.Sequential(
  STFT(512, 256),
  MelFilter(32000, 512, 64, 2000, 16000),
  Log1p(a=np.log10(1e7), trainable=True),
  u.depthwise_separable_conv1d(64, nfeat, 11, stride=2),
  nn.Dropout(p=pDropout),
  u.depthwise_separable_conv1d(nfeat, nfeat, 11, stride=2),
  nn.Dropout(p=pDropout),
  u.depthwise_separable_conv1d(nfeat, 1, 11, stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'stft_depthwise_32kHz': lambda nfeat : nn.Sequential(
  STFT(512, 256),
  MelFilter(32000, 512, 64, 2000, 16000),
  Log1p(a=np.log10(1e7), trainable=True),
  u.depthwise_separable_conv1d(64, nfeat, 11, stride=2),
  nn.Dropout(p=pDropout),
  u.depthwise_separable_conv1d(nfeat, nfeat, 11, stride=2),
  nn.Dropout(p=pDropout),
  u.depthwise_separable_conv1d(nfeat, 1, 11, stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'stft_depthwise_ksize_FB': lambda nfeat, kernel : nn.Sequential(
  STFT(512, 256),
  MelFilter(50000, 512, 64, 500, 25000),
  Log1p(a=np.log10(1e7), trainable=True),
  u.depthwise_separable_conv1d(64, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, 1, kernel, stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),


'stft_depthwise_ksize': lambda nfeat, kernel : nn.Sequential(
  STFT(512, 256),
  MelFilter(50000, 512, 64, 2000, 25000),
  Log1p(a=np.log10(1e7), trainable=True),
  u.depthwise_separable_conv1d(64, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, 1, kernel, stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),


'stft_depthwise_ksize_CARIMAM': lambda nfeat, kernel : nn.Sequential(
  STFT(512, 256),
  MelFilter(128000, 512, 64, 2000, 128000//2),
  Log1p(a=np.log10(1e7), trainable=True),
  u.depthwise_separable_conv1d(64, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, 1, kernel, stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),


'stft_depthwise_ksize_specBN': lambda nfeat, kernel : nn.Sequential(
  STFT(512, 256),
  MelFilter(50000, 512, 64, 2000, 25000),
  Log1p(a=np.log10(1e7), trainable=True),
  nn.BatchNorm1d(64),
  u.depthwise_separable_conv1d(64, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, 1, kernel, stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'stft_depthwise_ksize_specNorm': lambda nfeat, kernel : nn.Sequential(
  STFT(512, 256),
  MelFilter(50000, 512, 64, 2000, 25000),
  Log1p(a=np.log10(1e7), trainable=True),
  u.SpecNorm(),
  u.depthwise_separable_conv1d(64, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, 1, kernel, stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'stft_depthwise': lambda nfeat : nn.Sequential(
  STFT(512, 256),
  MelFilter(50000, 512, 64, 2000, 25000),
  Log1p(a=np.log10(1e7), trainable=True),
  u.depthwise_separable_conv1d(64, nfeat, 11, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.Dropout(p=pDropout),
  u.depthwise_separable_conv1d(nfeat, nfeat, 11, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.Dropout(p=pDropout),
  u.depthwise_separable_conv1d(nfeat, 1, 11, stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'stft_depthwise_noBN_noDO': lambda nfeat : nn.Sequential(
  STFT(512, 256),
  MelFilter(50000, 512, 64, 2000, 25000),
  Log1p(a=np.log10(1e7), trainable=True),
  u.Quat_depthwise_separable_conv1d(64, nfeat, 11, stride=2, BN=True),
  u.Quat_depthwise_separable_conv1d(nfeat, nfeat, 11, stride=2, BN=True),
  u.Quat_depthwise_separable_conv1d(nfeat, 1, 11, stride=2, BN=True),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'sincnet_dw': lambda nfeat, kernel : nn.Sequential(
  sincnet.SincConv_fast(64, 512, sample_rate=50000),
  u.depthwise_separable_conv1d(64, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, nfeat, kernel, stride=2),
  nn.BatchNorm1d(nfeat),
  nn.LeakyReLU(),
  u.Dropout1d(pDropout),
  u.depthwise_separable_conv1d(nfeat, 1, kernel, stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'sincnet':nn.Sequential(
  sincnet.SincConv_fast(64, 512, sample_rate=50000),
  nn.Conv1d(64, 128, kernel_size=(7), stride=2),
  nn.BatchNorm1d(128),
  nn.LeakyReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(128, 128, kernel_size=(7), stride=2),
  nn.BatchNorm1d(128),
  nn.LeakyReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(128, 1, kernel_size=(7), stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'stft_nomel': lambda nfeat : nn.Sequential(
  STFT(512, 256),
  nn.Conv1d(257, nfeat, kernel_size=(11), stride=2),
  nn.BatchNorm1d(nfeat),
  nn.ReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(nfeat, nfeat, kernel_size=(11), stride=2),
  nn.BatchNorm1d(nfeat),
  nn.ReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(nfeat, 1, kernel_size=(11), stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'stft_pcen': nn.Sequential(
  STFT(512, 256),
  MelFilter(50000, 512, 64, 2000, 25000),
  Log1p(a=np.log10(1e7), trainable=True),
  PCENLayer(64),
  nn.Conv1d(64, 512, kernel_size=(11), stride=2),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(512, 512, kernel_size=(11), stride=2),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(512, 1, kernel_size=(11), stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'stft': lambda nfeat : nn.Sequential(
  STFT(512, 256),
  MelFilter(50000, 512, 64, 2000, 25000),
  Log1p(a=np.log10(1e7), trainable=True),
  nn.Conv1d(64, nfeat, kernel_size=(11), stride=2),
  nn.BatchNorm1d(nfeat),
  nn.ReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(nfeat, nfeat, kernel_size=(11), stride=2),
  nn.BatchNorm1d(nfeat),
  nn.ReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(nfeat, 1, kernel_size=(11), stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'qat_stft': lambda nfeat : nn.Sequential(
  STFT(512, 256),
  MelFilter(50000, 512, 64, 2000, 25000),
  Log1p(a=np.log10(1e7), trainable=True),
  nn.intrinsic.quantized.modules.conv_relu.ConvReLU1d(64, nfeat, kernel_size=11, stride=2),
  nn.Identity(),
  nn.Identity(),
  nn.Dropout(p=0.25, inplace=False),
  nn.intrinsic.quantized.modules.conv_relu.ConvReLU1d(nfeat, nfeat, kernel_size=11, stride=2),
  nn.Identity(),
  nn.Identity(),
  nn.Dropout(p=0.25, inplace=False),
  nn.intrinsic.quantized.modules.conv_relu.ConvReLU1d(nfeat, 1, kernel_size=11, stride=2),
  nn.AdaptiveMaxPool1d(output_size=1)),

'bool_stft': nn.Sequential(
  STFT(512, 256),
  MelFilter(50000, 512, 64, 2000, 25000),
  Log1p(a=np.log10(1e7), trainable=True),
  BinarizeConv1d(64, 2048, kernel_size=11, stride=2),
  nn.BatchNorm1d(2048,),
  nn.Hardtanh(inplace=True),
  BinarizeConv1d(2048, 2048, kernel_size=11, stride=2),
  nn.BatchNorm1d(2048),
  nn.Hardtanh(inplace=True),
  BinarizeConv1d(2048, 2048, kernel_size=11, stride=2),
  nn.BatchNorm1d(2048),
  nn.Hardtanh(inplace=True),
  BinarizeConv1d(2048, 2048, kernel_size=11, stride=2),
  nn.BatchNorm1d(2048),
  nn.Hardtanh(inplace=True),
  BinarizeConv1d(2048, 1, kernel_size=11, stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),


'0.10': nn.Sequential(
  u.GammaSpec(50000, 512, 256, 64, 2000),
  nn.Conv1d(64, 512, kernel_size=(11), stride=2),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(512, 512, kernel_size=(11), stride=2),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(512, 1, kernel_size=(11), stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'0.9' : nn.Sequential(
  u.GammaSpec(50000, 512, 128, 64, 2000),
  nn.Conv1d(64, 512, kernel_size=(1)),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Conv1d(512, 1, kernel_size=(1)),
  nn.LeakyReLU(),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'0.8': nn.Sequential(
  u.GammaSpec(50000, 512, 256, 64, 2000),
  nn.Conv1d(64, 512, kernel_size=(17), stride=5),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Conv1d(512, 512, kernel_size=(11), stride=2),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Conv1d(512, 1, kernel_size=(5), stride=1),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'0.7': nn.Sequential(
  u.GammaSpec(50000, 512, 256, 64, 2000),
  nn.Conv1d(64, 512, kernel_size=(11), stride=2),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Conv1d(512, 512, kernel_size=(11), stride=2),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Conv1d(512, 1, kernel_size=(11), stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'0.6': nn.Sequential(
  nn.Conv1d(64, 512, kernel_size=(17), stride=2),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Conv1d(512, 512, kernel_size=(17), stride=2),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Conv1d(512, 1, kernel_size=(17), stride=2),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'0.5': nn.Sequential(
  nn.Conv1d(64, 512, kernel_size=(11)),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Conv1d(512, 1, kernel_size=(11)),
  nn.LeakyReLU(),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'0.4' : nn.Sequential(
  nn.Conv1d(64, 256, kernel_size=(1)),
  nn.BatchNorm1d(256),
  nn.LeakyReLU(),
  nn.Conv1d(256, 1, kernel_size=(1)),
  nn.LeakyReLU(),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'0.3' : nn.Sequential(
  nn.Conv1d(64, 512, kernel_size=(1)),
  nn.BatchNorm1d(512),
  nn.LeakyReLU(),
  nn.Conv1d(512, 1, kernel_size=(1)),
  nn.LeakyReLU(),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'0.2' : nn.Sequential(
  nn.Conv1d(32, 1024, kernel_size=(1)),
  nn.BatchNorm1d(1024),
  nn.LeakyReLU(),
  nn.Conv1d(1024, 1, kernel_size=(1)),
  nn.LeakyReLU(),
  nn.AdaptiveMaxPool1d(output_size=(1))),

'0.0' : nn.Sequential(
  nn.Conv1d(32, 128, kernel_size=(1)),
  nn.BatchNorm1d(128),
  nn.LeakyReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(128, 1024, kernel_size=(1)),
  nn.BatchNorm1d(1024),
  nn.LeakyReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(1024, 1024, kernel_size=(1)),
  nn.BatchNorm1d(1024),
  nn.LeakyReLU(),
  nn.Dropout(p=pDropout),
  nn.Conv1d(1024, 1, kernel_size=(1)),
  nn.AdaptiveMaxPool1d(output_size=(1)))
}
