import requests
import os
import librosa
import numpy as np
import pandas as pd
import time
from scipy.signal import butter, lfilter
import xml.etree.ElementTree as ET
import soundfile as sf
import subprocess
import matplotlib.pyplot as plt
import shutil

def get_mp3_sample_rates(mp3_file):
    SAMPLE_RATES = {
        'MPEG1': [44100, 48000, 32000],
        'MPEG2': [22050, 24000, 16000],
        'MPEG2.5': [11025, 12000, 8000]
    }

    sample_rates_found = set()

    with open(mp3_file, 'rb') as f:
        data = f.read()

    i = 0
    while i < len(data) - 4:
        # Cherche un sync word 0xFFEx
        if data[i] == 0xFF and (data[i+1] & 0xE0) == 0xE0:
            version_id = (data[i+1] >> 3) & 0x03
            sample_rate_index = (data[i+2] >> 2) & 0x03

            if version_id == 0b11:
                version = 'MPEG1'
            elif version_id == 0b10:
                version = 'MPEG2'
            elif version_id == 0b00:
                version = 'MPEG2.5'
            else:
                version = 'Unknown'

            if sample_rate_index == 0b11:
                # réservé
                i += 1
                continue

            if version in SAMPLE_RATES:
                sample_rate = SAMPLE_RATES[version][sample_rate_index]
                sample_rates_found.add(sample_rate)

            # Calcul de la taille de la frame (simplifié, dépendrait du bitrate etc.)
            i += 4  # avance de 4 octets pour éviter boucle infinie
        else:
            i += 1

    return sample_rates_found

def reencode_mp3(input_file, target_rate):
    # Créer un fichier temporaire dans le même dossier
    dir_name = os.path.dirname(input_file)
    base_name = os.path.basename(input_file)
    temp_file = os.path.join(dir_name, "temp_" + base_name)

    try:
        # Commande FFmpeg
        subprocess.run([
            "ffmpeg", "-y", "-i", input_file,
            "-ar", str(target_rate), "-ac", "1",
            temp_file
        ], check=True)

        # Supprimer l'ancien fichier
        os.remove(input_file)

        # Renommer le fichier temporaire en nom original
        os.rename(temp_file, input_file)
        print(f"✅ Réencodage terminé : {input_file}")

    except subprocess.CalledProcessError as e:
        print(f"❌ Erreur FFmpeg : {e}")
        if os.path.exists(temp_file):
            os.remove(temp_file)  # Nettoyer le temporaire en cas d'erreur
            
def bandpass_filter(data, lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    y = lfilter(b, a, data)
    return y

PATH = r"Insert/Your/Path" # were you will download files

french_name = [ "Merle noir", "Rougegorge familier", "Mésange charbonnière", "Mésange bleue","Pinson des arbres", "Chardonneret élégant", "Moineau domestique","Pie bavarde","Corneille noire", "Pigeon ramier",
    "Tourterelle turque", "Hirondelle rustique", "Grive musicienne", "Troglodyte mignon", "Fauvette à tête noire","Bergeronnette grise","Alouette des champs", "Étourneau sansonnet","Rossignol philomèle", "Roitelet huppé"]

species = ["Turdus merula","Erithacus rubecula","Parus major","Cyanistes caeruleus","Fringilla coelebs", "Carduelis carduelis","Passer domesticus","Pica pica","Corvus corone","Columba palumbus", "Streptopelia decaocto", "Hirundo rustica",
    "Turdus philomelos","Troglodytes troglodytes","Sylvia atricapilla","Motacilla alba", "Alauda arvensis","Sturnus vulgaris", "Luscinia megarhynchos", "Regulus regulus"]

name_dir = [
    "TuMe",  # Turdus merula
    "ErRu",  # Erithacus rubecula
    "PaMa",  # Parus major
    "CyCa",  # Cyanistes caeruleus
    "FrCo",  # Fringilla coelebs
    "CaCa",  # Carduelis carduelis
    "PaDo",  # Passer domesticus
    "PiPi",  # Pica pica
    "CoCo",  # Corvus corone
    "CoPa",  # Columba palumbus
    "StDe",  # Streptopelia decaocto
    "HiRu",  # Hirundo rustica
    "TuPh",  # Turdus philomelos
    "TrTr",  # Troglodytes troglodytes
    "SyAt",  # Sylvia atricapilla
    "MoAl",  # Motacilla alba
    "AlAr",  # Alauda arvensis
    "StVu",  # Sturnus vulgaris
    "LuMe",  # Luscinia megarhynchos
    "ReRe"   # Regulus regulus
]

target_duration = 60 * 45  # Expl : 30 min per target
duration_cumulative = 0

for idx in range(2, len(species), 1):
    
    save_dir = os.path.join(PATH, name_dir[idx])
    os.makedirs(save_dir, exist_ok=True)

    query = species[idx].replace(" ", "+") + "+q:A"
    api_url = f"https://xeno-canto.org/api/2/recordings?query={query}"

    response = requests.get(api_url)
    data = response.json()

    print(f"Nombre d'enregistrements trouvés pour {species[idx]} : {data['numRecordings']}")
    records_metadata = []

    duration_cumulative = 0
    shutil.copy(os.path.join(PATH, name_dir[0], "labelMap.xml"), os.path.join(save_dir, "labelMap.xml")) ## copie labelMap
    page = 1
    while duration_cumulative < target_duration:
        print(f"Page {page}")
        for i, rec in enumerate(data['recordings']):
            audio_url = rec['file'] if rec['file'].startswith('http') else f"https:{rec['file']}"
            filename = os.path.join(save_dir, f"{name_dir[idx]}_{i}.mp3")
            
            # Télécharger
            audio_data = requests.get(audio_url)
            with open(filename, 'wb') as f:
                f.write(audio_data.content)
            print(f"Téléchargé: {filename}")

            # Réecriture du mp3
            sample_rates = get_mp3_sample_rates(filename)
            if len(sample_rates) > 1:
                max_rate = max(sample_rates)
                print(f"⚠️ Plusieurs sample rates. Réencodage en {max_rate} Hz...")
                reencode_mp3(filename, max_rate)

            try:
                y, sr = librosa.load(filename, sr=None, mono=True)
                duration_audio = librosa.get_duration(y=y, sr=sr)
              
                # # Ecrire le fichier par une version mono WAV (plus lourd que le mp3 même en mono)
                # filename = os.path.splitext(filename)[0] + ".wav"
                # sf.write(filename, y, sr) 
                # print(f"Fichier converti en mono et sauvegardé sous: {filename}")

                # Filtre passe-bande
                y_filtered = bandpass_filter(y, 2000, 8000, sr)

                # Calcul énergie
                frame_length = 2048
                hop_length = 512
                energy = np.array([
                    np.sum(np.abs(y_filtered[i:i+frame_length]**2))
                    for i in range(0, len(y_filtered), hop_length)])
                energy_db = 10 * np.log10(energy + 1e-10)  # évite log(0)

                # Seuil dynamique basé sur bruit moyen
                energy_mean = np.mean(energy_db)
                energy_std = np.std(energy_db)
                dynamic_threshold = energy_mean + 0.5 *energy_std # ajustable dynamic_threshold = np.percentile(energy, 75)  # seuil au 75e percentile
                # print (f"Energy mean : {energy_mean}")
                # print (f"Energy std : {energy_std}")

                # print(f"Seuil dynamique: {dynamic_threshold:.6f}")
                # plt.figure(figsize=(12,4))
                # plt.plot(energy_db, label='Energy')
                # plt.axhline(dynamic_threshold, color='r', linestyle='--', label='Threshold')
                # plt.title(f"Energie + seuil : {filename}")
                # plt.xlabel('Frame')
                # plt.ylabel('Energy')
                # plt.legend()
                # plt.show()

                voiced = energy_db > dynamic_threshold
                starts = []
                ends = []
                in_segment = False
                for idx_frame, val in enumerate(voiced):
                    if val and not in_segment:
                        in_segment = True
                        starts.append(idx_frame * hop_length / sr)
                    elif not val and in_segment:
                        in_segment = False
                        ends.append(idx_frame * hop_length / sr)
                if in_segment:
                    ends.append(len(y) / sr)

                # Fusionner les intervalles proches (<0.5s)
                merged_starts = []
                merged_ends = []
                if starts:
                    current_start = starts[0]
                    current_end = ends[0]

                    for i in range(1, len(starts)):
                        gap = starts[i] - current_end
                        if gap <= 0.5:
                            # fusionner
                            current_end = ends[i]
                        else:
                            # ajouter l'intervalle précédent
                            merged_starts.append(current_start)
                            merged_ends.append(current_end)
                            # démarrer nouveau intervalle
                            current_start = starts[i]
                            current_end = ends[i]
                    
                    # ajouter le dernier intervalle
                    merged_starts.append(current_start)
                    merged_ends.append(current_end)
                
                # Supprimer les intervalles trop courts (<0.3s)
                filtered_starts = []
                filtered_ends = []
                for s, e in zip(merged_starts, merged_ends):
                    if (e - s) >= 0.3:
                        filtered_starts.append(s)
                        filtered_ends.append(e)
                 
                # Remplacer starts/ends par fusionnés
                starts = filtered_starts
                ends = filtered_ends

                # Analyse spectrogramme et amplitude par segment
                S = np.abs(librosa.stft(y_filtered, n_fft=2048, hop_length=hop_length))
                freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
                
                # plt.figure(figsize=(12,6))
                # librosa.display.specshow(librosa.amplitude_to_db(S, ref=np.max), sr=sr, hop_length=hop_length, x_axis='time', y_axis='hz')
                # for start, end in zip(starts, ends):
                #     plt.axvspan(start, end, color='lime', alpha=0.3)
                # plt.colorbar(format='%+2.0f dB')
                # plt.title(f"Spectrogramme + intervalles détectés : {filename}")
                # plt.show()

                # Créer racine XML
                root = ET.Element("Annotations")

                for (start, end) in zip(starts, ends):
                    if end <= start:
                        print(f"⚠️ Segment vide ou invalide : start_frame={start_frame}, end_frame={end_frame}")
                        continue
                    duration_cumulative += (end-start)

                    start_sample = int(start * sr)
                    end_sample = int((end+0.2) * sr)

                    # Spectrogramme et fréquence
                    start_frame = int(start_sample / hop_length)
                    end_frame = int(end_sample / hop_length)

                    start_frame = max(0, min(start_frame, S.shape[1]-1))
                    end_frame = max(0, min(end_frame, S.shape[1]))

                    segment_spec = S[:, start_frame:end_frame]
                    segment_freqs = freqs[np.any(segment_spec > 0, axis=1)]

                    # Calcul seuil dynamique sur la période (par fréquence)
                    freq_energy = np.mean(segment_spec, axis=1)  # moyenne sur le temps, pour chaque fréquence
                    freq_energy_db = 10 * np.log10(freq_energy + 1e-10)  # en dB

                    freq_mean = np.mean(freq_energy_db)
                    freq_std = np.std(freq_energy_db)
                    freq_threshold = freq_mean + 0.5 * freq_std  # ajustable

                    # Sélection des fréquences actives
                    freq_mask = freq_energy_db > freq_threshold
                    active_freqs = freqs[freq_mask]

                    freq_min = active_freqs.min() if len(active_freqs) > 0 else None
                    freq_max = active_freqs.max() if len(active_freqs) > 0 else None

                    amp_min = y_filtered[start_sample:end_sample].min() if end_sample > start_sample else None
                    amp_max = y_filtered[start_sample:end_sample].max() if end_sample > start_sample else None

                    # Créer balise Annotation
                    annotation = ET.SubElement(root, "Annotation")

                    ET.SubElement(annotation, "Start_s").text = str(start)
                    ET.SubElement(annotation, "End_s").text = str(end)
                    ET.SubElement(annotation, "FqMin_Hz").text = str(freq_min) if freq_min is not None else "None"
                    ET.SubElement(annotation, "FqMax_Hz").text = str(freq_max) if freq_max is not None else "None"
                    ET.SubElement(annotation, "AmplMin_V").text = str(amp_min) if amp_min is not None else "None"
                    ET.SubElement(annotation, "AmplMax_V").text = str(amp_max) if amp_max is not None else "None"
                    ET.SubElement(annotation, "Id").text = str(idx)

                 
                print(f"Durée fichier: {duration_audio:.2f}s - Durée cumulée: {duration_cumulative/60:.2f} min")
                # Enregistrer fichier XML
                xml_filename = os.path.splitext(filename)[0]+ "_Annotation.xml"
                tree = ET.ElementTree(root)

                # Écriture avec indentation propre
                from xml.dom import minidom
                xmlstr = minidom.parseString(ET.tostring(root)).toprettyxml(indent="    ")

                with open(xml_filename, "w", encoding="utf-8") as f:
                    f.write(xmlstr)

                print(f"Annotations sauvegardées dans {xml_filename}")

            except Exception as e:
                print(f"Erreur lors de la lecture/détection: {e}")

            if duration_cumulative >= target_duration:
                break

        if duration_cumulative < target_duration and 'nextPage' in data:
            page += 1
            api_url = f"https://xeno-canto.org/api/2/recordings?query={query}&page={page}"
            response = requests.get(api_url)
            data = response.json()
        else:
            break

print("✅ Téléchargement et extraction terminés.")
