# -*- coding: utf-8 -*- 
#Nicolas Enfon - 16/07/14 - LSIS DYNI
#This scripts computes the .pkl datasets, in order to train a CNN for the LifeBIRD classification task
from numpy import *
from matplotlib.pyplot import *
from scipy.io.wavfile import read, write
from pylab import specgram, norm
import cPickle, gzip, os, glob, pylab, time
import pywt

curdir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/CODE'
wavdir = '/NAS3/SABIOD/SITE/AMAZONE_BIRD_LIFECLEF/2014/LIFECLEF2014_BIRDAMAZON_XC_WAV_RN/'
datadir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/'
validdir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/VALID/'
validout = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/VALID_SOMESP_311_212/'
traindir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/TRAIN/'
trainout = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/TRAIN_SOMESP_311_212/'

#Load labels and scalogram names
labels = loadtxt(datadir+'ICML_plus_ID_v2.csv',dtype='str',delimiter=',')
ID = labels[:,2]
ID = [ ID[i][31:-4] for i in xrange(len(ID)) ]
scalotrain = array(os.listdir(traindir))
scalovalid = array(os.listdir(validdir))
#/!\ CAUTION: remember that these scalograms have the lower frequencies in the first rows /!\
#--> maybe not the new ones

def cut_scalo(scalo, length, second=False, rate=44100):
    """Crops the scalogram at a given length"""
    level = int( log2(len(scalo)) )
    if second:#if length is given in seconds, convert it in bins
	length = int( length * rate * (2 ** (-level)) )
    out = []
    n = len(scalo[0]) // length
    if n == 0:
	difference = length - len(scalo[0])
        padd = zeros((len(scalo),difference))
	scalo = concatenate((scalo, padd), 1)
	n = 1
    for i in xrange(n):
	out.append(scalo[:, i * length : (i + 1) * length ])
    return out

def crop_reverse_scalo(scalo, top=6, bottom=2):
    """Crops the first and last rows of a scalogram, after having reversed it upside down"""
    scalo2 = scalo[::-1]
    scalo2 = scalo2[top:-bottom,:]
    return scalo2

def load_scalos(scalos, start=0, batchsize=500, specnb=(311,212)):
    """Loads the scalograms between indices [start, stop["""
    signal = []
    i = start
    count = 0
    while count < batchsize:
	print "len(scalos):            ", len(scalos), '     i:    ', i
	print "scalos[i].split('_')[0]: ", scalos[i].split('_')[0]
	print 'len(ID):                  ', len(ID)
	print "ID.index(scalos[i].split('_')[0]): ", ID.index(scalos[i].split('_')[0])
	print 'len(labels):               ', len(labels)
        label = int32(   labels[   ID.index(scalos[i].split('_')[0])   ][0]   )  -  1#so it begins at 0 and theano doesn't bug
	if label in specnb: #to keep only the first species (faster)
	    print 'in load_scalos, i =', i, 'label = ',label
            sg = loadtxt(scalos[i]).astype(float32)
            sg = crop_reverse_scalo(sg).astype(float32)
	    print 'len(sg)', len(sg)
	    print 'sg.shape before reshape', sg.shape
	    sg = sg.reshape(1, len(sg) * len(sg[0]))
	    print 'sg.shape after  reshape', sg.shape
            if count == 0:
                signal = sg
                labs = label
            else:
                signal = vstack((signal, sg))
                labs = hstack((labs, label))
	    count += 1
	    i += 1
	else:
	    i += 1
    return i, (signal, labs)

def main(folder, folderout, batchsize=500, specnb=(311,212)):
    print 'Applying to folder ', folder
    tic = time.clock()
    os.chdir(folder)
    scalos = array(glob.glob('*.scalocut'))
    #---TAKES INTO ACCOUNT THE FACT THAT WE LIMIT THE SPECIES NUMBER----
    number = 0
    for i in xrange(len(scalos)):
	label = int32(   labels[   ID.index(scalos[i].split('_')[0])   ][0]   )  -  1#so it begins at 0 and theano doesn't bug
	if  label in specnb:
	    number += 1
    #asupr:
    #return number // batchsize
    #-------------
    random.shuffle(scalos)
    #nbatches = len(scalos) // batchsize#TODO: change that to get the few last windows?
    nbatches = number // batchsize
    index = 0
    for i in xrange(nbatches - 0):#why -1? TO REMOVE
	print 'Batch nb', i
	os.chdir(folder)
	index, batch = load_scalos(scalos, index, batchsize, specnb)
	os.chdir(folderout)
	f = open('batch_nb_'+str(i)+'.pkl','wb')
	cPickle.dump(batch, f, protocol=cPickle.HIGHEST_PROTOCOL)
        f.close()
    print 'Done in ', time.clock() - tic, 'seconds'
    print 'Left aside scalograms: ', len(scalos) % batchsize

if __name__ == '__main__':
    print 'Run...'
    main(validdir, validout)
    main(traindir, trainout)#this one has a memory error... but it resolved itself!
