Georges
switch to fine tuned model
820a16c
from typing import Optional
import streamlit as st
import pandas as pd
import plotly.express as px
import os
import time
import requests
import torch
from transformers import pipeline
@st.cache_resource
def model_loader():
model_name = "OrangeBottle/place_description_model_16_bit"
#model_name = "unsloth/Llama-3.2-3B-Instruct" # Non Fine-Tuned Base Model
st.write(f"##### Using {model_name}")
pipe = pipeline(task="text-generation", model=model_name, dtype=torch.bfloat16, device_map="cpu")
return pipe
@st.cache_data
def text_description_to_locations(description: str) -> list[str]:
# # Use LLM to convert text description to list of locations
# # For simplicity, we will just return a hardcoded list here
# return ["Paris, France", "Rome, Italy", "Barcelona, Spain", "Amsterdam, Netherlands"]
messages = [
{"role": "user", "content": f"List 4 popular vacation locations that match the following description: {description}. Return the locations as a comma-separated list. Only return the location names without any additional text. The locations should be geo-searchable cities."},
]
text = llm_request(messages)
text = text.strip()
locations = [loc.strip() for loc in text.split(",")]
return locations
@st.cache_data
def search_location(location_name) -> Optional[dict[str, float]]:
url = "https://us1.locationiq.com/v1/search"
headers = {"accept": "application/json"}
params = {
"key": os.environ.get("pk_key"),
"q": location_name,
"format": "json"
}
response = requests.get(url, params=params, headers=headers)
response_json = response.json()
location = response_json[0]
if not location:
return None
latitude = location['lat']
longitude = location['lon']
time.sleep(1) # To respect rate limiting
return {
'lat': float(latitude),
'lon': float(longitude),
}
@st.cache_data
def describe_location_with_llm(location_name: str) -> str:
# # Use LLM to generate a description of the location
# # For simplicity, we will just return a hardcoded description here
# return f"{location_name} is a beautiful city with rich history and culture."
messages = [
{"role": "user", "content": f"Describe the location: {location_name} in detail."},
]
text = llm_request(messages)
text = text.strip()
return text
@st.cache_data
def llm_request(messages) -> str:
pipe = model_loader()
max_new_tokens=128
use_cache = True
temperature = 1.5
min_p = 0.1
outputs = pipe(messages, max_new_tokens=max_new_tokens)
print(outputs)
text = outputs[0]['generated_text'][-1]["content"]
return text
async def main():
st.write("""
# Vacation Location Explorer with LLM and Maps
""")
# Text box for what to search for
search_query = st.text_input("Describe a vacation experience you would like to have:")
# Use LLM to generate a couple of locations based on the search query
# For simplicity, we will just hardcode some locations here
if search_query:
st.write(f"Based on your input: '{search_query}', here are some recommended locations:")
recommended_locations = text_description_to_locations(search_query)
for location in recommended_locations:
st.write(f"- {location}")
else:
st.write("Please enter a vacation experience to get recommendations.")
return
locations = []
for location in recommended_locations:
result = search_location(location)
if result:
lat = result['lat']
lon = result['lon']
locations.append((location, lat, lon))
else:
st.write(f"Location {location} not found.")
if not locations:
st.write("No locations found to display on the map.")
return
df_locations = pd.DataFrame(locations, columns=["Location", "Latitude", "Longitude"])
df_locations['size'] = 3 # Fixed size for all points, can be modified as needed
fig = px.scatter_map(df_locations, lat="Latitude", lon="Longitude", hover_name="Location", zoom=3, size="size", color="Location")
fig.update_layout(mapbox_style="open-street-map")
fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0})
event = st.plotly_chart(fig, on_select = "rerun")
if len(event['selection']["points"]) == 0:
st.write("Please select a location on the map to get more information.")
return
description = describe_location_with_llm(event['selection']["points"][0]["hovertext"])
st.write(description)
import asyncio
asyncio.run(main())