import solara
import pandas as pd
import joblib
import plotly.express as px
from io import BytesIO
from typing import Optional, cast
from solara.components.file_drop import FileInfo
import googlemaps
from time import sleep
from dotenv import load_dotenv
import os
# Load environment variables from .env file
load_dotenv()
# Get the Google Maps API key from the environment variables
GOOGLE_MAPS_API_KEY = os.getenv('GOOGLE_MAPS_API_KEY')
# Initialize Google Maps client
gmaps = googlemaps.Client(key=GOOGLE_MAPS_API_KEY)
# Load the pre-trained pipeline and label encoder
pipeline = joblib.load('best_model_pipeline.pkl')
label_encoder = joblib.load('label_encoder.pkl')
class State:
    df = solara.reactive(cast(Optional[pd.DataFrame], None))
    content = solara.reactive(b"")
    filename = solara.reactive("")
    size = solara.reactive(0)
    processed = solara.reactive(False)
    df_uploaded = solara.reactive(cast(Optional[pd.DataFrame], None))
    download_data = solara.reactive(b"")
@solara.component
def FileDropDemo():
    def on_file(f: FileInfo):
        State.filename.value = f["name"]
        State.size.value = f["size"]
        State.content.value = f["file_obj"].read()
        State.processed.value = False
        solara.Success("Archivo cargado satisfactoriamente")
    solara.FileDrop(
        label="Arrastre y cargue un archivo de excel aquí.",
        on_file=on_file,
        lazy=True,
    )
@solara.component
def ProcessButton():
    def process_file():
        if State.content.value:
            try:
                df_uploaded = pd.read_excel(BytesIO(State.content.value))
                # Filter for the latest year in 'UltimaFechaVisita'
                df_uploaded['UltimaFechaVisita'] = pd.to_datetime(df_uploaded['UltimaFechaVisita'], errors='coerce')
                latest_year = df_uploaded['UltimaFechaVisita'].dt.year.max()
                df_uploaded = df_uploaded[df_uploaded['UltimaFechaVisita'].dt.year == latest_year]
                # Filter rows containing 'AUTOMOTRIZ' in 'Intervención'
                df_uploaded1 = df_uploaded[df_uploaded['Intervención'].str.contains('AUTOMOTRIZ', na=False)]
                # Store the df_uploaded in State for later use in map visualization
                State.df_uploaded.value = df_uploaded1.copy()
                # Select specific columns for predictions
                selected_columns = ['IdTBLEstablecimientos', 'EstablecimientoFormal', 'HorarioDiurno',
                                    'Localidad', 'UPZ', 'Barrio', 'Sede', 'TipoEstablecimiento',
                                    'ActividadEconomica', 'UltimaFechaVisita', 'Ultimo_Concepto']
                df = df_uploaded1[selected_columns]
                # Rename columns to match the model expectations
                df.rename(columns={'HorarioDiurno': 'HorarioDuirno'}, inplace=True)
                # Read external data and aggregate
                df_dict = pd.read_excel('mant_aut_asp_trab_2017_2022_V3.xlsx', sheet_name=['2017', '2018', '2019', '2020', '2021', '2022'])
                dfasver = pd.concat(df_dict, ignore_index=True)
                dfasveragg = dfasver.groupby('IdTBLEstablecimientos').agg({
                    'numero_trab_asp': 'max',
                    '7.5': 'last',
                    '7.11': 'last',
                    '7.6': 'last',
                    '7.4': 'last',
                    '4.5': 'last',
                    '4.11': 'last',
                    '7.7': 'last',
                    '4.3': 'last',
                    '5.4': 'last',
                    '4.13': 'last',
                    '4.9': 'last',
                    '4.6': 'last',
                    '3.8': 'last',
                    '4.12': 'last',
                    '5.6': 'last',
                    '7.10': 'last',
                    '5.5': 'last'
                }).reset_index()
                # Merge dataframes
                dftllmer = pd.merge(df, dfasveragg, on='IdTBLEstablecimientos', how='left')
                dftllmer = dftllmer.dropna().reset_index(drop=True)
                # Constants
                A_1 = 0.7
                A_2 = 0.3
                n_7_6 = 0.5
                n_4_11 = 0.3
                n_4_5 = 0.13
                n_5_4 = 0.07
                n_7_5 = 0.0769
                n_7_11 = 0.0769
                n_7_4 = 0.0769
                n_7_7 = 0.0769
                n_4_3 = 0.0769
                n_4_13 = 0.0769
                n_4_9 = 0.0769
                n_4_6 = 0.0769
                n_3_8 = 0.0769
                n_4_12 = 0.0769
                n_5_6 = 0.0769
                n_7_10 = 0.0769
                n_5_5 = 0.0769
                # Encoding Mapping
                category_mapping = {
                    '1. Cumple': 1,
                    '5. Terminado': 0,
                    '2. No cumple': -1
                }
                # Custom function to calculate mlindex
                def calculate_mlindex(row):
                    # Determine A_c based on 'concepto' value
                    if row['Ultimo_Concepto'] == 'Concepto Favorable con Req':
                        A_c = 0.5
                    elif row['Ultimo_Concepto'] == 'Desfavorable':
                        A_c = 1.0
                    elif row['Ultimo_Concepto'] == 'Concepto Favorable':
                        A_c = 0
                    else:
                        A_c = 0  # Default for other categories
                    # Encode all relevant columns and apply constants
                    I_d = A_c * (A_1 * (category_mapping.get(row['7.6'], 0) * n_7_6 +
                                        category_mapping.get(row['4.11'], 0) * n_4_11 +
                                        category_mapping.get(row['4.5'], 0) * n_4_5 +
                                        category_mapping.get(row['5.4'], 0) * n_5_4) +
                                 A_2 * (category_mapping.get(row['7.5'], 0) * n_7_5 +
                                        category_mapping.get(row['7.11'], 0) * n_7_11 +
                                        category_mapping.get(row['7.4'], 0) * n_7_4 +
                                        category_mapping.get(row['7.7'], 0) * n_7_7 +
                                        category_mapping.get(row['4.3'], 0) * n_4_3 +
                                        category_mapping.get(row['4.13'], 0) * n_4_13 +
                                        category_mapping.get(row['4.9'], 0) * n_4_9 +
                                        category_mapping.get(row['4.6'], 0) * n_4_6 +
                                        category_mapping.get(row['3.8'], 0) * n_3_8 +
                                        category_mapping.get(row['4.12'], 0) * n_4_12 +
                                        category_mapping.get(row['5.6'], 0) * n_5_6 +
                                        category_mapping.get(row['7.10'], 0) * n_7_10 +
                                        category_mapping.get(row['5.5'], 0) * n_5_5))
                    return I_d
                # Apply the function to each row of the dataframe
                dftllmer['mlindex'] = dftllmer.apply(calculate_mlindex, axis=1)
                # Prepare the data for prediction
                X_2023 = dftllmer.drop('Ultimo_Concepto', axis=1, errors='ignore')
                # Make predictions on the 2023 data
                predictions_encoded = pipeline.predict(X_2023)
                decoded_predictions = label_encoder.inverse_transform(predictions_encoded)
                dftllmer['predicciones'] = decoded_predictions
                # Merge additional information for EDA and map visualization
                df_for_eda = dftllmer[['IdTBLEstablecimientos', 'predicciones', 'numero_trab_asp']].merge(
                    df_uploaded[['IdTBLEstablecimientos', 'NombreComercial', 'Localidad', 'Barrio', 'UltimaFechaVisita', 'DireccionComercial']],
                    on='IdTBLEstablecimientos',
                    how='left'
                )
                State.df.value = df_for_eda
                State.processed.value = True
                # Prepare data for download
                output = BytesIO()
                df_for_eda.to_excel(output, index=False, engine='openpyxl')
                output.seek(0)
                State.download_data.value = output.read()
                solara.Success("Archivo procesado y predicciones efectuadas satisfactoriamente")
            except Exception as e:
                solara.Error(f"Error procesando archivo: {e}")
    solara.Button("Procesar archivo", on_click=process_file, disabled=State.processed.value)
# Improved geocoding function using Google Maps Geocoding API
def geocode_address(df, address_column='DireccionComercial'):
    cache = {}
    def geocode_with_cache(address):
        if address in cache:
            return cache[address]
        try:
            geocode_result = gmaps.geocode(f"{address}, Bogotá, Colombia")
            if geocode_result:
                location = geocode_result[0]['geometry']['location']
                cache[address] = (location['lat'], location['lng'])
                sleep(0.1)  # To avoid hitting the rate limit
                return location['lat'], location['lng']
        except Exception as e:
            print(f"Error geocoding {address}: {e}")
        return None, None
    df['latitude'], df['longitude'] = zip(*df[address_column].apply(geocode_with_cache))
    return df
@solara.component
def DownloadButton():
    if State.download_data.value:
        solara.FileDownload(State.download_data.value, filename="predicciones.xlsx", label="Descargar Predicciones")
@solara.component
def EDA():
    df = State.df.value
    if df is not None:
        with solara.lab.Tabs():
            with solara.lab.Tab("Histograma de Predicciones"):
                fig = px.histogram(df, x="predicciones", title="Distribución de Predicciones")
                solara.FigurePlotly(fig)
            with solara.lab.Tab("Boxplots por Barrio"):
                fig = px.box(df, y="numero_trab_asp", x="Barrio", color="predicciones", title="Boxplot de numero_trab_asp por Barrio y Predicciones")
                solara.FigurePlotly(fig)
            with solara.lab.Tab("Boxplots por Localidad"):
                fig = px.box(df, y="numero_trab_asp", x="Localidad", color="predicciones", title="Boxplot de numero_trab_asp por Localidad y Predicciones")
                solara.FigurePlotly(fig)
@solara.component
def Page():
    FileDropDemo()
    ProcessButton()
    df = State.df.value
    processed = State.processed.value
    with solara.AppBarTitle():
        solara.Text("Cargar y analizar nuevos datos para la predicción de conceptos sanitarios emitidos a talleres de mecánica automotriz")
    if processed and df is not None:
        with solara.lab.Tabs():
            with solara.lab.Tab("Predicciones"):
                solara.Markdown("## Predicciones")
                solara.DataFrame(df)
                DownloadButton()
            with solara.lab.Tab("Visualización Geográfica de Conceptos"):
                solara.Markdown("## Visualización Geográfica de Conceptos")
                df = geocode_address(df[df['predicciones'].notna()], address_column='DireccionComercial')
                if 'latitude' in df.columns and 'longitude' in df.columns:
                    # Create Plotly interactive map
                    fig = px.scatter_mapbox(
                        df,
                        lat="latitude",
                        lon="longitude",
                        hover_name="NombreComercial",
                        hover_data={
                            "latitude": False,
                            "longitude": False,
                            "Localidad": True,
                            "Barrio": True,
                            "UltimaFechaVisita": True,
                            "predicciones": True
                        },
                        color="predicciones",
                        zoom=10,
                        height=600,
                        size_max=30  # Increase size of points
                    )
                    fig.update_layout(mapbox_style="open-street-map")
                    fig.update_layout(margin={"r": 0, "t": 0, "l": 0, "b": 0})
                    solara.FigurePlotly(fig)
                else:
                    solara.Error("Fallo la geocodificación de direcciones o faltan columnas de Latitud y Longitud")
            with solara.lab.Tab("Exploratory Data Analysis"):
                EDA()
# Run the Solara app
Page()