# -*- coding: utf-8 -*- 
#Nicolas Enfon - 16/07/14 - LSIS DYNI
#This scripts computes the .pkl datasets, in order to train a CNN for the LifeBIRD classification task
from numpy import *
from matplotlib.pyplot import *
from scipy.io.wavfile import read, write
from pylab import specgram, norm
import cPickle, gzip, os, glob, pylab, time
import pywt

curdir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/CODE'
wavdir = '/NAS3/SABIOD/SITE/AMAZONE_BIRD_LIFECLEF/2014/LIFECLEF2014_BIRDAMAZON_XC_WAV_RN/'
datadir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/'
validdir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/VALID/'
traindir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/TRAIN/'
tempdir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/RESULTS_V1/TEMP'
#Load labels and scalogram names
labels = loadtxt(datadir+'ICML_plus_ID_v2.csv',dtype='str',delimiter=',')
ID = labels[:,2]
ID = [ ID[i][31:-4] for i in xrange(len(ID)) ]
scalotrain = array(os.listdir(traindir))
scalovalid = array(os.listdir(validdir))
#/!\ CAUTION: remember that these scalograms have the lower frequencies in the first rows /!\

def cut_scalo(scalo, length, second=False, rate=44100):
    """Crops the scalogram at a given length"""
    level = int( log2(len(scalo)) )
    if second:#if length is given in seconds, convert it in bins
	length = int( length * rate * (2 ** (-level)) )
    out = []
    n = len(scalo[0]) // length
    if n == 0:
	difference = length - len(scalo[0])
        padd = zeros((len(scalo),difference))
	scalo = concatenate((scalo, padd), 1)
	n = 1
    for i in xrange(n):
	out.append(scalo[:, i * length : (i + 1) * length ])
    return out

def crop_reverse_scalo(scalo, top=6, bottom=2):
    """Crops the first and last rows of a scalogram, after having reversed it upside down"""
    scalo2 = scalo[::-1]
    scalo2 = scalo2[top:-bottom,:]
    return scalo2

def load_scalos(scalos, start=0, stop=0):
    """Loads the scalograms between indices [start, stop["""
    signal = []
    if stop == 0:
	stop = len(scalos)
    i = start
    while i < stop:
        print 'in load_scalos, i =', i
	sg = loadtxt(scalos[i]).astype(float32)
        sg = crop_reverse_scalo(sg).astype(float32)
	print 'len(sg)', len(sg)
	print 'sg.shape before reshape', sg.shape
	sg = sg.reshape(1, len(sg) * len(sg[0]))
	print 'sg.shape after  reshape', sg.shape
	label = int32(   labels[   ID.index(scalos[i].split('_')[0])   ][0]   )  -  1#so it begins at 0 and theano doesn't bug
        if i == start:
            signal = sg
            labs = label
        else:
            signal = vstack((signal, sg))
            labs = hstack((labs, label))
	i += 1
    return (signal, labs)

def main(folder, batchsize=10):
    print 'Applying to folder ', folder
    tic = time.clock()
    os.chdir(folder)
    scalos = array(glob.glob('*.scalocut'))
    random.shuffle(scalos)
    nbatches = len(scalos) // batchsize#TODO: change that to get the few last windows
    #for i in xrange(nbatches - 0):
    for i in xrange(2):
	print 'Batch nb', i
	os.chdir(folder)
	batch = load_scalos(scalos, i*batchsize, (i+1)*batchsize)
	print 'Batch shape:', batch[0].shape, 
	print batch
	os.chdir(tempdir)
	f = open('batch_nb_'+str(i)+'.pkl','wb')
	cPickle.dump(batch, f)
        f.close()
    print 'Done in ', time.clock() - tic, 'seconds'
    print 'Left aside scalograms: ', len(scalos) % batchsize

if __name__ == '__main__':
    print 'Run...'
    main(validdir)
    #main(traindir)
