from typing import Optional import streamlit as st import pandas as pd import plotly.express as px import os import time import requests import torch from transformers import pipeline @st.cache_resource def model_loader(): model_name = "OrangeBottle/place_description_model_16_bit" #model_name = "unsloth/Llama-3.2-3B-Instruct" # Non Fine-Tuned Base Model st.write(f"##### Using {model_name}") pipe = pipeline(task="text-generation", model=model_name, dtype=torch.bfloat16, device_map="cpu") return pipe @st.cache_data def text_description_to_locations(description: str) -> list[str]: # # Use LLM to convert text description to list of locations # # For simplicity, we will just return a hardcoded list here # return ["Paris, France", "Rome, Italy", "Barcelona, Spain", "Amsterdam, Netherlands"] messages = [ {"role": "user", "content": f"List 4 popular vacation locations that match the following description: {description}. Return the locations as a comma-separated list. Only return the location names without any additional text. The locations should be geo-searchable cities."}, ] text = llm_request(messages) text = text.strip() locations = [loc.strip() for loc in text.split(",")] return locations @st.cache_data def search_location(location_name) -> Optional[dict[str, float]]: url = "https://us1.locationiq.com/v1/search" headers = {"accept": "application/json"} params = { "key": os.environ.get("pk_key"), "q": location_name, "format": "json" } response = requests.get(url, params=params, headers=headers) response_json = response.json() location = response_json[0] if not location: return None latitude = location['lat'] longitude = location['lon'] time.sleep(1) # To respect rate limiting return { 'lat': float(latitude), 'lon': float(longitude), } @st.cache_data def describe_location_with_llm(location_name: str) -> str: # # Use LLM to generate a description of the location # # For simplicity, we will just return a hardcoded description here # return f"{location_name} is a beautiful city with rich history and culture." messages = [ {"role": "user", "content": f"Describe the location: {location_name} in detail."}, ] text = llm_request(messages) text = text.strip() return text @st.cache_data def llm_request(messages) -> str: pipe = model_loader() max_new_tokens=128 use_cache = True temperature = 1.5 min_p = 0.1 outputs = pipe(messages, max_new_tokens=max_new_tokens) print(outputs) text = outputs[0]['generated_text'][-1]["content"] return text async def main(): st.write(""" # Vacation Location Explorer with LLM and Maps """) # Text box for what to search for search_query = st.text_input("Describe a vacation experience you would like to have:") # Use LLM to generate a couple of locations based on the search query # For simplicity, we will just hardcode some locations here if search_query: st.write(f"Based on your input: '{search_query}', here are some recommended locations:") recommended_locations = text_description_to_locations(search_query) for location in recommended_locations: st.write(f"- {location}") else: st.write("Please enter a vacation experience to get recommendations.") return locations = [] for location in recommended_locations: result = search_location(location) if result: lat = result['lat'] lon = result['lon'] locations.append((location, lat, lon)) else: st.write(f"Location {location} not found.") if not locations: st.write("No locations found to display on the map.") return df_locations = pd.DataFrame(locations, columns=["Location", "Latitude", "Longitude"]) df_locations['size'] = 3 # Fixed size for all points, can be modified as needed fig = px.scatter_map(df_locations, lat="Latitude", lon="Longitude", hover_name="Location", zoom=3, size="size", color="Location") fig.update_layout(mapbox_style="open-street-map") fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0}) event = st.plotly_chart(fig, on_select = "rerun") if len(event['selection']["points"]) == 0: st.write("Please select a location on the map to get more information.") return description = describe_location_with_llm(event['selection']["points"][0]["hovertext"]) st.write(description) import asyncio asyncio.run(main())