from pydub import AudioSegment, silence
import math
import os
import numpy as np
import subprocess
import xml.etree.ElementTree as ET
import shutil  
from scipy.signal import butter, filtfilt

#### Birds Classes (To complet if more) ####
name_dir = [ # Official Name :          - French Name :
    "TuMe",  # Turdus merula            - Merle noir
    "ErRu",  # Erithacus rubecula       - Rougegorge familier
    "PaMa",  # Parus major              - Mésange charbonnière
    "CyCa",  # Cyanistes caeruleus      - Mésange bleue
    "FrCo",  # Fringilla coelebs        - Pinson des arbres
    "CaCa",  # Carduelis carduelis      - Chardonneret élégant
    "PaDo",  # Passer domesticus        - Moineau domestique
    "PiPi",  # Pica pica                - Pie bavarde
    "CoCo",  # Corvus corone            - Corneille noire
    "CoPa",  # Columba palumbus         - Pigeon ramier
    "StDe",  # Streptopelia decaocto    - Tourterelle turque
    "HiRu",  # Hirundo rustica          - Hirondelle rustique
    "TuPh",  # Turdus philomelos        - Grive musicienne
    "TrTr",  # Troglodytes troglodytes  - Troglodyte mignon
    "SyAt",  # Sylvia atricapilla       - Fauvette à tête noire
    "MoAl",  # Motacilla alba           - Bergeronnette grise
    "AlAr",  # Alauda arvensis          - Alouette des champs
    "StVu",  # Sturnus vulgaris         - Étourneau sansonnet
    "LuMe",  # Luscinia megarhynchos    - Rossignol philomèle
    "ReRe"   # Regulus regulus          - Roitelet huppé
]

#### Noise Classes ###
noise_dir = ["airplane", "car_horn", "cat", "chainsaw", "church_bell", "cow", "crackling_fire", "crickets", "dog", "engine", "fireworks", "footsteps","frog", "helicopter", "hen", "insects", "pig", 
             "rooster", "sea_waves", "sheep", "sneezing", "thunderstorm", "train", "rain", "wind"]

SAMPLE_RATES = {
    'MPEG1': [44100, 48000, 32000],
    'MPEG2': [22050, 24000, 16000],
    'MPEG2.5': [11025, 12000, 8000]
}

# Decibels relative to full scale cible 
TARGET_RMS = 0.05
TARGET_SAMPLE_RATE = 48000
SILENCE_THRESH = -60  # dBFS
MIN_SILENCE_LEN = 200  # ms

def get_mp3_sample_rates(mp3_file):
    sample_rates_found = set()
    with open(mp3_file, 'rb') as f:
        data = f.read()

    i = 0
    while i < len(data) - 4:
        if data[i] == 0xFF and (data[i+1] & 0xE0) == 0xE0:
            version_id = (data[i+1] >> 3) & 0x03
            sample_rate_index = (data[i+2] >> 2) & 0x03

            if version_id == 0b11:
                version = 'MPEG1'
            elif version_id == 0b10:
                version = 'MPEG2'
            elif version_id == 0b00:
                version = 'MPEG2.5'
            else:
                i += 1
                continue

            if sample_rate_index == 0b11:
                i += 1
                continue

            if version in SAMPLE_RATES:
                sample_rate = SAMPLE_RATES[version][sample_rate_index]
                sample_rates_found.add(sample_rate)

            i += 4
        else:
            i += 1

    return sample_rates_found

def reencode_mp3(input_file, target_rate):
    temp_file = input_file.replace(".mp3", "_converted.wav")
    try:
        subprocess.run([
            "ffmpeg", "-y", "-i", input_file,
            "-ar", str(target_rate),
            "-ac", "1",
            "-sample_fmt", "s16",  # 16-bit PCM
            temp_file
        ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
        print(f"✅ Réencodé en WAV : {temp_file}")
        # os.remove(input_file)
        # os.rename(temp_file, input_file)
        return temp_file
    except subprocess.CalledProcessError as e:
        print(f"❌ FFmpeg error : {e}")
        if os.path.exists(temp_file):
            os.remove(temp_file)

def compute_rms(audio_segment):
    samples = np.array(audio_segment.get_array_of_samples()).astype(np.float32)
    if audio_segment.sample_width == 2:  # 16-bit
        samples /= 32768
    elif audio_segment.sample_width == 4:
        samples /= 2147483648
    return np.sqrt(np.mean(samples**2))

def dbfs_to_rms(dbfs):
    return 10 ** (dbfs / 20)

def parse_annotations(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    annotations = []
    for ann in root.findall("Annotation"):
        start = float(ann.find("Start_s").text)
        end = float(ann.find("End_s").text)
        fq_min = max(int(float(ann.find("FqMin_Hz").text)), 0) # en cas d'erreur de labbel 
        fq_max = int(float(ann.find("FqMax_Hz").text))
        annotations.append((start, end, fq_min, fq_max))
    return annotations

def butter_bandpass_filter(data, lowcut, highcut, fs, order=3):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data)

def audiosegment_to_ndarray(audio_segment):
    samples = np.array(audio_segment.get_array_of_samples())
    if audio_segment.channels == 2:
        samples = samples.reshape((-1, 2))
    return samples.astype(np.float32)

def ndarray_to_audiosegment(samples, sample_width, frame_rate, channels):
    # Clip to avoid overflow
    samples = np.clip(samples, -2 ** (8 * sample_width - 1), 2 ** (8 * sample_width - 1) - 1)
    samples = samples.astype({1: np.int8, 2: np.int16, 4: np.int32}[sample_width])
    if channels == 2:
        samples = samples.reshape((-1,))
    return AudioSegment(
        samples.tobytes(),
        frame_rate=frame_rate,
        sample_width=sample_width,
        channels=channels
    )

def bandpass_filter(segment, fq_min, fq_max, order=5):
    fs = segment.frame_rate
    sample_width = segment.sample_width
    channels = segment.channels

    samples = audiosegment_to_ndarray(segment)

    if channels == 1:
        filtered = butter_bandpass_filter(samples, fq_min, fq_max, fs, order)
    else:
        # Appliquer filtre séparément sur chaque canal
        filtered = np.stack([
            butter_bandpass_filter(samples[:, ch], fq_min, fq_max, fs, order)
            for ch in range(channels)
        ], axis=-1)

    return ndarray_to_audiosegment(filtered, sample_width, fs, channels)

def apply_annotations_mask(audio, annotations, margin_s=0.5, fade_ms=500, filter_order=5):
    output = AudioSegment.silent(duration=len(audio), frame_rate=audio.frame_rate)
    duration_ms = len(audio)

    for start, end, fq_min, fq_max in annotations:
        start_ms = max(0, int((start - margin_s) * 1000))
        end_ms = min(duration_ms, int((end + margin_s) * 1000))

        segment = audio[start_ms:end_ms]
        filtered = bandpass_filter(segment, fq_min, fq_max, order=filter_order)

        faded = filtered.fade_in(fade_ms).fade_out(fade_ms)
        output = output.overlay(faded, position=start_ms)

    return output

### normailze_audio_noise
def normalize_audio(filename, output_filename, target_rms=TARGET_RMS):
    audio = AudioSegment.from_file(filename)
    audio = audio.set_channels(1).set_frame_rate(TARGET_SAMPLE_RATE).set_sample_width(2)

    nonsilent = silence.detect_nonsilent(audio, min_silence_len=MIN_SILENCE_LEN, silence_thresh=SILENCE_THRESH)
    if not nonsilent:
        print(f"🔇 Aucun son détecté dans {filename}")
        return

    active_audio = sum([audio[start:end] for start, end in nonsilent])
    current_rms = compute_rms(active_audio)

    if current_rms < 1e-5:
        print(f"🛑 RMS nul dans {filename}")
        return

    gain = 20 * math.log10(target_rms / current_rms)
    normalized = audio.apply_gain(gain)

    normalized.export(output_filename, format="wav")
    print(f"✅ Normalisé : {output_filename}")

# normalize_audio_with_annot(bird)
def normalize_audio_with_annot(filename, output_filename, target_rms=TARGET_RMS, original_annotation_path=None):
    audio = AudioSegment.from_file(filename)
    
    if original_annotation_path is None:
        base_name = os.path.basename(filename).replace(".mp3", "").replace(".wav", "")
        xml_path = os.path.join(os.path.dirname(filename), base_name + "_Annotation.xml")
    else:
        xml_path = original_annotation_path

    if not os.path.exists(xml_path):
        print(f"❌ Fichier XML manquant : {xml_path}")
        return
    
    annotations = parse_annotations(xml_path)
    masked_audio = apply_annotations_mask(audio, annotations)

    # Détection des segments non silencieux pour normalisation RMS
    nonsilent = silence.detect_nonsilent(masked_audio, min_silence_len=MIN_SILENCE_LEN, silence_thresh=SILENCE_THRESH)
    if not nonsilent:
        print(f"🔇 Aucun son détecté dans {filename}")
        return

    active_audio = sum([masked_audio[start:end] for start, end in nonsilent])
    current_rms = compute_rms(active_audio)

    if current_rms < 1e-5:
        print(f"🛑 RMS trop faible (silencieux) : {filename}")
        return

    gain = 20 * math.log10(target_rms / current_rms)
    normalized = masked_audio.apply_gain(gain)
    
    normalized.export(output_filename, format="wav")
    print(f"✅ Normalisé : {output_filename}")

    # Copier le fichier d'annotation dans le dossier de sortie
    output_xml_path = os.path.splitext(output_filename)[0] + "_Annotation.xml"
    shutil.copy(xml_path, output_xml_path)
    print(f"📄 Annotation copiée : {output_xml_path}")

def process_file(filepath, output_filename):
    ext = os.path.splitext(filepath)[1].lower()
    if ext not in [".mp3", ".wav"]:
        return

    print(f"\n🎧 Traitement de : {filepath}")
    file_to_process = filepath  # par défaut

    if ext == ".mp3":
        sample_rates = get_mp3_sample_rates(filepath)
        need_reencode = (
            len(sample_rates) > 1 or
            (len(sample_rates) == 1 and list(sample_rates)[0] != TARGET_SAMPLE_RATE)
        )

        if need_reencode:
            print(f"⚠️  Réencodage en {TARGET_SAMPLE_RATE} Hz mono 16-bit")
            temp_wav = reencode_mp3(filepath, TARGET_SAMPLE_RATE)
            if temp_wav and os.path.exists(temp_wav):
                file_to_process = temp_wav
            else:
                print(f"❌ Échec de réencodage : {filepath}")
                return

    annotation_path = os.path.splitext(filepath)[0] + "_Annotation.xml"
    normalize_audio_with_annot(file_to_process, output_filename, original_annotation_path=annotation_path)

    # normalize_audio(file_to_process, output_filename)
                               
    # Supprimer le fichier temporaire s'il a été créé
    if file_to_process != filepath and os.path.exists(file_to_process):
        os.remove(file_to_process)
        print (f"fichier {file_to_process} supprimé")

# Exécution sur un dossier
if __name__ == "__main__":

    ### Bird
    test_dir = r"Todo_Path\File_origin\BirdOnly"
    cible_dir = r"Todo_Path\File_cleaned_normalized\Bird"
    target_classes = {"AlAr", "CaCa", "CoCo", "CoPa", "CyCa", "ErRu", "FrCo", "HiRu", "LuMe", "TuMe"} # To change potentially
    
    if (test_dir == "Todo_Path\File_origin\BirdOnly"):
        raise Exception("Change your path !")
    
    os.makedirs(cible_dir, exist_ok=True)

    for root, dirs, files in os.walk(test_dir):
        class_name = os.path.basename(root)
        if class_name not in target_classes:
            continue  # ignorer les autres

        target_dir = os.path.join(cible_dir, os.path.relpath(root, test_dir))
        os.makedirs(target_dir, exist_ok=True)

        for file in files:
            if not file.lower().endswith((".mp3", ".wav")):
                continue

            input_path = os.path.join(root, file)
            output_path = os.path.join(target_dir, file.replace(".mp3", ".wav"))
            process_file(input_path, output_path)
