from shiny import App, render, reactive, ui
from shinywidgets import output_widget, render_widget, reactive_read, register_widget
import ipyleaflet as L
from ipywidgets import Layout, HTML
import numpy as np
import datetime
from pathlib import Path
import pandas as pd

project_name = "CARI'MAP"

basemaps = {
    "Esri.WorldStreetdMap": L.basemaps.Esri.WorldStreetMap,
    "Esri.NatGeoWorldMap": L.basemaps.Esri.NatGeoWorldMap,
    "OpenStreetMap": L.basemaps.OpenStreetMap.Mapnik,
    "Satellite": L.basemaps.Gaode.Satellite,
    "Stamen.Toner": L.basemaps.Stamen.Toner,
    "Stamen.Terrain": L.basemaps.Stamen.Terrain,
    "Stamen.Watercolor": L.basemaps.Stamen.Watercolor,
}

stations_location = {
    'ANGUILLA':     [18.27632, -62.98334],
    'BON':          [12.01654, -68.23889],
    'StBARTH':      [17.93567, -62.85825],
    'BERMUDE':      [32.30808, -64.68840],
    'ARUBA':        [12.56996, -70.06936],
    'JAM':          [18.47769, -77.42101],
    'GUA_BREACH':   [16.27275, -61.82557],
    'GUA_SF':       [16.21611, -61.22273],
    'GUA_AB':       [16.49772, -61.49633],
    'BAHAMAS':      [23.08987, -74.83888],
    'StEUS':        [17.46173, -62.98735],
    'StMARTIN':     [18.13021, -62.95031],
    'MART_PRECH':   [14.78176, -61.22047],
    'MART_StANNE':  [14.40284, -60.82717],
    'GUYANNE_GEPOG': [14.40284, -60.82717]}


df = pd.read_pickle("example_clean.pkl")

choices = ("Humpback", "Dolphin", "Sperm Whale")
species_colors = {
    "Humpback": 'purple',
    "Dolphin": 'green',
    'Sperm Whale': 'green',
    }

zoom_level = 4
center_map = [22.00, -65.000]

min_date = df['date'].min() #datetime.date(2019, 1, 25)
max_date = df['date'].max() #datetime.date(2023, 1, 25)

min_month = datetime.date(2023, 1, 15)
max_month = datetime.date(2023, 12, 15)


app_ui = ui.page_fluid(
    ui.head_content(
        ui.tags.style((Path(__file__).parent / "style.css").read_text()),
    ),
    ui.div(
        ui.h3(project_name),
        class_ = "center_item",
    ),
    ui.layout_sidebar(
        ui.panel_sidebar(
            ui.div(
                ui.input_selectize("species", "Species", choices, selected=choices, multiple=True),
                ui.row(
                    ui.column(
                        6,
                        ui.p(ui.input_action_button("b_all", "Select All")),
                        class_ = "center_item",
                    ),
                    ui.column(
                        6,
                        ui.p(ui.input_action_button("b_none", "Clear All")),
                        class_ = "center_item",
                    ),
                ),
                class_ = "heigh_layout",
            ),

            ui.div(
                ui.input_slider("year", "Year", min=min_date.year, max=max_date.year,  value=min_date.year, sep='', step=(max_date.year-min_date.year)),
                ui.input_slider("month", "Month", min=min_month, max=max_month,  value=min_month, time_format="%B"),
                class_ = "heigh_layout",
            ),

        ),

        ui.panel_main(
            output_widget("map"),
            ui.output_text_verbatim("info"),
            ui.output_text_verbatim("select_species"),
            ui.output_text_verbatim("select_date"),
        ),
    )
)


def server(input, output, session):
    selected_species = reactive.Value(choices)
    selected_date = reactive.Value((min_date, max_date))
    info_msg = reactive.Value("")

    map = L.Map(basemap= L.basemaps.Esri.WorldStreetMap, center=center_map, zoom=zoom_level, scroll_wheel_zoom=True, layout=Layout(width='100%', height='600px'))
    map.add_control(L.leaflet.ScaleControl(position="topright"))
    register_widget("map", map)
    selected_slides = reactive.Value([])

    @reactive.Effect
    @reactive.event(input.b_none)
    def _():
        selected_species.set(())
        ui.update_selectize("species", selected=())

    @reactive.Effect
    @reactive.event(input.b_all)
    def _():
        selected_species.set(choices)
        ui.update_selectize("species", selected=choices)

    @reactive.Effect
    @reactive.event(input.species)
    def _():
        selected_species.set(input.species())
        info_msg.set(' ')

    @output
    @render.text
    def select_species():
        return f"Selected species : {str(selected_species())}"



    @reactive.Effect
    @reactive.event(input.year)
    def _():
        selected_slides.set((input.year(), input.month().month))
        info_msg.set(' ')


    @reactive.Effect
    @reactive.event(input.month)
    def _():
        selected_slides.set((input.year(), input.month().month))
        info_msg.set(' ')

    @output
    @render.text
    def info():
        return f"{str(info_msg())}"

    def disp_info(**kwargs):

        selec_year = selected_slides()[0]
        selec_month = selected_slides()[1]
        selec_spec = selected_species()

        selec_station = ""
        coor = kwargs["coordinates"]
        min_dist = np.inf
        for stat in stations_location.keys():
            dist = np.linalg.norm(np.array(coor) - np.array(stations_location[stat]))
            if min_dist > dist:
                min_dist = dist
                selec_station = stat

        tmp = df[df['date'].dt.year == selec_year]
        tmp = tmp[tmp['date'].dt.month == selec_month]
        tmp = tmp[tmp['station'] == selec_station]

        # recherche du marker avec la coordonnee
        message = [f"Information for the station : {selec_station}"]
        message.append(f"Date : {selec_month:02}-{selec_year}")
        message.append(f"Detection : ")
        #Time format
        # import ipdb; ipdb.set_trace()
        # for spec in selec_spec:
        #     (tmp.species == "Humpback").sum()
        # 

        for spec, grp in tmp.groupby("species"):
            message.append(f"\t- {spec} = {len(grp)}")

        print('\n'.join(message))
        info_msg.set('\n'.join(message))


    @reactive.Effect
    @render_widget
    def update_map():
        if len(map.layers) < 2:
            icon = L.Icon(icon_url='http://sabiod.lis-lab.fr/pub/CARIMAP/icon.png', icon_size=[10, 15])
            for idx_s, loc_name in enumerate(stations_location):
                marker = L.Marker(location=stations_location[loc_name], icon=icon, draggable=False)
                marker.on_click(disp_info)
                map.add_layer(marker)

        else:
            for idx in range(len(map.layers)-1, len(stations_location), -1):
                 map.remove_layer(map.layers[idx])

        print(str(selected_slides()))
        selec_year = selected_slides()[0]
        selec_month = selected_slides()[1]
        selec_spec = selected_species()

        #import ipdb; ipdb.set_trace()
        for elem, grp in df.groupby("species"):
            print(len(grp))
            if not(elem in selec_spec):
                continue

            grp = grp[grp['date'].dt.year == selec_year]
            grp = grp[grp['date'].dt.month == selec_month]

            for s_name, s_grp in grp.groupby("station"):
                circle = L.Circle(opacity=1)
                circle.location = stations_location[s_name]
                circle.radius = size_cicle(np.random.randint(10, 1000), len(df))
                circle.color = species_colors[elem]
                circle.fill_color = species_colors[elem]
                map.add_layer(circle)




def size_cicle(val, max_val):
    max_size = 1000000
    min_size = 10
    return int(100*(max(min_size, (val/max_val)*max_size)))

app = App(app_ui, server)
