zm-f21's picture
Update app.py
ec6f784 verified
import gradio as gr
from transformers import pipeline
from sentence_transformers import SentenceTransformer
import pandas as pd
import numpy as np
import zipfile
import os
import torch
# -----------------------------
# Load Mistral pipeline
# -----------------------------
llm = pipeline(
"text-generation",
model="mistralai/Mistral-7B-Instruct-v0.2",
torch_dtype=torch.float16,
device_map="auto"
)
# -----------------------------
# Load SentenceTransformer embeddings
# -----------------------------
embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased")
# -----------------------------
# Extract Yukon ZIP
# -----------------------------
zip_path = "/app/yukon.zip" # make sure you uploaded here
extract_folder = "/app/yukon_texts"
# Remove old folder if exists
if os.path.exists(extract_folder):
import shutil
shutil.rmtree(extract_folder)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_folder)
# -----------------------------
# Parse TXT files and create dataframe
# -----------------------------
def parse_metadata_and_content(raw):
metadata = {}
content = raw
for line in raw.split("\n"):
if ":" in line:
key, value = line.split(":", 1)
metadata[key.strip().upper()] = value.strip()
content_lines = [
line for line in raw.split("\n") if not any(k in line.upper() for k in metadata.keys())
]
content = "\n".join(content_lines)
return metadata, content
documents = []
for root, dirs, files in os.walk(extract_folder):
for filename in files:
if filename.startswith("._"):
continue
if filename.endswith(".txt"):
filepath = os.path.join(root, filename)
with open(filepath, "r", encoding="latin-1") as f:
raw = f.read()
metadata, content = parse_metadata_and_content(raw)
paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
for p in paragraphs:
documents.append({
"source_title": metadata.get("SOURCE_TITLE", "Unknown"),
"province": metadata.get("PROVINCE", "Unknown"),
"last_updated": metadata.get("LAST_UPDATED", "Unknown"),
"url": metadata.get("URL", "N/A"),
"pdf_links": metadata.get("PDF_LINKS", ""),
"text": p
})
texts = [d["text"] for d in documents]
embeddings = embedding_model.encode(texts).astype("float32")
df = pd.DataFrame(documents)
df["Embedding"] = list(embeddings)
print("Loaded documents:", len(df))
# -----------------------------
# Retrieval function
# -----------------------------
def retrieve_with_pandas(query, top_k=2):
query_emb = embedding_model.encode([query])[0]
df["Similarity"] = df["Embedding"].apply(
lambda x: np.dot(query_emb, x) / (np.linalg.norm(query_emb) * np.linalg.norm(x))
)
return df.sort_values("Similarity", ascending=False).head(top_k)
# -----------------------------
# RAG generation
# -----------------------------
def generate_with_rag(query, top_k=2):
top_docs = retrieve_with_pandas(query, top_k)
context = " ".join(top_docs["text"].tolist())
input_text = f"""
Use ONLY the following context to answer the question briefly (2–3 sentences).
Do NOT guess. Do NOT add external information.
Context:
{context}
Question: {query}
"""
response = llm(input_text, max_new_tokens=150, num_return_sequences=1)[0]['generated_text']
meta = []
for _, row in top_docs.iterrows():
meta.append(
f"- Province: {row['province']}\n"
f" Source: {row['source_title']}\n"
f" Updated: {row['last_updated']}\n"
f" URL: {row['url']}\n"
)
metadata_block = "\n".join(meta)
final = f"{response.strip()}\n\nSources Used:\n{metadata_block}"
return final
# -----------------------------
# Gradio Chat
# -----------------------------
def respond(message, history):
answer = generate_with_rag(message)
history.append((message, answer))
return history, history
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Your question")
msg.submit(respond, [msg, chatbot], [chatbot, chatbot])
gr.Markdown("Ask questions about Yukon rental rules and landlord responsibilities.")
if __name__ == "__main__":
demo.launch(share=True)