# -*- coding: utf-8 -*- 
#Nicolas Enfon - 15/07/14 - LSIS DYNI
#This scripts computes the scalograms, in order to train a CNN for the LifeBIRD classification task
from numpy import *
from matplotlib.pyplot import *
from scipy.io.wavfile import read, write
from pylab import specgram, norm
import cPickle, gzip, os, glob, pylab
import pywt

curdir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/CODE'
wavdir = '/NAS3/SABIOD/SITE/AMAZONE_BIRD_LIFECLEF/2014/LIFECLEF2014_BIRDAMAZON_XC_WAV_RN/'
datadir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/'
validdir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/VALID'
traindir = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/DATA/TRAIN'


def scalo(wavfile, maxlevel=10, level=6, wavelet='db2'):
    """Returns the scalogram of a signal"""
    print 'Reading the file ', wavfile, '...'
    if type(wavfile) == str:#you can pass it either the wavname, or the wav
	rate, signal = read(wavfile)
    else:
	signal = wavfile
    print 'Computing the scalogram ...'
    wp = pywt.WaveletPacket(signal, wavelet, 'sym', maxlevel=maxlevel)
    nodes = wp.get_level(level, order='freq')
    labels = [n.path for n in nodes]
    values = pylab.array([n.data for n in nodes], 'd')
    values = abs(values)
    print 'Scalogram done!'
    return values



def main(start=0, stop=0):
    """Main function, that iterates over the .wav files"""

    #Get the labels
    labels = loadtxt(datadir+'ICML_plus_ID_v2.csv',dtype='str',delimiter=',')
    ID = labels[:,2]
    ID = [ ID[i][31:-4] for i in xrange(len(ID)) ]

    #List the .wav files
    os.chdir(wavdir)
    wavfiles = glob.glob('*')


    if stop == 0:
	stop = len(wavfiles)

    for i in arange(start, stop):
	print 'Iteration', i, 'file', wavfiles[i]
	inlist = 0
	outlist = 0
	rate, signal = read(wavdir + wavfiles[i])
	test = wavfiles[i][31:-4] in ID
	#Check that the file ID exists in our ID list, and get its position in order to know if it's train / test
        if test == True:
	    inlist += 1
            index = ID.index(wavfiles[i][31:-4])

	    if labels[index][3] == 'Train':
		scal = scalo(signal)
		os.chdir(traindir)
		savetxt(wavfiles[i][31:-4]+'.scalo', scal)
	    elif labels[index][3] == 'Test':
		scal = scalo(signal)
		os.chdir(validdir)
		savetxt(wavfiles[i][31:-4]+'.scalo', scal)
	elif test == False:
	    outlist += 1
            print '/!\ Error: not in list at index', i, 'for file', wavfiles[i]
        else:
            print '/!\ Error: neither True nor False at index', i, 'for file', wavfiles[i]


        

    print 'Done!'
    print 'inlist', inlist, 'outlist', outlist
