# -*- coding: utf-8 -*- 
#Nicolas Enfon - 17/07/14 - LSIS DYNI
#This scripts cut windows out of the scalograms, in order to train a CNN for the LifeBIRD classification task
from numpy import *
from matplotlib.pyplot import *
from scipy.io.wavfile import read, write
from pylab import specgram, norm
import cPickle, gzip, os, glob, pylab, time
import pywt

curdir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/CODE'
wavdir = '/NAS3/SABIOD/SITE/AMAZONE_BIRD_LIFECLEF/2014/LIFECLEF2014_BIRDAMAZON_XC_WAV_RN/'
datadir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/'
validdir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/VALID/'
traindir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/TRAIN/'

#Load labels and scalogram names
labels = loadtxt(datadir+'ICML_plus_ID_v2.csv',dtype='str',delimiter=',')
ID = labels[:,2]
ID = [ ID[i][31:-4] for i in xrange(len(ID)) ]
scalotrain = array(os.listdir(traindir))
scalovalid = array(os.listdir(validdir))
#/!\ CAUTION: remember that these scalograms have the lower frequencies in the first rows /!\

def cut_scalo(scalo, length=2000, overlap=1500, second=False, rate=44100):
    """Cut the scalogram in windows of equal length"""
    level = int( log2(len(scalo)) )
    if second:#if length is given in seconds, convert it in bins
        length  = int(  length * rate * (2 ** (-level)) )
	overlap = int( overlap * rate * (2 ** (-level)) )
    out = []
    i = 0
    match = False
    while i + length <= len(scalo[0]):
	out.append(scalo[:, i : i + length ])
	if i + length == len(scalo[0]):#if it matches perfectly, nothing special to do
	    match = True
        i += length - overlap
    if not(match):#if the last window is too big, padd with zeros
	difference = i + length - len(scalo[0])
	padd = zeros((len(scalo),difference))
	scalo = concatenate((scalo, padd), 1)
	out.append(scalo[:, i : i + length ])
    #TODO: check if zero padding is a problem for the CNN
    return out

def main(folder, start=0, stop=0):
    print 'Applying function to folder ', folder
    os.chdir(folder)
    scalos = glob.glob('*.scalo')#get the names of the scalos
    if stop == 0:
	stop = len(scalos)
    for i in arange(start, stop):
	print 'scalo ', scalos[i]
	scalo = loadtxt(scalos[i])
	windows = cut_scalo(scalo)
	for j in xrange(len(windows)):
	    print '       part ', j
	    savetxt(scalos[i].split('.')[0]+'_part_'+str(j+1)+'.scalocut', windows[j])


if __name__ == '__main__':
    print 'Run...'
    main(validdir)
    main(traindir)
