from torch import nn

pDropout = .25

get = {
'main1d': nn.Sequential(
  nn.Conv1d(64, 128, 3, bias=False),
  nn.BatchNorm1d(128),
  nn.LeakyReLU(0.01),
  nn.Conv1d(128, 128, 3,bias=False),
  nn.BatchNorm1d(128),
  nn.MaxPool1d(3),
  nn.LeakyReLU(0.01),
  nn.Conv1d(128, 128, 3, bias=False),
  nn.BatchNorm1d(128),
  nn.LeakyReLU(0.01),
  nn.Conv1d(128, 128, 3, bias=False),
  nn.BatchNorm1d(128),
  nn.LeakyReLU(0.01),
  nn.Conv1d(128, 128, 3, bias=False),
  nn.BatchNorm1d(128),
  nn.MaxPool1d(3),
  nn.LeakyReLU(0.01),
  nn.Dropout(p=.5),
  nn.Conv1d(128, 256, 9, bias=False),  # for 80 bands
  nn.BatchNorm1d(256),
  nn.LeakyReLU(0.01),
  nn.Dropout(p=.5),
  nn.Conv1d(256, 64, 1, bias=False),
  nn.BatchNorm1d(64),
  nn.LeakyReLU(0.01),
  nn.Dropout(p=.5),
  nn.Conv1d(64, 1, 1, bias=False),
  nn.AdaptiveMaxPool1d(output_size=1)
),
'main': nn.Sequential(
  nn.Conv2d(1, 32, 3, bias=False),
  nn.BatchNorm2d(32),
  nn.LeakyReLU(0.01),
  nn.Conv2d(32, 32, 3, bias=False),
  nn.BatchNorm2d(32),
  nn.MaxPool2d(3),
  nn.LeakyReLU(0.01),
  nn.Conv2d(32, 32, 3, bias=False),
  nn.BatchNorm2d(32),
  nn.LeakyReLU(0.01),
  nn.Conv2d(32, 32, 3, bias=False),
  nn.BatchNorm2d(32),
  nn.LeakyReLU(0.01),
  nn.Conv2d(32, 64, (16, 3), bias=False),
  nn.BatchNorm2d(64),
  nn.MaxPool2d((1,3)),
  nn.LeakyReLU(0.01),
  nn.Dropout(p=.5),
  nn.Conv2d(64, 256, (1, 9), bias=False),  # for 80 bands avec 9 a la place de 3
  nn.BatchNorm2d(256),
  nn.LeakyReLU(0.01),
  nn.Dropout(p=.5),
  nn.Conv2d(256, 64, 1, bias=False),
  nn.BatchNorm2d(64),
  nn.LeakyReLU(0.01),
  nn.Dropout(p=.5),
  nn.Conv2d(64, 7, 1, bias=False), #### nb de label (10, 11)
  nn.AdaptiveMaxPool2d(output_size=(1, 1))
)
}
