# -*- coding: utf-8 -*-
# Nicolas Enfon - 01/04/14 - LSIS DYNI
import readfile
import numpy
from pylab import plot, show, subplot, specgram, savefig, figure
from matplotlib.pyplot import *
from scipy.io.wavfile import read,write
from random import random
import cPickle

#Various parameters
audio_file = 'ANTARES_66496.21_full_concatenated.wav'
dat_url = 'ANTARES_66496.21_03_09_2012_23.19.41_02.20.28_T8_Q20_J60_scalo1_L1.dat'
threshold1 = 0.2 #for the detection of peaks in the trend
threshold2 = 0.2

#Reading of the input audio data
sampling_rate, data = read(audio_file)

#Reading of the .dat file
dat_file_raw1 = readfile.readdat(dat_url)
dat_file1 = numpy.array(dat_file_raw1)
dat_file1 = dat_file1[60:,:] #we keep only the fisrt 60 rows (low freq)

dat_file_raw2 = readfile.readdat(dat_url)
dat_file2 = numpy.array(dat_file_raw2)
dat_file2 = dat_file2[35:61,:]

#Binary detector
trend1 = dat_file1.sum(0)
trend2 = dat_file2.sum(0)
smoothed_trend1 = numpy.convolve(trend1, numpy.ones(35), mode='same')
smoothed_trend2 = numpy.convolve(trend2, numpy.ones(35), mode='same')
mean1 = smoothed_trend1.mean()
mean2 = smoothed_trend2.mean()
peaks1 = []
peaks2 = []
for i in xrange(len(smoothed_trend1)):
    if smoothed_trend1[i] > mean1 * (1 + threshold1):
        peaks1.append(i) #we only keep the index of the peak, for plotting
for i in xrange(len(smoothed_trend2)):
    if smoothed_trend2[i] > mean2 * (1 + threshold2):
        peaks2.append(i)

#Scaling so that it matches data's size
size1 = 1200
peaks1 = numpy.array(peaks1)
peaks1 = peaks1 * len(data) / float(len(dat_file1[0]))
for i in xrange(len(peaks1)):
    peaks1[i] = int( round( peaks1[i] ) )#caution, numpy array keeps float type
peaks1_gate = numpy.zeros(len(data))
for i in xrange(len(peaks1)):
    deb = int(peaks1[i]-size1)
    if deb < 0:
        deb = 0
    fin = int(peaks1[i]+size1)
    if fin >= len(data):
        fin = len(data) - 1
    peaks1_gate[ deb:fin ] = 1

size2 = 800
peaks2 = numpy.array(peaks2)
peaks2 = peaks2 * len(data) / float(len(dat_file2[0]))
for i in xrange(len(peaks2)):
    peaks2[i] = int(round(peaks2[i]))
peaks2_gate = numpy.zeros(len(data))
for i in xrange(len(peaks2)):
    deb = int(peaks2[i] - size2)
    if deb < 0:
        deb = 0
    fin = int(peaks2[i] + size2)
    if fin >= len(data):
        fin = len(data) - 1
    peaks2_gate[deb:fin] = 1

#Computing the energies
data_sq = data ** 2

def snr(peaks_gate, log=True):
    nb_ones_ener = list(peaks_gate).count(1)
    energ = ( ( data_sq * peaks_gate ).sum() ) / float(nb_ones_ener)
    inv = peaks_gate.astype(int) ^ np.ones(len(peaks_gate)).astype(int)
    nb_ones_inv = list(inv).count(1)
    energ_data = ( (data_sq * inv).sum() ) / float(nb_ones_inv)
    snr = energ / energ_data
    log_snr = 10 * numpy.log10( energ / energ_data )
    if log == True:
        return log_snr
    else:
	return snr

logsnr1 = snr(peaks1_gate)
logsnr2 = snr(peaks2_gate)
snr1 = snr(peaks1_gate, log=False)
snr2 = snr(peaks2_gate, log=False)

print 'log1:', logsnr1
print 'log2:', logsnr2
print 'snr1:', snr1
print 'snr2:', snr2

file = open('peaks.pkl','w')
cPickle.dump(snr1, file)
cPickle.dump(snr2, file)
file.close()
#file.write(logsnr1)
#file.close()

#figure()
#subplot(3,1,1)
#plot(data)
#subplot(3,1,2)
#plot(peaks1_gate)
#subplot(3,1,3)
#plot(peaks2_gate)
#savefig('test_snr.png')
