import random
import os
from pydub import AudioSegment
from pydub.generators import WhiteNoise
import shutil

def weighted_choice(classes, probabilities):
    r = random.random()
    cumulative = 0.0
    for cls, prob in zip(classes, probabilities):
        cumulative += prob
        if r < cumulative:
            return cls
    return classes[0]  # sécurité si r est très proche du total

## Statistical Probability (To change) ## 
# SS = 50
# S = 50
# A = 16
# B = 8
# C = 3
# D = 2
# E = 1
# Sun_Noise_classe = ["insects", "crickets", "frog", "dog", "footsteps",
#                      "cat", "cow", "church_bells", "engine", "hen", "sea_waves",
#                     "car_horn", "chainsaw", "train", "pig", "rooster", "sheep",
#                      "airplane", "fireworks", "helicopter", "sneezing", "crackling_fire"]
# Sun_Noise_proba = [A, A, A, B, B, 
#                     C, C, C, C, C, C,
#                     D, D, D, D, D, D,
#                     E,E,E,E,E]
# Sun_Noise_proba = [round(p / sum(Sun_Noise_proba), 3) for p in Sun_Noise_proba]
# Rain_Noise_classe = [None, "thunderstorm", "frog", 
#                      "cat", "cow", "church_bells", "engine", "hen", "sea_waves", 
#                      "train", "car_horn", "airplane", "helicopter", "sneezing"]
# Rain_Noise_proba = [SS, S, A,
#                      C, C, B, B, C, B, 
#                      D, D, E, E, E]
# Rain_Noise_proba = [round(p / sum(Rain_Noise_proba), 3) for p in Rain_Noise_proba]


Sun_Noise_classe = ["insects", "crickets", "frog", "dog", "footsteps", "cat", "cow", "church_bells", "engine", "hen", "sea_waves", "car_horn", "chainsaw", "train", "pig", "rooster", "sheep", "airplane", "fireworks", "helicopter", "sneezing", "crackling_fire"]
Sun_Noise_proba = [0.18, 0.16, 0.16, 0.08, 0.1, 0.03, 0.03, 0.03, 0.03, 0.03, 0.02, 0.01, 0.01, 0.02, 0.02, 0.02, 0.02, 0.01, 0.01,0.01,0.01, 0.01]
Sun_Noise_proba = [round(p / sum(Sun_Noise_proba), 3) for p in Sun_Noise_proba]

Rain_Noise_classe = [None, "thunderstorm", "frog", "cat", "cow", "church_bells", "engine", "hen", "sea_waves", "train", "car_horn", "airplane", "helicopter", "sneezing"]
Rain_Noise_proba = [0.32,0.32,0.1,0.02,0.02,0.03,0.03,0.02,0.02,0.02,0.02,0.01,0.01,0.01]
Rain_Noise_proba = [round(p / sum(Rain_Noise_proba), 3) for p in Rain_Noise_proba]

## Statical Environnemental Bkg
Rain_file = [None, "rain-A_1.wav", "rain-A_2.wav", "rain-B_1.wav", "rain-C_1.wav", "rain-C_2.wav"] # 3 levels of intensity 
Rain_proba = [0.85, 0.025, 0.025, 0.05, 0.025, 0.025] # same proba for each intensity - global rain probability : 15% 
Wind_file = [None, "wind-A_1.wav", "wind-B_1.wav", "wind-B_2.wav", "wind-B_3.wav", "wind-C_1.wav", "wind-C_2.wav"]
Wind_proba = [0.7, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05 ] # global wind Proba : 30% (5%, 15%, 10%)

## Taux de remplissage d'evennement ##
External_Event_presence_proba = 0.65 # % of noise presence in the dataset (1 = very noisy | 0 = no noise) 

## Random Jump 
Random_Jump_ms_max = 2000 # +- [0, 2] seconde for each 5s
white_noise_db_range = [-50, -20]  # ce bruit recevra aussi le random gain_range

## Buffer of event
max_buffer_size_per_file = 3

## Random Gain range (dB) 
average_SNR_dB = 0 # Esperance (and average) of SNR in the dataset
rdm_gain_range = -20 # uniforme law between -20dB , 0dB for noise and bird sound (added after the SNR variation)
rdm_gain_range_activate = False 

fade_time_ms = 500

#########################################################################################################################
## Bird PATH 
Clean_dir_To_Mixate = r"ToDo_Path\File_cleaned_normalized\Bird"
Clear_dir_Noise = r"ToDo_Path\File_cleaned_normalized\Noise"
Output_dir = r"ToDo_Path\Result_Mixed_SNR_0"
target_classes = {"TuMe", "CaCa", "CoCo", "AlAr", "CyCa", "CoPa", "ErRu", "FrCo", "HiRu", "LuMe"} # To Change if needed
#########################################################################################################################

### Mixer ###
os.makedirs(Output_dir, exist_ok=True)
for root, dirs, files in os.walk(Clean_dir_To_Mixate):
    class_name = os.path.basename(root)
    if class_name not in target_classes:
        continue  # ignorer les autres

    for file in files:
        if not file.lower().endswith((".mp3", ".wav")):
            continue

        ### Creation des noms de fichier ###
        input_path = os.path.join(root, file) ## fichier sur lequel on va ajouter les bruits
        annotation_path =  os.path.splitext(input_path)[0] + "_Annotation.xml"
        if not os.path.exists(annotation_path):
            print(f"Erreur : file {annotation_path} not exist !")
            continue
        output_path_name = os.path.join(Output_dir, file)
        output_annotation_path = os.path.join(Output_dir, os.path.splitext(file)[0] + "_Annotation.xml")

        ### --- CHARGEMENT DU FICHIER CIBLE ---
        audio = AudioSegment.from_file(input_path)
        if rdm_gain_range_activate:
            gain_db_audio = random.uniform(rdm_gain_range, 0)
            audio_ampli = audio.apply_gain(gain_db_audio)
        else : 
            audio_ampli = audio
        duree_ms = len(audio)
        duree_sec = duree_ms / 1000
        print(f"Traitement de : {input_path}")
        print(f"Durée de l'audio : {duree_ms} ms ({duree_sec:.2f} s)")

        if rdm_gain_range_activate:
            gain_db_noise = random.uniform(rdm_gain_range, 0)

        current_buffer_class = []
        current_buffer_proba = []

        # --- BRUIT DE FOND (Pluie ou Vent) --- 
        rain_file = weighted_choice(Rain_file, Rain_proba)
        if rain_file != None :
            print(f"Environnement : {rain_file}")
            rain_path = os.path.join(Clear_dir_Noise, "rain", rain_file)
            rain_audio = AudioSegment.from_file(rain_path)
            rain_time = len(rain_audio)
            
            # Ajout de la pluie
            if rain_time < duree_ms :
                start_rain = int((duree_ms - rain_time)/2)
                rain_for_audio = AudioSegment.silent(duration=duree_ms).overlay(rain_audio, position=start_rain)
            else :
                start_rain = random.randint(0, rain_time - duree_ms)
                rain_for_audio = rain_audio[start_rain:start_rain + duree_ms]
            
            # Modulation of gain (SNR)
            rain_for_audio = rain_for_audio.apply_gain(-average_SNR_dB)
            if rdm_gain_range_activate:
                rain_for_audio = rain_for_audio.apply_gain(gain_db_noise)
            
            audio_with_ambiant_noise = audio_ampli.overlay(rain_for_audio) # ajout de la pluie

        else :
            # Pas de pluie -> choix si vent ou non
            wind_file = weighted_choice(Wind_file, Wind_proba)
            if wind_file :
                print(f"Environnement : {wind_file}")
                wind_path = os.path.join(Clear_dir_Noise, "wind", wind_file)
                wind_audio = AudioSegment.from_file(wind_path)
                wind_time = len(wind_audio)
                
                # Ajout de la pluie 
                if wind_time < duree_ms :
                    start_wind = int((duree_ms - wind_time)/2)
                    wind_for_audio = AudioSegment.silent(duration=duree_ms).overlay(wind_audio, position=start_wind)
                else :
                    start_wind = random.randint(0, wind_time - duree_ms)
                    wind_for_audio = wind_audio[start_wind:start_wind + duree_ms]

                wind_for_audio = wind_for_audio.apply_gain(-average_SNR_dB)
                if rdm_gain_range_activate: 
                    wind_for_audio = wind_for_audio.apply_gain(gain_db_noise)

                audio_with_ambiant_noise = audio_ampli.overlay(wind_for_audio) # ajout de la pluie
            else :
                audio_with_ambiant_noise = audio_ampli

        # --- AJOUT DE BRUIT BLANC (qualité d'enregistrement) ---
        time = 0
        adding_noise_sound = WhiteNoise().to_audio_segment(duration=len(audio_with_ambiant_noise))
        adding_noise_sound = adding_noise_sound.apply_gain(random.uniform(white_noise_db_range[0], white_noise_db_range[1]))

        # --- EVENEMENTS EXTERNES ---
        while time < duree_ms:
            jump_time = random.randint(-Random_Jump_ms_max, Random_Jump_ms_max)
            time += jump_time
            if random.random() > External_Event_presence_proba :
                print(f"Time : {time} | Classe : Bruit")
            else :
                # Selectionne le bruit
                if len(current_buffer_class) < max_buffer_size_per_file : 
                    noise_classe = weighted_choice(Rain_Noise_classe, Rain_Noise_proba) if rain_file else weighted_choice(Sun_Noise_classe, Sun_Noise_proba)

                    if noise_classe not in current_buffer_class :
                        current_buffer_class.append(noise_classe)
                        source_proba = Rain_Noise_proba if rain_file else Sun_Noise_proba
                        current_buffer_proba.append(source_proba[(Rain_Noise_classe if rain_file else Sun_Noise_classe).index(noise_classe)])

                        if len(current_buffer_class) == max_buffer_size_per_file : 
                            total = sum(current_buffer_proba)
                            current_buffer_proba = [round(p / total, 2) for p in current_buffer_proba]
                else : 
                    noise_classe = weighted_choice(current_buffer_class, current_buffer_proba)

                print(f"Time : {time} | Classe : {noise_classe}")
                if noise_classe: 
                    folder = os.path.join(Clear_dir_Noise, noise_classe)
                    noise_available = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f)) and f.lower().endswith((".wav", ".mp3"))]
                    # Ajoute le bruit
                    if noise_available:
                        selected_path = os.path.join(folder, random.choice(noise_available)) 
                        noise_audio = AudioSegment.from_file(selected_path)
                        noise_audio = noise_audio.fade_in(fade_time_ms).fade_out(fade_time_ms) # Application d’un fondu d’entrée et de sortie 
                        adding_noise_sound = adding_noise_sound.overlay(noise_audio, position=time)

            time += 5000 
        # Final mix
        adding_noise_sound = adding_noise_sound.apply_gain(-average_SNR_dB)
        if rdm_gain_range_activate: 
            adding_noise_sound = adding_noise_sound.apply_gain(gain_db_noise)
        final_audio = audio_with_ambiant_noise.overlay(adding_noise_sound)
        final_audio.export(output_path_name, format = "wav")
        print(f"Audio {output_path_name} crée avec {current_buffer_class}!")
        ### Gestion annotation 
        shutil.copy(annotation_path, output_annotation_path)
