"""This tutorial introduces the LeNet5 neural network architecture
using Theano.  LeNet5 is a convolutional neural network, good for
classifying images. This tutorial shows how to build the architecture,
and comes with all the hyper-parameters you need to reproduce the
paper's MNIST results.


This implementation simplifies the model in the following ways:

 - LeNetConvPool doesn't implement location-specific gain and bias parameters
 - LeNetConvPool doesn't implement pooling by average, it implements pooling
   by max.
 - Digit classification is implemented with a logistic regression rather than
   an RBF network
 - LeNet5 was not fully-connected convolutions at second layer

References:
 - Y. LeCun, L. Bottou, Y. Bengio and P. Haffner:
   Gradient-Based Learning Applied to Document
   Recognition, Proceedings of the IEEE, 86(11):2278-2324, November 1998.
   http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf

"""
import cPickle
import gzip
import os
import sys
import time
import getopt
import numpy
import matplotlib
import matplotlib.pyplot

import theano
import theano.tensor as T
from theano.tensor.signal import downsample
from theano.tensor.nnet import conv

from logistic_sgd import LogisticRegression, load_data
from mlp import HiddenLayer


class LeNetConvPoolLayer(object):
    """Pool Layer of a convolutional network """

    def __init__(self, rng, input, filter_shape, image_shape, poolsize=(2, 2)):
        """
        Allocate a LeNetConvPoolLayer with shared variable internal parameters.

        :type rng: numpy.random.RandomState
        :param rng: a random number generator used to initialize weights

        :type input: theano.tensor.dtensor4
        :param input: symbolic image tensor, of shape image_shape

        :type filter_shape: tuple or list of length 4
        :param filter_shape: (number of filters, num input feature maps,
                              filter height,filter width)

        :type image_shape: tuple or list of length 4
        :param image_shape: (batch size, num input feature maps,
                             image height, image width)

        :type poolsize: tuple or list of length 2
        :param poolsize: the downsampling (pooling) factor (#rows,#cols)
        """

        assert image_shape[1] == filter_shape[1]
        self.input = input

        # there are "num input feature maps * filter height * filter width"
        # inputs to each hidden unit
        fan_in = numpy.prod(filter_shape[1:])
        # each unit in the lower layer receives a gradient from:
        # "num output feature maps * filter height * filter width" /
        #   pooling size
        fan_out = (filter_shape[0] * numpy.prod(filter_shape[2:]) /
                   numpy.prod(poolsize))
        # initialize weights with random weights
        W_bound = numpy.sqrt(6. / (fan_in + fan_out))
        self.W = theano.shared(numpy.asarray(
            rng.uniform(low=-W_bound, high=W_bound, size=filter_shape),
            dtype=theano.config.floatX),
                               borrow=True)

        # the bias is a 1D tensor -- one bias per output feature map
        b_values = numpy.zeros((filter_shape[0],), dtype=theano.config.floatX)
        self.b = theano.shared(value=b_values, borrow=True)

        # convolve input feature maps with filters
        conv_out = conv.conv2d(input=input, filters=self.W,
                filter_shape=filter_shape, image_shape=image_shape)

        # downsample each feature map individually, using maxpooling
        pooled_out = downsample.max_pool_2d(input=conv_out,
                                            ds=poolsize, ignore_border=True)

        # add the bias term. Since the bias is a vector (1D array), we first
        # reshape it to a tensor of shape (1,n_filters,1,1). Each bias will
        # thus be broadcasted across mini-batches and feature map
        # width & height
        self.output = T.tanh(pooled_out + self.b.dimshuffle('x', 0, 'x', 'x'))

        # store parameters of this layer
        self.params = [self.W, self.b]



class LatConnexLayer(object):
    """Lateral connections layer of a convolutional network. """

    def DoG(self, image_shape, S, W):
        """
        Implement a 2D difference of Gaussians function.

        :type  image_shape: tuple or list of length 4
        :param image_shape: (batch size, number of input feature maps,
                             image height, image width)

        :type S:  types.FloatType
        :param S: selective excitation

        :type W:  types.FloatType
        :param W: variance
        """

        size = image_shape[3]
        """ Declare the weighting kernel (matrix). """
        dog = numpy.ndarray(shape=(size, size),
                            dtype=theano.config.floatX, order='F')

        """ For S and W both 0, consider a filter with no effect,
        i.e. a 2D centered delta. """
        if S == 0 and W == 0:
            dog = numpy.zeros(shape=(size, size),
                            dtype=theano.config.floatX, order='F')
            dog[size/2][size/2] = 1.0
            return dog

        """ Compute the values of the weighting kernel. """
        radius = (size-1)/2
        for x in range(0, size):
            for y in range(0, size):
                dx = x - radius
                dy = y - radius
                if dx > 0 and size % 2 == 0:
                    dx -= 1
                if dy > 0 and size % 2 == 0:
                    dy -= 1
                dist2 = (dx/float(size))**2 + (dy/float(size))**2;
                dog[x][y] =  S * numpy.exp(-dist2 / (2.0*W**2)) \
                          / (W * numpy.sqrt(2*numpy.pi))    \
                          - (1-S) * numpy.exp(-dist2 / (2.0*(1-W)**2)) \
                          / ((1-W) * numpy.sqrt(2*numpy.pi))

                """ Normalize with the size of the kernel
                and the number of feature maps? (TODO check)"""
                # dog[x][y] /= numpy.prod(image_shape[1:])

        """ Normalize the data if out of bounds (-1, 1). """
        dogMin = numpy.amin(dog)
        dogMax = numpy.amax(dog)

        if dogMin == dogMax:
            dog = dog * 1/numpy.fabs(dogMax)

        if dogMin < -1:
            c = -1
        else:
            c = dogMin

        if dogMax > 1:
            d = 1
        else:
            d = dogMax

        a = (d-c)/(dogMax-dogMin)
        b = d - a*dogMax

        dog = dog * a + b

        """ Return the 2D weighting kernel. """
        return dog



    def DoG3(self, image_shape, dogPath):
        """
        Load filter from file dogPath

        :type  image_shape: tuple or list of length 4
        :param image_shape: (batch size, number of input feature maps,
                             image height, image width)
        """

        assert dogPath != None

        f = open(dogPath, 'r')
        lines = f.readlines()
        values = []
        for i in range(0, len(lines)):
            values = values + [float(lines[i])]

        size = image_shape[3]
        assert len(lines) == size

        """ Declare the weighting kernel (matrix). """
        dog = numpy.ndarray(shape=(size, size),
                            dtype=theano.config.floatX, order='F')
        c = len(values)/2
        for dx in range(c):
            for dy in range(c):
                d = numpy.sqrt((dx**2 + dy**2)/2)
                print c, dx, dy, int(d)
                dog[c - dx - 1][c - dy - 1] = values[c + int(d)]
                dog[c - dx - 1][c + dy + 0] = values[c + int(d)]
                dog[c + dx + 0][c - dy - 1] = values[c + int(d)]
                dog[c + dx + 0][c + dy + 0] = values[c + int(d)]

        """ Return the 2D weighting kernel. """
        return dog



    def saveDoG(self, layerNumber, S, W):
        """
        Save the 2D weighing kernel to an image.

        :type  layerNumber: int
        :param layerNumber: layer number

        :type  S: float
        :param S: first parameter of the weighting kernel

        :type  W: float
        :param W: second parameter of the weighting kernel
        """

        # TODO - gives runtime error on mejean
        # matplotlib.pyplot.imshow(self.dog)
        # matplotlib.pyplot.clim(-1,1) # reference range of plotted values
        # matplotlib.pyplot.colorbar()
        # imageName = "../data/dog%d_S=%.1f_W=%.1f.png" % (layerNumber, S, W)
        # matplotlib.pyplot.savefig(imageName)

        # DEBUG
        #print "DEBUG DoG for layer%d S=%.1f W=%.1f is:" % (layerNumber, S, W)
        print "DEBUG DoG"
        numpy.set_printoptions(precision=2)
        print self.dog



    def __init__(self, input, S, W, image_shape, dogPath):
        """
        Implement a convolution between a tensor storing 2D images
        and a 2D weighing kernel (difference of Gaussians).

        :type  input: theano.tensor.dtensor4
        :param input: image tensor, of shape image_shape

        :type S:  types.FloatType
        :param S: selective excitation

        :type W:  types.FloatType
        :param W: variance

        :type  image_shape: tuple or list of length 4
        :param image_shape: (batch size, number of input feature maps,
                             image height, image width)
        """

        """ Assert square image shapes. """
        assert image_shape[2] == image_shape[3]

        self.input = input

        """ Compute 2D weighing kernel. """
        #self.dog = self.DoG(image_shape, S, W)
        if dogPath == None:
            self.dog = self.DoG(image_shape, S, W)
        else:
            self.dog = self.DoG3(image_shape, dogPath)

        """ The filter (convolution kernel) covers the entire image.
        The DoG filter has to be expanded from 2D to 4D, but dimshuffle
        could not be used to broadcast the data for axes 1 and 2.
        Using tile to manually replicate the appropriate data.
        TODO: find a better solution. """
        self.wlat = theano.shared(
            value=numpy.tile(self.dog,
                             (image_shape[1], image_shape[1], 1, 1)),
            borrow=True)
        filter_shape = self.wlat.get_value().shape

        # DEBUG
        # print "image_shape = ", image_shape
        # print "filter_shape = ", filter_shape

        """ Convolve input feature maps with the weighting kernel.
        The convolution pads the image with 0 (border_mode='full') in
        order to apply the weighing kernel throughout the entire image. Thus,
        the shape of the convolution is:
        (image_shape[0], image_shape[1],
        image_shape[2] + filter_shape[2] - 1,
        image_shape[3] + filter_shape[3] - 1) """
        conv_out = conv.conv2d(input=self.input,
                               filters=self.wlat,
                               filter_shape=filter_shape,
                               image_shape=image_shape,
                               border_mode='full')

        """ Perform slicing to come back the original image_shape. """
        begin2 = filter_shape[2] / 2
        begin3 = filter_shape[3] / 2
        end2 = begin2 + image_shape[2]
        end3 = begin3 + image_shape[3]
        self.aux = conv_out[:, :, begin2:end2, begin3:end3]

        """ Normalize the output as to have the same norm
        as the input. First compute the norm of the input and aux
        with respect to the last 2 axes."""
        self.inputNorm = self.input.norm(L=2, axis=(2, 3))
        self.auxNorm = self.aux.norm(L=2, axis=(2, 3))

        self.scale = self.inputNorm / self.auxNorm
        self.output = self.aux * self.scale.dimshuffle(0, 1, 'x', 'x')

        """ If no normalization wanted, simply put: """
        #self.output = self.aux


def evaluate_lenet5_lat_connex(learning_rate=0.1, n_epochs=200,
                    dataset='mnist.pkl.gz',
                    nkerns=[20, 50], batch_size=500,
                    S0=0.0, W0=0.0, S1=0.0, W1=0.0,
                    dogPath=None):
    """ Demonstrates lenet with lateral connections on MNIST dataset

    :type learning_rate: float
    :param learning_rate: learning rate used (factor for the stochastic
                          gradient)

    :type n_epochs: int
    :param n_epochs: maximal number of epochs to run the optimizer

    :type dataset: string
    :param dataset: path to the dataset used for training /testing (MNIST here)

    :type nkerns: list of ints
    :param nkerns: number of kernels on each layer

    :type S0: float
    :param S0: S for the weighting kernel of layer 0

    :type W0: float
    :param W0: W for the weighting kernel of layer 0

    :type S1: float
    :param S1: S for the weighting kernel of layer 1

    :type W1: float
    :param W1: W for the weighting kernel of layer 1
    """

    rng = numpy.random.RandomState(23455)

    # Test the performance of the model in noisy conditions.
    datasets = load_data(dataset)

    train_set_x, train_set_y = datasets[0]
    valid_set_x, valid_set_y = datasets[1]
    test_set_x, test_set_y = datasets[2]

    # compute number of minibatches for training, validation and testing
    n_train_batches = train_set_x.get_value(borrow=True).shape[0]
    n_valid_batches = valid_set_x.get_value(borrow=True).shape[0]
    n_test_batches = test_set_x.get_value(borrow=True).shape[0]
    n_train_batches /= batch_size
    n_valid_batches /= batch_size
    n_test_batches /= batch_size

    # allocate symbolic variables for the data
    index = T.lscalar()  # index to a [mini]batch
    x = T.matrix('x')   # the data is presented as rasterized images
    y = T.ivector('y')  # the labels are presented as 1D vector of
                        # [int] labels

    ishape = (28, 28)  # this is the size of MNIST images

    ######################
    # BUILD ACTUAL MODEL #
    ######################
    print '... building the model'

    # Reshape matrix of rasterized images of shape (batch_size,28*28)
    # to a 4D tensor, compatible with our LeNetConvPoolLayer
    layer0_input = x.reshape((batch_size, 1, 28, 28))

    # Construct the first convolutional pooling layer:
    # filtering reduces the image size to (28-5+1,28-5+1)=(24,24)
    # maxpooling reduces this further to (24/2,24/2) = (12,12)
    # 4D output tensor is thus of shape (batch_size,nkerns[0],12,12)
    layer0 = LeNetConvPoolLayer(rng, input=layer0_input,
            image_shape=(batch_size, 1, 28, 28),
            filter_shape=(nkerns[0], 1, 5, 5), poolsize=(2, 2))

    # Construct the lateral connections for the first layer, for which
    # the output is convolved with a lateral connections weighting kernel.
    # The size of the output tensor does not change.
    layer0_lat_connex = LatConnexLayer(input=layer0.output, S=S0, W=W0,
                        image_shape=(batch_size, nkerns[0], 12, 12),
                                       dogPath=dogPath)
    layer0_lat_connex.saveDoG(0, S0, W0)

    # Construct the second convolutional pooling layer
    # filtering reduces the image size to (12-5+1,12-5+1)=(8,8)
    # maxpooling reduces this further to (8/2,8/2) = (4,4)
    # 4D output tensor is thus of shape (nkerns[0],nkerns[1],4,4)
    layer1 = LeNetConvPoolLayer(rng, input=layer0_lat_connex.output,
            image_shape=(batch_size, nkerns[0], 12, 12),
            filter_shape=(nkerns[1], nkerns[0], 5, 5), poolsize=(2, 2))

    # Construct the lateral connections for the second layer, for which
    # the output is convolved with a lateral connections weighting kernel.
    # The size of the output tensor does not change.
    # layer1_lat_connex = LatConnexLayer(input=layer1.output, S=S1, W=W1,
    #                     image_shape=(batch_size, nkerns[1], 4, 4))
    # layer1_lat_connex.saveDoG(1, S1, W1)

    # the HiddenLayer being fully-connected, it operates on 2D matrices of
    # shape (batch_size,num_pixels) (i.e matrix of rasterized images).
    # This will generate a matrix of shape (20,32*4*4) = (20,512)
    #layer2_input = layer1_lat_connex.output.flatten(2)
    layer2_input = layer1.output.flatten(2)

    # construct a fully-connected sigmoidal layer
    layer2 = HiddenLayer(rng, input=layer2_input, n_in=nkerns[1] * 4 * 4,
                         n_out=500, activation=T.tanh)

    # classify the values of the fully-connected sigmoidal layer
    layer3 = LogisticRegression(input=layer2.output, n_in=500, n_out=10)

    # the cost we minimize during training is the NLL of the model
    cost = layer3.negative_log_likelihood(y)

    # create a function to compute the mistakes that are made by the model
    test_model = theano.function([index], layer3.errors(y),
             givens={
                x: test_set_x[index * batch_size: (index + 1) * batch_size],
                y: test_set_y[index * batch_size: (index + 1) * batch_size]})

    validate_model = theano.function([index], layer3.errors(y),
            givens={
                x: valid_set_x[index * batch_size: (index + 1) * batch_size],
                y: valid_set_y[index * batch_size: (index + 1) * batch_size]})

    # create a list of all model parameters to be fit by gradient descent
    params = layer3.params + layer2.params + layer1.params + layer0.params

    # create a list of gradients for all model parameters
    grads = T.grad(cost, params)

    # train_model is a function that updates the model parameters by
    # SGD Since this model has many parameters, it would be tedious to
    # manually create an update rule for each model parameter. We thus
    # create the updates list by automatically looping over all
    # (params[i],grads[i]) pairs.
    updates = []
    for param_i, grad_i in zip(params, grads):
        updates.append((param_i, param_i - learning_rate * grad_i))

    # DEBUG
    train_model = theano.function([index], cost, updates=updates,
    #train_model = theano.function([index], [cost, layer0_lat_connex.input, layer0_lat_connex.inputNorm, layer0_lat_connex.aux, layer0_lat_connex.auxNorm, layer0_lat_connex.scale, layer0_lat_connex.tmp], updates=updates,
          givens={
            x: train_set_x[index * batch_size: (index + 1) * batch_size],
            y: train_set_y[index * batch_size: (index + 1) * batch_size]})

    ###############
    # TRAIN MODEL #
    ###############
    print '... training'
    # early-stopping parameters
    patience = 10000  # look as this many examples regardless
    patience_increase = 2  # wait this much longer when a new best is
                           # found
    improvement_threshold = 0.995  # a relative improvement of this much is
                                   # considered significant
    validation_frequency = min(n_train_batches, patience / 2)
                                  # go through this many
                                  # minibatche before checking the network
                                  # on the validation set; in this case we
                                  # check every epoch

    best_params = None
    best_validation_loss = numpy.inf
    best_iter = 0
    test_score = 0.
    start_time = time.clock()

    epoch = 0
    done_looping = False

    while (epoch < n_epochs) and (not done_looping):
        epoch = epoch + 1
        for minibatch_index in xrange(n_train_batches):

            iter = (epoch - 1) * n_train_batches + minibatch_index

            if iter % 100 == 0:
                print 'training @ iter = ', iter, ' at ', time.strftime("%Y/%m/%d %H:%M:%S")
            # DEBUG
            cost_ij = train_model(minibatch_index)
            # cost_ij = train_model(minibatch_index)[0]

            # print "layer0_lat_connex.input.shape = ", train_model(minibatch_index)[1].shape
            # print "layer0_lat_connex.inputNorm.shape = ", train_model(minibatch_index)[2].shape
            # print "layer0_lat_connex.aux.shape = ", train_model(minibatch_index)[3].shape
            # print "layer0_lat_connex.auxNorm.shape = ", train_model(minibatch_index)[4].shape
            # print "layer0_lat_connex.scale.shape = ", train_model(minibatch_index)[5].shape
            # print "layer0_lat_connex.tmp.shape = ", train_model(minibatch_index)[6].shape

            if (iter + 1) % validation_frequency == 0:

                # compute zero-one loss on validation set
                validation_losses = [validate_model(i) for i
                                     in xrange(n_valid_batches)]
                this_validation_loss = numpy.mean(validation_losses)
                print('epoch %i, minibatch %i/%i, validation error %f %%' % \
                      (epoch, minibatch_index + 1, n_train_batches, \
                       this_validation_loss * 100.))

                # if we got the best validation score until now
                if this_validation_loss < best_validation_loss:

                    #improve patience if loss improvement is good enough
                    if this_validation_loss < best_validation_loss *  \
                       improvement_threshold:
                        patience = max(patience, iter * patience_increase)

                    # save best validation score and iteration number
                    best_validation_loss = this_validation_loss
                    best_iter = iter

                    # test it on the test set
                    test_losses = [test_model(i) for i in xrange(n_test_batches)]
                    test_score = numpy.mean(test_losses)
                    print(('     epoch %i, minibatch %i/%i, test error of best '
                           'model %f %%') %
                          (epoch, minibatch_index + 1, n_train_batches,
                           test_score * 100.))

            if patience <= iter:
                done_looping = True
                break

    end_time = time.clock()
    print('Optimization complete.')
    print('Best validation score of %f %% obtained at iteration %i,'\
          'with test performance %f %%' %
          (best_validation_loss * 100., best_iter + 1, test_score * 100.))
    print >> sys.stderr, ('The code for file ' +
                          os.path.split(__file__)[1] +
                          ' ran for %.2fm' % ((end_time - start_time) / 60.))

if __name__ == '__main__':
    """ Read command line arguments. """
    S0 = 0.0 # weighting kernel parameter S for layer 0
    W0 = 0.0 # weighting kernel parameter W for layer 0
    S1 = 0.0 # weighting kernel parameter S for layer 1
    W1 = 0.0 # weighting kernel parameter W for layer 1
    dogPath = None

    try:
        opts, args = getopt.getopt(sys.argv[1:],'h',['S0=','W0=','dogPath=','dataset='])
    except getopt.GetoptError:
        print("Usage:\n  %s -S0=<S for layer 0> -W0=<W for layer 0>"
              % sys.argv[0])
        sys.exit(2)
    for opt, arg in opts:
        if opt == '-h':
            print("Usage:\n  %s -S0=<S for layer 0> -W0=<W for layer 0>"
                  % sys.argv[0])
            sys.exit()
        elif opt == "--S0":
            S0 = float(arg)
        elif opt == "--W0":
            W0 = float(arg)
        elif opt == "--dogPath":
            dogPath = arg
        elif opt == "--dataset":
            dataset = arg
    print "S0 = ", S0
    print "W0 = ", W0
    print "dogPath = ", dogPath
    print "dataset = ", dataset

    print("Executing script:\n   %s\n" % ' '.join(sys.argv))

    evaluate_lenet5_lat_connex(S0=S0, W0=W0, S1=S1, W1=W1, dogPath=dogPath, dataset=dataset)


def experiment(state, channel):
    evaluate_lenet5_lat_connex(state.learning_rate, dataset=state.dataset)
