# -*- coding: utf8 -*-
# Nicolas Enfon - LSIS DYNI - 08/07/14
from numpy import *
from matplotlib.pyplot import *
from scipy.io.wavfile import read
from pylab import norm, specgram
import os, glob, time

#----Parameters----
codefolder = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/CODE/'
datafolder = '/NAS3/SABIOD/public_data/SCATT_BIRD_50mostE_perfile_april2014_RB_HG/'
resultfolder = '/NAS3/SABIOD/METHODES/NICOLAS/CNN2D_LIFEBIRD/RESULTS_V1/'
specroot = '/NAS3/SABIOD/public_data/SCATT_BIRD_50mostE_perfile_april2014_RB_HG/'
n = 10

#----List files and directories----
#os.chdir(datafolder)
mixed = glob.glob(datafolder + '/*')
files = [mixed[i] for i in xrange(len(mixed)) if os.path.isfile(mixed[i])]
folders = [mixed[i] for i in xrange(len(mixed)) if os.path.isdir(mixed[i])]
species = [specroot + str(i) for i in xrange(1, n + 1)]

#----Energy profile of a scalogram----
def energy_profile(datfile, wavfile=None, skip=True, normalize=True, showplot=False, sauve=None):
    """Computes and plots the energy profile of a scalogram (.dat file)"""
    dat = loadtxt(datfile, delimiter=',')
    if len(dat) == 0:
	print 'Problem: the .dat file (scalogram) is empty'
	return 
    if skip:
	dat = dat[1:,:]
    profile = dat.sum(1)
    variance = dat.std(1)
    if normalize:
        profile = profile / norm(profile)
	variance = variance / norm(variance)
    subnb = 2#number of subplots
    if wavfile != None:
	sampling_rate, signal = read(wavfile)
	Pxx, freqs, bins, im = specgram( signal, Fs=sampling_rate )
	(numBins, numSpectra) = Pxx.shape
	Pxx = Pxx[::-1] #to get the low freqs on the bottom
	if all(Pxx > 1e-500):
	    Pxx = 10 * log10(Pxx)#OR ABS()?
	else:
            for i in xrange(len(Pxx)):
	        for j in xrange(len(Pxx[0])):
		    if Pxx[i][j] <= 1e-500:
			Pxx[i][j] = 0
		    else:
                        Pxx[i][j] = 10 * log10(Pxx[i][j])
	subnb = 3
    figure(figsize=(20,15))
    subplot(subnb,1,1)#scalogram
    imshow(dat, interpolation='nearest', aspect='auto')
    title('Scalogram T4 Q20 J40 of file '+datfile.split('_')[4])
    xlabel('Time')
    ylabel('Wavelet coefficients (frequency)')
    subplot(subnb,1,2)#energy profile
    plot(profile, range(len(profile)), label='energy')
    plot(variance, range(len(variance)), label='variance')
    legend()
    title('Normalized energy & variance profile')
    xlabel('Energy')
    ylabel('Wavelet coefficients (frequency)')
    ax = gca()
    ax.set_ylim(ax.get_ylim()[::-1])#inverts the Y axis
    if wavfile != None:
	subplot(subnb,1,3)
	imshow(Pxx, aspect='auto')
	title('Spectrogram')
    if sauve == 'here':
	savefig('energy_profile_dat_'+datfile.split('_')[4]+'.png')
    elif sauve != None:
        savefig(sauve+'energy_profile_dat_'+datfile.split('_')[4]+'.png')
    if showplot:
	show(block=False)

    return profile, variance

#----Energy profile for a species----
def profile(specnb, showplot=False, sauve=None):
    """Computes and plots the energy profile for a species, with several recordings"""
    os.chdir(specroot + str(specnb))
    datfiles = glob.glob('LIFECLEF2014_BIRDAMAZON_XC_WAV_RN*.dat')
    if specnb == 465:
	datfiles = datfiles[:5]
    if specnb == 465:
	datfiles = datfiles[:19]
    profiles = []
    print 'Computing energy profile for species nb '+str(specnb)+' ...'
    for i in xrange(len(datfiles)):
	print 'Loading .dat file nb '+str(i)+':  '+datfiles[i]+' ...'
	prof, var = energy_profile(datfiles[i])
	if prof != None:
            profiles.append(energy_profile(datfiles[i]))#careful, begins with high frequencies
	else:
	    print '.dat file nb '+str(i)+' was empty'
    profiles = array(profiles)

    figure(figsize=(40,30))
    subplot(211) #plot all the profiles (non agregated)
    for i in xrange(len(profiles)):
       	plot(profiles[i], range(len(profiles[i])))
    title('Normalized energy profile of each recording of species #'+str(specnb))
    xlabel('Energy')
    ylabel('Wavelet coefficients (frequency)')
    ax = gca()
    ax.set_ylim(ax.get_ylim()[::-1])
    subplot(212) #boxplot of the profiles
    forboxplot = [profiles[:,i] for i in xrange(len(profiles[0]))]
    boxplot(forboxplot, notch=True, sym='', vert=False, positions=arange(len(forboxplot),0,-1))
    title('Boxplot of the energy profile of species #'+str(specnb))
    xlabel('Energy')
    ylabel('Wavelet coefficients (frequency)')
	
    if sauve == 'here':
        savefig('profile_species_'+str(specnb)+'.png')
    elif sauve != None:
        savefig(sauve+'profile_species_'+str(specnb)+'.png')
    
    if showplot:
	show(block=False)

    return profiles

#----Energy profile for all the species----
if __name__ == '__main__':
    tic = time.clock()
    for i in xrange(465,502):
        print 'Iteration nb ', i
        print 'Time spent:  ', time.clock() - tic
        profile(i, sauve='here')
        close('all')
    close('all')
