# -*- coding: utf-8 -*- 
#Nicolas Enfon - 01/05/14 - LSIS DYNI
import numpy as np
from matplotlib.pyplot import *
import cPickle, gzip
from time import clock

def addnoise(set_x, rate, white, patch, reduce):
    '''Adds uniformly random noise or white noise, at a given rate, in the picture (mnist) dataset'''
    from math import sqrt
    if reduce == 0:
	    noisy_x = np.array(np.zeros((len(set_x),len(set_x[0]))), dtype=np.float32)
	    if patch == 0:#no patch: each pixel can be noised
	        for i in xrange(len(set_x)):#for each picture
	            for j in xrange(len(set_x[0])):#for each pixel
	                if np.random.random() < rate:
	                    if white == 'True':
	                        noisy_x[i][j] = 0.#0 means white pixel - TOCHECK
	                    else:
	                        noisy_x[i][j] = np.random.random()#other distribution: np.random.beta(8,8)
	                else:
	                    noisy_x[i][j] = set_x[i][j]
	        return noisy_x
	    elif patch > 0:#patch of noise
	        rate = rate / (patch ** 2)
	        width = sqrt(len(set_x[0]))
	        for i in xrange(len(set_x)):#for each picture
	            for j in  xrange(len(set_x[0])):
	                if np.random.random() < rate:
	                    startj = int(max(width * (j//width), j - patch + 1))
	                    div = j // width
	                    startl = int(max(0, div - patch))
	                    if white == 'True':
	                        for k in xrange(startj, j + 1):
	                            rest = k % width
	                            for l in xrange(startl, int(div)):
	                                #print 'div:',div,'k:',k,'l:',l,'startj:',startj,'startl:',startl
	                                noisy_x[i][l * width + rest] = 0.    
	                    else:
	                        for k in xrange(startj, j + 1):
	                            rest = k % width
	                            for l in xrange(startl, div + 1):
	                                noisy_x[i][l * width + rest] = np.random.random()
	                else:
	                    noisy_x[i][j] = set_x[i][j]
    elif reduce > 0:
        noisy_x = set_x
        n = 0 #nb of pixels to be removed
        for i in xrange(reduce):
            n += (28 - i) + (28 - i - 1)
        noisy_x = list(noisy_x) #transforming into list to remove some elements (pixels) easily
        for i in xrange(len(noisy_x)):
            noisy_x[i] = list(noisy_x[i])
        for i in xrange(len(noisy_x)): #for each picture
            count = 0
            while count < n:
                pix = np.random.randint(len(noisy_x[i]))
                noisy_x[i].pop(pix)
                count += 1
        for i in xrange(len(noisy_x)):
            noisy_x[i] = np.array(noisy_x[i], dtype=np.float32)#retransforming into np array
        noisy_x = np.array(noisy_x)    
    
    return noisy_x

def main(rate, white, first, second, third, patch, verbose, reduce):
    '''Main function; adds noise to the training or validation or test sets'''
    #Various parameters
    datasetname = 'mnist.pkl.gz'
    noisyname = datasetname[:-7]+'_noisy.pkl.gz'
    tic = clock()#to measure execution speed
    #Open databases
    sets = cPickle.load(gzip.open(datasetname))
    noisy = gzip.open(noisyname, 'wb')
    #Name the different sets
    training_set_x, training_set_y = sets[0]
    valid_set_x, valid_set_y = sets[1]
    test_set_x, test_set_y = sets[2]
    training_set_x_noisy = training_set_x
    valid_set_x_noisy = valid_set_x
    test_set_x_noisy = test_set_x
    #Adds noise
    if first == 'True':
        if verbose:
           print 'Adding noise to the training set...'         
        training_set_x_noisy = addnoise(training_set_x, rate, white, patch, reduce)
    if second == 'True':
        if verbose:
            print 'Adding noise to the validation set...'
        valid_set_x_noisy = addnoise(valid_set_x, rate, white, patch, reduce)
    if third == 'True':
        if verbose:
            print 'Adding noise to the test set...'
        test_set_x_noisy = addnoise(test_set_x, rate, white, patch, reduce)
    
    #Write and save the noisy dataset
    if verbose:
        print 'Saving the noisy dataset...'
    noisyset = ( (training_set_x_noisy, training_set_y), (valid_set_x_noisy, valid_set_y), (test_set_x_noisy, test_set_y) )#careful here: gzip couldn't open lists saved in cPickle
    cPickle.dump(noisyset,noisy)
    noisy.close()
    
    if verbose:
        print '----------------------------'
        print 'Noise added succesfully in '+str(clock()-tic)+' seconds'
        print 'The noisy dataset is saved under the name '+noisyname
        print '----------------------------'

if __name__ == '__main__':
    #To get all the options, run python <this_file.py> --help 
    from optparse import OptionParser
    parser = OptionParser()
    parser.add_option('-v', '--verbose', action='store_true', dest='verbose', default=True)
    parser.add_option('-q', '--quiet', action='store_false', dest='verbose')
    parser.add_option('-r', '--rate', type=float, help='Noise rate in the image', default=0.3)
    parser.add_option('-w', '--white', help='If True, adds white patches on the image. Otherwise, random color patches', default='True')
    parser.add_option('-1', '--first', help='Adds noise to the 1st dataset (train)', default='True')
    parser.add_option('-2', '--second', help='Adds noise to the 1st dataset (train)', default='True')
    parser.add_option('-3', '--third', help='Adds noise to the 1st dataset (train)', default='True')
    parser.add_option('-p', '--patch', type=int, help='Noise patches instead of noise pixels', default=4)
    parser.add_option('-R','--reduce', type=int, help='Reduces the image dimension to (N-n)*(N-n)', default=2)
    (options, args) = parser.parse_args()
    main(options.rate, options.white, options.first, options.second, options.third, options.patch, options.verbose, options.reduce)
