# -*- coding: utf-8 -*- 
#Nicolas Enfon - 17/07/14 - LSIS DYNI
#This scripts cut windows out of the scalograms, in order to train a CNN for the LifeBIRD classification task
from numpy import *
from matplotlib.pyplot import *
from scipy.io.wavfile import read, write
from pylab import specgram, norm
import cPickle, gzip, os, glob, pylab, time
import pywt, time

curdir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/CODE'
wavdir = '/NAS3/SABIOD/SITE/AMAZONE_BIRD_LIFECLEF/2014/LIFECLEF2014_BIRDAMAZON_XC_WAV_RN/'
datadir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/'
validdir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/VALID/'
traindir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/TRAIN/'

#Load labels and scalogram names
labels = loadtxt(datadir+'ICML_plus_ID_v2.csv',dtype='str',delimiter=',')
ID = labels[:,2]
ID = [ ID[i][31:-4] for i in xrange(len(ID)) ]
scalotrain = array(os.listdir(traindir))
scalovalid = array(os.listdir(validdir))
#/!\ CAUTION: remember that these scalograms have the lower frequencies in the first rows /!\

def cut_scalo(scalo, length=2000, overlap=1500, second=False, rate=44100, cut=(0,1)):
    """Cut the scalogram in windows of equal length"""
    level = int( log2(len(scalo)) )#gives the level of decomposition of the scalogram (power of 2)

    #Flip and crop the scalogram
    scalo = scalo[::-1]
    scalo = scalo[cut[0]:-cut[1],:] 
    #Compute energy and energy variance time distribution
    """ energy_ori = scalo.sum(0)#energy over all frequencies
    energy = convolve(energy_ori, ones(350), mode='same')#smoothing
    energy = energy / norm(energy)#normalize
    detect = ( energy > energy.mean() / 1. )#simple detector
    size = 120#size of windows over wich we compute the variance
    n = len(energy_ori) // size #number of windows
    reste = len(energy_ori) % size
    energy2 = energy_ori[:-reste]
    energyend = energy_ori[-reste:]
    energy2.resize(n, size)
    norme = norm(energy2, axis=1)
    norme = repeat(norme, size)
    norme.resize(n, size)
    energy2 /= norme#normalizing within each window
    var = energy2.var(1)
    varend = energyend.var()
    varend = repeat(varend, reste)
    var = repeat(var, size)
    var = hstack((var, varend))
    var = var / norm(var)
    detect2 = ( var > var.mean()/1. )"""
    #TODO: cut windows of size N when detector, padded with 0 if smal detection, and in several windows if big detection
    if second:#if length is given in seconds, convert it in bins
	length  = int(  length * rate * (2 ** (-level)) )
	overlap = int( overlap * rate * (2 ** (-level)) )
    out = []
    i = 0
    match = False
    while i + length <= len(scalo[0]):
	print 'break'
	break
	out.append(scalo[:, i : i + length ])
	if i + length == len(scalo[0]):#if it matches perfectly, nothing special to do
	    match = True
        i += length - overlap
    if not(match):#if the last window is too big, padd with zeros
	difference = i + length - len(scalo[0])
	padd = zeros((len(scalo),difference))
	scalo = concatenate((scalo, padd), 1)
	out.append(scalo[:, i : i + length ])
    return out

def main(folder, start=0, stop=0):
    print 'Applying function to folder ', folder
    tic = time.clock()
    os.chdir(folder)
    scalos = glob.glob('*.scalo')#get the names of the scalos
    if stop == 0:
	stop = len(scalos)
    for i in arange(start, stop):
	print 'scalo ', scalos[i]
	scalo = loadtxt(scalos[i])
	windows = cut_scalo(scalo)
	for j in xrange(len(windows)):
	    print '       part ', j
	    savetxt(scalos[i].split('.')[0]+'_part_'+str(j+1)+'.fullscalocut', windows[j])
    print 'Done in ', time.clock() - tic, 'seconds'

if __name__ == '__main__':
    print 'Run...'
    main(validdir)
#    main(traindir)
