Spaces:
Sleeping
Sleeping
File size: 4,690 Bytes
f5c40cc 32b38b5 f5c40cc a291df5 32b38b5 f5c40cc 820a16c 32b38b5 f5c40cc a291df5 f5c40cc 8e79565 f5c40cc 3279af5 f5c40cc 8e79565 f5c40cc 8048b06 f5c40cc 8048b06 f5c40cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from typing import Optional
import streamlit as st
import pandas as pd
import plotly.express as px
import os
import time
import requests
import torch
from transformers import pipeline
@st.cache_resource
def model_loader():
model_name = "OrangeBottle/place_description_model_16_bit"
#model_name = "unsloth/Llama-3.2-3B-Instruct" # Non Fine-Tuned Base Model
st.write(f"##### Using {model_name}")
pipe = pipeline(task="text-generation", model=model_name, dtype=torch.bfloat16, device_map="cpu")
return pipe
@st.cache_data
def text_description_to_locations(description: str) -> list[str]:
# # Use LLM to convert text description to list of locations
# # For simplicity, we will just return a hardcoded list here
# return ["Paris, France", "Rome, Italy", "Barcelona, Spain", "Amsterdam, Netherlands"]
messages = [
{"role": "user", "content": f"List 4 popular vacation locations that match the following description: {description}. Return the locations as a comma-separated list. Only return the location names without any additional text. The locations should be geo-searchable cities."},
]
text = llm_request(messages)
text = text.strip()
locations = [loc.strip() for loc in text.split(",")]
return locations
@st.cache_data
def search_location(location_name) -> Optional[dict[str, float]]:
url = "https://us1.locationiq.com/v1/search"
headers = {"accept": "application/json"}
params = {
"key": os.environ.get("pk_key"),
"q": location_name,
"format": "json"
}
response = requests.get(url, params=params, headers=headers)
response_json = response.json()
location = response_json[0]
if not location:
return None
latitude = location['lat']
longitude = location['lon']
time.sleep(1) # To respect rate limiting
return {
'lat': float(latitude),
'lon': float(longitude),
}
@st.cache_data
def describe_location_with_llm(location_name: str) -> str:
# # Use LLM to generate a description of the location
# # For simplicity, we will just return a hardcoded description here
# return f"{location_name} is a beautiful city with rich history and culture."
messages = [
{"role": "user", "content": f"Describe the location: {location_name} in detail."},
]
text = llm_request(messages)
text = text.strip()
return text
@st.cache_data
def llm_request(messages) -> str:
pipe = model_loader()
max_new_tokens=128
use_cache = True
temperature = 1.5
min_p = 0.1
outputs = pipe(messages, max_new_tokens=max_new_tokens)
print(outputs)
text = outputs[0]['generated_text'][-1]["content"]
return text
async def main():
st.write("""
# Vacation Location Explorer with LLM and Maps
""")
# Text box for what to search for
search_query = st.text_input("Describe a vacation experience you would like to have:")
# Use LLM to generate a couple of locations based on the search query
# For simplicity, we will just hardcode some locations here
if search_query:
st.write(f"Based on your input: '{search_query}', here are some recommended locations:")
recommended_locations = text_description_to_locations(search_query)
for location in recommended_locations:
st.write(f"- {location}")
else:
st.write("Please enter a vacation experience to get recommendations.")
return
locations = []
for location in recommended_locations:
result = search_location(location)
if result:
lat = result['lat']
lon = result['lon']
locations.append((location, lat, lon))
else:
st.write(f"Location {location} not found.")
if not locations:
st.write("No locations found to display on the map.")
return
df_locations = pd.DataFrame(locations, columns=["Location", "Latitude", "Longitude"])
df_locations['size'] = 3 # Fixed size for all points, can be modified as needed
fig = px.scatter_map(df_locations, lat="Latitude", lon="Longitude", hover_name="Location", zoom=3, size="size", color="Location")
fig.update_layout(mapbox_style="open-street-map")
fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0})
event = st.plotly_chart(fig, on_select = "rerun")
if len(event['selection']["points"]) == 0:
st.write("Please select a location on the map to get more information.")
return
description = describe_location_with_llm(event['selection']["points"][0]["hovertext"])
st.write(description)
import asyncio
asyncio.run(main())
|