# -*- coding: utf-8 -*- 
#Nicolas Enfon - LSIS DYNI - NortekMED
from pylab import *
from time import *
from dwt_1 import scalogram
from scipy.io.wavfile import *
import os, glob, pywt

'''Plot the scalograms made with the different wavelet families in order to 
see which ones are the best'''


def plot_scalograms(wavelist, plotdir, sounddir=None, soundfiles=None, level=6, makedirs=True):
    '''Iterates over the wavelet list and plots the scalograms

	@soundfiles: list of the soundfiles
	@wavelist: list of the wavelet names to be used
	@plotdir: directory where the figures will be saved
	@level: level of decomposition'''

    #Take care of making the plot folders
    if makedirs:
	print 'listing sound directories'
        #List the sound directories
	os.chdir(sounddir)
	replist = os.listdir('.')
	sounddirs = []
	for i in xrange(len(replist)):
	    if os.path.isdir(replist[i]):
		sounddirs.append(replist[i])
	#Create the output directories
	print 'making output directories'
	os.chdir(plotdir)
	for i in xrange(len(sounddirs)):
	    direc = sounddirs[i].split('/')
	    if direc[-1] != '':
		direc = direc[-1]
	    elif direc[-2] != '':
		direc = direc[-2]
	    else:
		print " /!\ Problem with the sound directory names /!\ "
	    if direc not in os.listdir('.'):
                os.makedirs(direc)

    if makedirs:

	#for each species directory
	#we had memory error, so we plotted only some files. Now we try to plot the rest, by starting from the end:
	sounddirs = sounddirs[::-1]#TOREMOVE LATER

	for i in xrange(len(sounddirs)):

	    os.chdir(sounddir + sounddirs[i])
	    soundfiles0 = os.listdir('.')
	    soundfiles = []
	    #keep the wav, not the mp3
	    for z in xrange(len(soundfiles0)):
		if soundfiles0[z][-4:] != '.mp3':
		    soundfiles.append(soundfiles0[z])
	    #idem memory error:
	    soundfiles = soundfiles[::-1]#TOREMOVE LATER

	    #for each recording
	    for j in xrange(len(soundfiles)):

                os.chdir(sounddir + sounddirs[i])
		sample_rate, signal = read(soundfiles[j])
		#if stereo file: keep only one track
		if len(signal.shape) != 1:
		    signal = signal[:,0]
                os.chdir(plotdir + sounddirs[i])

                #for each wavelet: computes and plots the scalograms
                for k in xrange(len(wavelist)):

		    #TOREMOVE: used only for making the spectrograms and not the scalograms
		    pass
		
#                    wvlt = wavelist[k]
#		    print 'file:', soundfiles[j], 'wavelet:', wvlt
 #                   scalo = scalogram(signal, wvlt, level=level)
  #                  figure(figsize=(30,15))
   #                 imshow(scalo, interpolation='nearest', aspect='auto') #if I want the y axis inverted: scalo=scalo[::-1] then imshow(scalo, origin='lower')
#                    title('Soundfile: '+soundfiles[j]+' sample rate: '+str(sample_rate)+' total length: '+str(  round( (len(signal)/float(sample_rate)), 2 )  )+' seconds   Wavelet: '+wvlt+'  level: '+str(level))
#   		    if soundfiles[j].split('/')[-1][:-4] not in os.listdir('.'):
#      		        os.makedirs(soundfiles[j].split('/')[-1][:-4])
#		    labels = arange(0, len(scalo[0]), len(scalo[0])/10)
#		    labels2 = labels * len(signal) / float(sample_rate) / len(scalo[0])
#		    labels2 = around(labels2, 1)
#		    xticks(labels, labels2)
#		    xlabel('Time (s)')
#		    ylabel('Scalogram coefficients')
#                   savefig(soundfiles[j].split('/')[-1][:-4] + '/' + wvlt + '_on_' + soundfiles[j].split('/')[-1][:-3] + 'png')

		#plot spectragram, to have an idea of the frequencies
		Pxx, freqs, bins, im = specgram(signal, NFFT=256, Fs=sample_rate, detrend=mlab.detrend_none, window=mlab.window_hanning, noverlap=128, cmap=None)
     	        print 'file:', soundfiles[j], 'scalogram'
		figure(figsize=(30,15))
		Pxx = 10 * log10(Pxx)
		imshow(Pxx, origin='lower', aspect='auto')		
		title('Soundfile: '+soundfiles[j]+' sample rate: '+str(sample_rate)+' total length: '+str(  round( (len(signal)/float(sample_rate)), 2 )  )+' seconds')
                labels = arange(0, len(Pxx[0]), len(Pxx[0])/10)
                labels2 = labels * len(signal) / float(sample_rate) / len(Pxx[0])
                labels2 = around(labels2, 1)
                xticks(labels, labels2)
		labely = arange(0, len(freqs), len(freqs)/10)
		labely2 = labely * freqs[-1] / len(Pxx)
		labely2 = around(labely2, 1)
		yticks(labely, labely2)
                xlabel('Time (s)')
                ylabel('Frequency (Hz)')
                savefig(soundfiles[j].split('/')[-1][:-4] + '/' + 'spectrogram_on_' + soundfiles[j].split('/')[-1][:-3] + 'png')




    else:	

        #Load the soundfile
        for j in xrange(len(soundfiles)):

            sample_rate, signal = read(soundfiles[j])

            #Computes and plots the scalograms
            for i in xrange(len(wavelist)):

                wvlt = wavelist[i]
                scalo = scalogram(signal, wvlt, level=level)
                figure(figsize=(30,15))
                imshow(scalo, interpolation='nearest', aspect='auto')
                title('Soundfile '+soundfiles[j]+'  Wavelet '+wvlt+'  level '+str(level))
                savefig(plotdir + wvlt + '_on_' + soundfiles[j].split('/')[-1][:-3] + 'png')



if __name__ == "__main__" :
    print asctime()
    tic = time()
    sounddir = '/NAS3/SABIOD/SITE/NORTEKMED/SAMPLE_SOUNDS/'
    soundfiles = ['/NAS3/SABIOD/SITE/COERULEOALBA/S_coeruleoalba_whistles.wav']
    plotdir = '/NAS3/SABIOD/public_data/NRL/WAVELET_FAMILIES/'
    wavelist = pywt.wavelist()
    plot_scalograms(wavelist, plotdir, sounddir)
    print 'Done in', time() - tic, 'seconds'
