# -*- coding: utf-8 -*-
# Nicolas Enfon - 01/04/14 - LSIS DYNI
""" This script is meant to visually check that the scalogram allows us to detect
	 some particular species """
import readfile
import numpy
from pylab import plot, show, subplot, specgram, savefig, figure
from matplotlib.pyplot import *
from scipy.io.wavfile import read,write
from random import random

#Various parameters
nfft = 128 #size of the window for FFT
overlap = 64
audio_file = 'Antares_066458.21.wav'
dat_url = 'ANTARES_66458.8._01_09_2012_22.59.54_02.00.41_T4_Q20_J40_scalo1_L1.dat'
threshold = 0.7 #for the detection of peaks in the trend
window_signal = 0.01 #in seconds

#Reading of the input audio data
sampling_rate, data = read(audio_file)
Pxx, freqs, bins, im = specgram( data, NFFT = nfft, Fs = sampling_rate,  noverlap = overlap)
(numBins, numSpectra) = Pxx.shape
Pxx = Pxx[::-1] #inverting the matrix to get the low freqs on the bottom

#Reading of the .dat file
dat_file_raw = readfile.readdat(dat_url)
dat_file = numpy.array(dat_file_raw)

#Calculating the detector, ie the sum of the energies for each column, with a moving average
trend = dat_file.sum(0)
smoothed_trend = numpy.convolve(trend, [1,1,1,1,1], mode='same')
mean = smoothed_trend.mean()
peaks = []
peaks2 = []#a supprimer
peaks_signal = []
peaks_fft = []
for i in xrange(len(smoothed_trend)):
    if smoothed_trend[i] > mean * (1 + threshold):
	peaks.append(i) #we only keep the index of the peak, for plotting
	peaks_signal.append(   numpy.round( i * len(data) / len(dat_file[0]) )   )
	peaks_fft.append(  numpy.round( i * len(Pxx[0]) / len(dat_file[0]) )    )
	peaks2.append(1)
    else:
	peaks2.append(0)


#Conversions
window_ratio = float(sampling_rate) * window_signal / len(data)
window_sampled = window_ratio * len(data)
window_scattered = window_ratio * len(dat_file[0])
window_spectro = window_ratio * len(Pxx[0])

#Converting the raw data into log data
for i in xrange(len(Pxx)):
    for j in xrange(len(Pxx[0])):
        Pxx[i][j] = 10 * numpy.log10(Pxx[i][j])

#Plotting
fig = figure(figsize=(120,10))
nb = 10 #number of detected peaks we want to see
ax = [ [None for i in xrange(nb)] for j in xrange(3) ]
i = 0
count = 0
peaks_memory = -99
while count < nb:
    if peaks[i] - peaks_memory < 10:
	i += 1
        continue #to avoid plotting several times the same peak
    #Plotting the input signal
    ax[0][count] = subplot(3,nb,count+1)
    #Select the window that will be displayed
    start_sampled = int( round( peaks_signal[i] - window_sampled ) )
    if start_sampled < 0:
	start_sampled = 0
    end_sampled = int( round( peaks_signal[i] + window_sampled ) )
    if end_sampled >= len(data):
	end_sampled = len(data) - 1
    reduced_data = list( data[ start_sampled : end_sampled ] )
    plot( range(len(reduced_data)), reduced_data )
    title('Input signal - ' + str(i))
    xlabel('Time (s)')
    xlim(0,len(reduced_data))
    ticks = numpy.arange(float(start_sampled)/sampling_rate, float(end_sampled)/sampling_rate, 0.005)
    ax[0][count].set_xticklabels(ticks)

    #Plotting the scattering energy detector
    ax[1][count] = subplot(3,nb,count+nb+1)
    start_scattered = int( round( peaks[i] - window_scattered  ) )
    if start_scattered < 0:
	start_scattered = 0
    end_scattered = int( round( peaks[i] + window_scattered  ) )
    if end_scattered >= len(dat_file[0]):
	end_scattered = len(dat_file[0]) - 1
    reduced_dat_file = list( trend[start_scattered:end_scattered] )
    plot( range(len(reduced_dat_file)),  reduced_dat_file , 'g')
    xlim(0, len(reduced_dat_file))
    title('Energy detector - ' + str(i))

    #Plotting the FFT   
    ax[2][count] = subplot(3,nb,count+(2*nb)+1)
    start_spectro = int( round( peaks_fft[i] - window_spectro ) )
    if start_spectro < 0:
	start_spectro = 0
    end_spectro = int( round( peaks_fft[i] + window_spectro ) )
    if end_spectro >= len(Pxx[0]):
	end_spectro = len(Pxx[0]) - 1
    reduced_Pxx = Pxx[:,start_spectro:end_spectro]
    imshow(reduced_Pxx, interpolation=None, aspect='auto', extent=(0,numBins,0,numSpectra))
    title('FFT - ' + str(i))

    i += 1
    count += 1
    peaks_memory = peaks[i]

savefig('a_supprimer_aussi_2.png')
#show()    

