SpotChatbot / app.py
Nathan Gebreab
updated with proper gradio formatting
7bf136e
raw
history blame
5.69 kB
"""
Spot: The Spotify Chatbot
IAT360 Final Project
By Nathan Gebreab (301582871) & EmXi Vo (301600699)
Spot is a chatbot using Meta's Llama-3.2-3B-Instruct model & uses
RAG (Retrieval-Augmented Generation) to provide the user with song recommendations
based on their input prompt. By using RAG, Spot is able to access a dataset of
approximately 30000 Spotify songs and their descriptive parameters in order to
find the best recommendations.
Links to Model (Authentication from Meta Required):
https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct
https://www.llama.com/llama-downloads/
Link to Dataset (created by Joakim Arvidsson):
https://www.kaggle.com/datasets/joebeachcapital/30000-spotify-songs
"""
import torch
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import pandas as pd
import numpy as np
import warnings
import gradio as gr
from huggingface_hub import InferenceClient
model_id="meta-llama/Llama-3.2-3B-Instruct"
# Suppress warnings
warnings.filterwarnings('ignore')
# Load the spotify dataset all at the beginning
print("Loading Spotify songs database...")
spotify_df = pd.read_csv('spotify_songs.csv')
# Remove duplicates based on track name and artist name
spotify_df = spotify_df.drop_duplicates(subset=["track_name", "track_artist"])
documents = spotify_df.apply(
lambda row: f"""Song: {row['track_name']},
Album: {row['track_album_name']},
Album Release Date: {row['track_album_release_date']},
Artist: {row['track_artist']},
Playlist Genre: {row['playlist_genre']},
Playlist Subgenre: {row['playlist_subgenre']},
Danceability: {row['danceability']},
Energy: {row['energy']},
Key: {row['key']},
Loudness: {row['loudness']},
Mode: {row['mode']},
Speechiness: {row['speechiness']},
Acousticness: {row['acousticness']},
Instrumentalness: {row['instrumentalness']},
Liveness: {row['liveness']},
Valence: {row['valence']},
Tempo: {row['tempo']},
Duration: {row['duration_ms']}
""",
axis=1
).tolist()
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = embedding_model.encode(documents, show_progress_bar=False)
df = pd.DataFrame({
"Document": documents,
"Embedding": list(embeddings)
})
print("Database loaded! Ready to chat.\n")
def retrieve_with_pandas(query, top_k=10):
query_embedding = embedding_model.encode([query])[0]
df['Similarity'] = df['Embedding'].apply(lambda x: np.dot(query_embedding, x) /
(np.linalg.norm(query_embedding) * np.linalg.norm(x)))
results = df.sort_values(by="Similarity", ascending=False).head(top_k)
return results[["Document", "Similarity"]]
def generate_intro(query):
llm = pipeline(
"text-generation",
model=model_id,
dtype=torch.bfloat16,
device_map="auto",
)
system_prompt = (
"You are Spot, a friendly music recommendation chatbot."
"Respond to the user in 1–3 natural sentences."
"Do NOT list songs. Do NOT number anything. Do NOT name any songs. Do NOT name any artists. Do NOT name any musicians. Do NOT name any famous works."
"Just give a short, warm and friendly message that leads into the list of recommended songs"
)
prompt = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_prompt}\n" \
f"<|start_header_id|>user<|end_header_id|>\n{query}\n" \
f"<|start_header_id|>assistant<|end_header_id|>\n"
intro = llm(
prompt,
max_new_tokens=60,
do_sample=True,
temperature=2.0
)[0]["generated_text"]
intro = intro.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
return intro
def num_requested_songs(query):
for word in query.split():
if word.isdigit():
return min(int(word), 10) # Max 10 songs
return 3 # Default number of songs
def generate_response(query, num_songs):
intro = generate_intro(query)
retrieved = retrieve_with_pandas(query, top_k=num_songs)
# Get the actual songs
songs_list = "\n".join([f"{i+1}. {row['Document']}"
for i, (_, row) in enumerate(retrieved.iterrows())])
response = f"""{intro}
Here are my recommendations:
{songs_list}
"""
return response
def respond(
message,
history: list[dict[str, str]],
# system_message,
# max_tokens,
# temperature,
# top_p,
# hf_token: gr.OAuthToken,
):
if message.lower() in ['quit', 'exit', 'bye', 'goodbye']:
return "Thanks for chatting!"
if not message.strip():
return "Please ask me something!"
num_songs = num_requested_songs(message)
response = generate_response(message, num_songs)
return response
chatbot = gr.ChatInterface(
respond,
title="Spot: The Spotify Chatbot",
description="""
Hello! My name's Spot and I'm here to give song recommendations!
You can request a specific song, or just let me know how you're feeling!
*Type 'quit' or 'exit' to end the conversation.*
""",
examples=[
"Give me 8 upbeat songs",
"Show me 5 chill songs for studying",
"Recommend songs by Drake",
"I want something energetic"
],
theme="glass",
# retry_btn=None,
# undo_btn=None,
# clear_btn="Clear Chat"
)
with gr.Blocks() as demo:
with gr.Sidebar():
gr.LoginButton()
chatbot.render()
if __name__ == "__main__":
demo.launch()