import argparse
import sys
import glob
import numpy as np
import pandas as pd
from tqdm import tqdm
import librosa as lr
import torch
import git

import matplotlib.pyplot as plt
import soundfile as sf
import scipy.signal as sg
import os

from models import DetectorErbs


def norm(x, eps=1e-10, axis=None):
    return (x - x.mean(axis, keepdims=True)) / (x.std(axis, keepdims=True) + eps)

SR = 96_000

class AudioSignalForward:
    def __init__(self, audio_folder, target_channel, hp=False, sig_dur=5.):
        super(AudioSignalForward, self)

        self.samples = []

        raw_files = glob.glob(audio_folder)
        raw_files.sort()
        self.hp = hp
        self.hp_sos = sg.butter(6, 2500, 'hp', output='sos', fs=SR)
        self.sig_dur = sig_dur
        self.target_channel = target_channel

        self.window = sg.windows.hann(4096)
        self.nfft = 4096

        for file in tqdm(raw_files, desc='Dataset initialization'):
            try:
                info = sf.info(file)
                duration, sr, nb_channels = info.duration, info.samplerate, info.channels

                if duration > sig_dur and nb_channels > self.target_channel:
                    self.samples.extend([{'file':file, 'offset':offset, 'sr':sr} for offset in np.arange(0, duration - self.sig_dur , self.sig_dur)])
            except Exception as e:
                continue

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        try:
            sig, sr = sf.read(os.path.join(sample['file']), start=int(sample['offset']*sample['sr']), stop=int((sample['offset']+self.sig_dur)*sample['sr']), always_2d=True)
        except Exception as e:
            print(e)
            print('Failed loading '+sample['file'])
            return None

        sig = sig[:, self.target_channel]
        if sr != SR:
            sig = sg.resample(sig, int(len(sig) * SR / sr))

        sig = norm(sig)
        tf_sig = np.abs(lr.stft(
            sig, 
            n_fft=self.nfft, 
            hop_length=938, 
            window=self.window, 
            center=True
            ))   # self.nfft//4 ~= 938    
        return tf_sig.astype(np.float32)[None], sample


def main(args):
    dst = AudioSignalForward("/nfs/NAS6/SABIOD/SITE/CARIMAM/DATA/LOT2/JAM_20210406_20320510/*0429*", target_channel=0, hp=True)
    sig, _ = dst[0]

    output_file = args.output_file

    model = DetectorErbs()

    repo = git.Repo(
        os.path.dirname(os.path.abspath(__file__)),
        search_parent_directories=True
        )
    
    git_branch = repo.head.ref.name
    git_sha = repo.head.object.hexsha         

    # prepare model
    model_name = args.model_name

    current_version = git_branch + '_' + git_sha
    model_version = torch.load(model_name)["git_info"]
    assert(current_version == model_version,
           "model and code version are not the same")

    state_dict = torch.load(model_name)["model"]

    try:
        model.load_state_dict(state_dict)
        print("load model weight succesfully")
    except RuntimeError:
        print("Error Load detector weights")
        exit()
    

    model.eval()

    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda:' + args.device
    model.to(device)


    #GET VERSION GIT

    # prepare data loader and output storage for predictions
    dl = torch.utils.data.DataLoader(
        dst,
        batch_size=args.bs,
        num_workers=args.nb_workers,
        prefetch_factor=4)

    df_res = pd.DataFrame(columns=['filename', 'offset', 'prediction', 'contexts'])
    files, offsets, preds, contexts = [], [], [], []

    # forward the model on each batch
    with torch.no_grad():
        for sig, meta in tqdm(dl, desc='Model inference'):
            sig = sig.to(device)
            batch_preds, batch_context = model(sig)
            batch_preds = torch.nn.functional.softmax(batch_preds, dim=1)
            batch_context = torch.nn.functional.softmax(batch_context, dim=1)

            batch_preds = batch_preds.cpu().detach().numpy()
            batch_context = batch_context.cpu().detach().numpy()

            preds.extend(batch_preds)
            contexts.extend(batch_context)

            files.extend(meta['file'])
            offsets.extend(meta['offset'].numpy())

    df_res.filename, df_res.offset, df_res.prediction, df_res.contexts = files, offsets, preds, contexts

    print(f"Saving results into {output_file}")
    df_res['pos'] = df_res['prediction'].apply(lambda x: x[1] > 0.5)

    df_res.to_csv(output_file, index=False)
    df_res.to_pickle(output_file.replace('.csv', '.pkl'))

    return 0

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--hp", action='store_true', help="Highpass filter frequency")
    parser.add_argument("--device", type=str, default='0', help='Gpu device')
    parser.add_argument("--model_name", type=str, default='/nfs/NAS6/mahe/src/click_detector_erbs/runs/cstm_name=VERIF_V10:hp=True:lr=0.000500:ne=30:wd=0.050000:bs=128:aug=False/ckpt_000969.pth', help="name of the model")
    parser.add_argument("--nb_workers", type=int, default=12,
                        help="Number of workers")
    parser.add_argument("--bs", type=int, default=32, help="Batch size")
    parser.add_argument("--output_file", type=str, default='output_file.csv',
                        help="name of the output file")

    sys.exit(main(parser.parse_args()))




def analyse_results(df, fpr_estim=0.02):
    for _, grp in tqdm(df.groupby('filename')):
        nb_items = len(grp)
        thd_pos = int(np.ceil(fpr_estim*nb_items))
        if grp.pos.sum() < thd_pos:
            continue
        pred_summary = np.vstack(grp.prediction)
        plt.figure()
        #plt.imshow(pred_summary, aspect='auto', interpolation=None, origin='lower')
        plt.plot(pred_summary[:,1])
        plt.title(grp.iloc[0]['filename'])
        plt.savefig(os.path.basename(grp.iloc[0]['filename']).replace(".WAV", ".pdf"))
