# -*- coding: utf-8 -*-
# Nicolas Enfon - 01/04/14 - LSIS DYNI
import readfile
import numpy
from pylab import plot, show, subplot, specgram, savefig, figure
from matplotlib.pyplot import *
from scipy.io.wavfile import read,write
from random import random
from math import floor, ceil
import cPickle

#Various parameters
audio_file = 'zc06_204a22410-24210.wav'
dat_url = 'zc06_204a22410-24210.wav_T8_Q20_J60.dat'
log_url = 'zc06_204a22410-24210.Good.log'
maybe_url = 'zc06_204a22410-24210.Maybe.log'
other_url = 'zc06_204a22410-24210.Other.log'

#Reading of the input audio data
sampling_rate, data = read(audio_file)

#Makes the trend used by the detector
def make_trend(dat_file, startline=0, endline=None, conv_size=35):
    if endline == None:
	endline=len(dat_file)
    trend = dat_file[startline:endline,:].sum(0)
    smoothed_trend = numpy.convolve(trend, numpy.ones(conv_size), mode='same')
    return smoothed_trend, smoothed_trend.min(), smoothed_trend.max()

#Simple detector
def detector(smoothed_trend, threshold=0.25):
    ''' Detects the energy peaks above a threshold, and saves their position '''
    peaks = []
    for i in xrange( len( smoothed_trend ) ):
	if smoothed_trend[i] > threshold:
	    peaks.append(i)
    return peaks

#Scaled to the data size
def scaled_detector(len_data, len_dat_file, peaks, size=150):
    ''' Scales the peaks and make 0/1 gates that match the audio data size.
	size is in bins, ie in seconds * sampling_rate(data) '''
    peaks = numpy.array(peaks)
    peaks = peaks * len_data / float(len_dat_file)#CAUTION: len_dat_file = len(dat_file[0])
    for i in xrange(len(peaks)):
        peaks[i] = int( round( peaks[i] ) )#caution, numpy array keeps float type
    peaks_gate = numpy.zeros( len_data )
    for i in xrange( len( peaks ) ):
        deb = int( peaks[i] - size )
        if deb < 0:
            deb = 0
        fin = int( peaks[i] + size )
        if fin > len_data:
            fin = len_data
        peaks_gate[ deb : fin ] = 1
    return peaks_gate

#Import and scale given labels
def scalelog2(log, data, sampling_rate, size=10):
    log = log * sampling_rate
    for i in xrange(len(log)):
        log[i][0] = int( round( log[i][0] ) )
        log[i][1] = int( round( log[i][1] ) )
    log_gate = numpy.zeros( len( data ) )
    for i in xrange( len( log ) ):
        deb = int( log[i][0] - size )
        if deb < 0:
            deb = 0
        fin = int( log[i][1] + size )
        if fin > len( data ):#TODO changer le >= en > sur visual4
            fin = len (data )
        if deb > fin:
            print
            print 'Error: deb > fin!'
            print
        log_gate[ deb : fin ] = 1
    return log_gate

#Making the gates
log = numpy.loadtxt(log_url, skiprows=1)
log_gate = scalelog2(log, data, sampling_rate, size=850)
maybe = numpy.loadtxt(maybe_url, skiprows=1)
maybe_gate = scalelog2(maybe, data, sampling_rate, size=850)
other = numpy.loadtxt(other_url, skiprows=1)
other_gate = scalelog2(other, data, sampling_rate, size=850)
#Reading the .dat
dat_file = numpy.loadtxt(dat_url, delimiter=',')

#we cut the data to speed up the computation
cut = 0.004
cut_len = int( floor( cut * len(log_gate) ) )
log_gate_cut = log_gate[:cut_len].astype(int)
maybe_gate_cut = maybe_gate[:cut_len].astype(int)
other_gate_cut = other_gate[:cut_len].astype(int)
all_cut = log_gate_cut + maybe_gate_cut + other_gate_cut
#Mixing the 3 logs
log_gate_cut = all_cut

dat_file_cut = dat_file[:,:ceil( cut * len(dat_file[0]) )]
data_cut = data[:cut_len]

def main(startline, endline):

    smoothed_trend_cut, min, max = make_trend(dat_file_cut, startline, endline, conv_size=51)
    smoothed_trend_avoid, min2, max2 = make_trend(dat_file_cut, -6, 88, conv_size=51)
    sample = numpy.arange(min,max,(max-min)/30)
    TPR = []#True Positives
    FPR = []
    for threshold in sample:
        print 'threshold: ', threshold
        peaks = detector(smoothed_trend_cut, threshold)
        peaks_gate = scaled_detector( len(data), len(dat_file[0]), peaks, size=68)
        peaks_gate_cut = peaks_gate[:cut_len].astype(int)

	peaks_avoid = detector(smoothed_trend_avoid, 3.5)
	peaks_avoid_gate = scaled_detector( len(data), len(dat_file[0]), peaks_avoid, size=68)
	peaks_avoid_cut = peaks_avoid_gate[:cut_len].astype(int)

	#Taking above the threshold1 and under the threshold 2
	peaks_gate_cut =  peaks_gate_cut  &  ( peaks_avoid_cut ^ numpy.ones(cut_len).astype(int) )
	test = peaks_gate_cut  &  ( peaks_avoid_cut ^ numpy.ones(cut_len).astype(int) )
	#True Positive
#	tp = list( numpy.diff( peaks_gate_cut * log_gate_cut ) ).count(1)
#	if (peaks_gate_cut * log_gate_cut)[0] == 1:
#	    tp += 1
#	print 'tp: ', tp
 	#False Negative
#	fn_list = list(   ( peaks_gate_cut ^ numpy.ones(cut_len).astype(int) ) & log_gate_cut   )
#	fn = list(  numpy.diff( fn_list )  ).count(1)
#	if fn_list[0] == 1:
#	    fn += 1
#	print 'fn: ', fn
	#True Positive Rate
#	tpr = float(tp) / (tp + fn)
#	print 'tpr: ', tpr
#	TPR.append(tpr)


	#False Positive
	#fp = 0
	#i = 0
	#while i < len(peaks_gate_cut):
	#    if peaks_gate_cut[i] == 0:
#		i += 1
#	    else:
#		absence = 1
#		while peaks_gate_cut[i] == 1:
#		    if log_gate[i] == 1:
#			absence = 0
#		    i +=1
#		fp += absence
	#print 'fp: ', fp
	#True negative - this ones fits well with the bin definition
	#tn_list = list(  ( peaks_gate_cut ^ numpy.ones(cut_len).astype(int) ) & ( log_gate_cut ^ numpy.ones(cut_len).astype(int) ) )
	#tn = tn_list.count(1)
	#print 'tn: ', tn	

	#False Positive Rate	
	fp_list = list( ( peaks_gate_cut ^ log_gate_cut ) & peaks_gate_cut )
	tn_list = list(  ( peaks_gate_cut ^ numpy.ones(cut_len).astype(int) ) & ( log_gate_cut ^ numpy.ones(cut_len).astype(int) ) )
	fpr = float( fp_list.count(1) ) / (fp_list.count(1) + tn_list.count(1))
	print 'fpr: ', fpr
	FPR.append(fpr)

#For each bin, TP/FP/TN/FN:
        tp_list = list(peaks_gate_cut & log_gate_cut)
        fn_list = list( (peaks_gate_cut ^ log_gate_cut) & (peaks_gate_cut ^ numpy.ones(cut_len).astype(int)) ) 
        tpr = float(tp_list.count(1)) / ( tp_list.count(1) + fn_list.count(1) )
        print 'tpr: ', tpr
        TPR.append(tpr)

#        fp_list = list( ( peaks_gate_cut ^ log_gate_cut ) & peaks_gate_cut )
#        tn_list = list(  ( peaks_gate_cut ^ numpy.ones(cut_len).astype(int) ) & ( log_gate_cut ^ numpy.ones(cut_len).astype(int) ) )
#        fpr = float( fp_list.count(1) ) / (fp_list.count(1) + tn_list.count(1))
#        print 'fpr: ', fpr
#        FPR.append(fpr)	

    return TPR, FPR

file = open('roc_data6.pkl','w')
TPRglob = []
FPRglob = []
for i in np.arange(0,1,1):
    print 'bande numero ', 18, 'à', 68
    TPR, FPR = main(18,68)
    TPRglob.append(TPR)
    FPRglob.append(FPR)
cPickle.dump(TPRglob, file)
cPickle.dump(FPRglob, file)
file.close()

#plotting
file = open('roc_data6.pkl','r')
TPRg = cPickle.load(file)
FPRg = cPickle.load(file)
#AUC
for i in xrange(len(TPRg)):
    TPRg[i].append(0)
    FPRg[i].append(0)
    TPRg[i].insert(0,1)
    FPRg[i].insert(0,1)
for i in xrange(len(TPRg)):
    TPRg[i].reverse()
    FPRg[i].reverse()
AUC = []
for j in xrange(len(TPRg)):
    auc = 0
    for i in xrange(len(TPRg[j])-1):
	auc += ( FPRg[j][i+1] - FPRg[j][i] ) * ( TPRg[j][i+1] + TPRg[j][i] ) / 2.
    print 'auc (trapèzes): ', auc    
    AUC.append(auc)

figure(figsize=(10,10))
for i in xrange(len(TPRg)):
    subplot(1,len(TPRg),i+1)
    plot(FPRg[i],TPRg[i])
    plot(np.arange(0,1,0.01),np.arange(0,1,0.01),'r:')
    title('ROC with scalo bands '+str(18)+' to '+str(68)+' -- AUC ='+str(AUC[i]))
savefig('roc_tpr_bins_fpr_bins.18s.without_last_6.png')

file.close()
