# -*- coding: utf-8 -*-
# Nicolas Enfon - 01/04/14 - LSIS DYNI
""" This script is meant to visually check that the scalogram allows us to detect
         some particular species """
import readfile
import numpy
from pylab import plot, show, subplot, specgram, savefig, figure
from matplotlib.pyplot import *
from scipy.io.wavfile import read,write
from random import random

#Various parameters
nfft = 128 #size of the window for FFT
overlap = 64
audio_file = 'ANTARES_66496.21_full_concatenated.wav'
dat_url = 'ANTARES_66496.21_03_09_2012_23.19.41_02.20.28_T8_Q20_J60_scalo1_L1.dat_Detector.dat'
threshold = 0.85 #for the detection of peaks in the trend
window_signal = 0.02 #in seconds

#Reading of the input audio data
sampling_rate, data = read(audio_file)
Pxx, freqs, bins, im = specgram( data, NFFT = nfft, Fs = sampling_rate,  noverlap = overlap)
(numBins, numSpectra) = Pxx.shape
Pxx = Pxx[::-1] #to get the low freqs on the bottom

#Converting the raw data into log data
for i in xrange(len(Pxx)):
    for j in xrange(len(Pxx[0])):
        Pxx[i][j] = 10 * numpy.log10(Pxx[i][j])

#Reading of the .dat file
dat_file_raw = readfile.readdat(dat_url)
dat_file = numpy.array(dat_file_raw) 

#TODO: moyenne mobile a 60 points

#Conversions
window_ratio = float(sampling_rate) * window_signal / len(data)
window_sampled = window_ratio * len(data)
window_scattered = window_ratio * len(dat_file[0])
window_spectro = window_ratio * len(Pxx[0])

#Binary detector
trend = dat_file.sum(0)
smoothed_trend = numpy.convolve(trend, [1,1,1,1,1], mode='same')
mean = smoothed_trend.mean()
peaks = []
for i in xrange(len(smoothed_trend)):
    if smoothed_trend[i] > mean * (1 + threshold):
	peaks.append(1)
    else:
	peaks.append(0)

#Plotting
nb = 30
fig = figure(figsize=(100,10))
ax = [ None for i in xrange(nb) ]
i = 0
count = 1
memory = -11
while count < nb + 1 and i < len(smoothed_trend):
    if peaks[i] != 0 and i - memory > 10:
	subplot(3,nb,count)
	start = i - window_scattered
	if start < 0:
	    start = 0
	end = i + window_scattered
	if end >= len(smoothed_trend):
	    end = len(smoothed_trend) - 1
	plot(smoothed_trend[start:end])


	subplot(3,nb,count+nb)
	start = int( round( i * len(data) / len(smoothed_trend) - window_sampled) )
        if start < 0:
            start = 0
        end = int (round( i * len(data) / len(smoothed_trend) + window_sampled) )
        if end >= len(data):
            end = len(data) - 1
	plot(data[start:end])

	subplot(3,nb,2*nb + count)
        start = int( round( i * len(Pxx[0]) / len(smoothed_trend) - window_spectro) )
        if start < 0:
            start = 0
        end = int (round( i * len(Pxx[0]) / len(smoothed_trend) + window_spectro) )
        if end >= len(data):
            end = len(data) - 1
	imshow(Pxx[:,start:end])

        count += 1
        memory = i

    i +=1

#show()
savefig('test.png')
