# -*- coding: utf-8 -*- 
#Nicolas Enfon - LSIS DYNI - NortekMED
from pylab import *
from time import *
from dwt_1 import *
from scipy.signal import convolve2d, correlate2d, fftconvolve
import theano
from theano.tensor.nnet import conv

#We can use the Theano convolution to speed up this part
image = theano.tensor.tensor4()
filter = theano.tensor.tensor4()
result = conv.conv2d(image, filter, border_mode='full')
theano_conv2d = theano.function([image, filter], result)


def detector(signal, pattern, wavelet='db2', level=6, normalized=False, freq_range=(20,40), detector_type='energy', smooth=False):
    '''Computes a DWT (Discrete Wavelet Transform) and a detector based on a convolution with a given pattern

        @signal: numerized audio signal
	@pattern: previously made pattern of whale vocalization
        @wavelet: PyWavelets wavelet type
        @level: level of decomposition. Level=6 gives 2 ^ 6 = 64 rows
        @normalized: normalization of the scalogram
        @freq_range: frequency range where we look for information
        @detector_type: energy, diff(energy) or diff(diff(energy))
        @smooth: smoothing of the detector with a moving average ie convolution with ones(x)

        returns: an array detector made accordingly to the specified rules'''

    print 'Computing scalogram ...'
    scalo = scalogram(signal, wavelet=wavelet, level=level)
    print 'scalogram shape:', scalo.shape
    print 'Computing correlation ...'
    #the kernel is flipped L<->R & U<->D so that the convolution is actually a cross-correlation
    #Scipy (slow):
    #correlated = correlate2d(scalo, pattern)
    #Theano:
    #correlated = theano_conv2d([[scalo]], [[fliplr(pattern[::-1])]])	
    #Scipy fft implementation:
    correlated = fftconvolve(scalo, fliplr(pattern[::-1]), mode='same')

    #TODO: gérer le fait que ce soit ou 'valid' ou 'full' avec theano
    #mais fftconvolve fait le 'same' centré sur le 'full', je peux faire pareil

    #The frequency range can be given in %, or in row numbers
    #This can be seen as where we are looking for the given species
    if type(freq_range[0]) != int:
        h_freq = int(freq_range[0] * values.shape[0])
        l_freq = int(freq_range[1] * values.shape[0])
    else:
        h_freq = freq_range[0]
        l_freq = freq_range[1]

    #Energy of the signal
    if detector_type == 'energy':
        detector = correlated[h_freq:l_freq,:]
        detector = detector.sum(0)
    #Derivative of the signal's energy
    if detector_type == 'd_energy':
        detector = correlated[h_freq:l_freq,:]
        detector = detector.sum(0)
        #We differentiate, and add a 0 in the first column to replace the missing value
        detector = diff(detector)
        detector = hstack((0,detector))
    #Second derivative of the signal's energy
    if detector_type == 'd2_energy':
        detector = correlated[h_freq:l_freq,:]
        detector = detector.sum(0)
        detector = diff(detector)
        detector = diff(detector)
        detector = hstack((0,0,detector))

    #smoothing with a moving average
    if smooth != False:
        detector = convolve(detector, ones(smooth) / float(smooth), mode='same')

    return scalo, detector

def booleanized(scalo, detector, threshold_type='mean', threshold_coeff=0.5, external=False, above=True):
    '''Transforms a detector array into a boolean array, with 1 where there is a detection and 0 otherwise

        @scalo: scalogram
        @detector: detector previously computed on the scalogram
        @threshold_type: mean or average
        @threshold_coeff: if threshold_coeff == 0.3, values 30% above the  signal.mean() will trigger a detection
        @external: absolute threshold, if previously computed on a broader time window
        @above: > or < threshold

        returns: a booleanized detector'''

    if not(external):
        if threshold_type == 'mean':
            if above:
                threshold = (1 + threshold_coeff) * detector.mean()
            else:
                threshold = (1 - threshold_coeff) * detector.mean()
        elif threshold_type == 'median':
            if above:
                threshold = (1 + threshold_coeff) * median(detector)
            else:
                threshold = (1 - threshold_coeff) * median(detector)

    if above:
        booleanized = detector > threshold
    else:
        booleanized = detector < threshold

    return booleanized

def decision(booleanized, threshold=0.2):
    '''Returns the decision based on the per sample detection

        @booleanized: booleanized detector
        @threshold: % of detection above which we decide the window is actually positive

        returns: a decision (boolean)'''

    if sum(booleanized) / float(len(booleanized)) >= threshold:
        decision = True
    else:
        decision = False

    return decision

def agregated(decisions, threshold=0.1):
    '''Returns the decision based on decisions made on several windows of signal

        @decisions: decisions made upon several windows of signal
        @threshold: % of positive decisions above which we decide the group is positive as a whole

        returns: a decision (boolean)'''

    decisions = array(decisions)

    if sum(decisions) / float( sum(decisions) + sum(decisions ^ 1)   ) >= threshold:
        decision = True
    else:
        decision = False

    return decision


if __name__ == '__main__':
    print asctime()
    tic = time()
    print 'Testing ...'
    soundfile = '/NAS3/SABIOD/SITE/NORTEKMED/SAMPLE_SOUNDS/BOTTLENOSE/04._Bottlenose_dolphin.wav'
    scal, detec = detector(soundfile, randint(0,100,(50,100)))
    b = booleanized(scal, detec)
    d = decision(b)
    a = agregated(d)
    print 'Done in ',time() - tic, 'seconds (the count can be wrong if on utln servers)'


