Finalv3.5 / app.py
zm-f21's picture
Update app.py
bf5af54 verified
raw
history blame
7.97 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
from sentence_transformers import SentenceTransformer
import pandas as pd
import numpy as np
import zipfile
import os
import re
import torch
###############################################################################
# 1) LOAD MISTRAL IN 4-BIT (MUCH FASTER)
###############################################################################
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto"
)
llm = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=200,
temperature=0.4,
)
###############################################################################
# 2) LOAD EMBEDDINGS
###############################################################################
embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased")
###############################################################################
# 3) EXTRACT ZIP + PARSE PROVINCE FILES
###############################################################################
zip_path = "/app/provinces.zip"
extract_folder = "/app/provinces_texts"
if os.path.exists(extract_folder):
import shutil
shutil.rmtree(extract_folder)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_folder)
date_regex = re.compile(r"(\d{4}[-_]\d{2}[-_]\d{2})")
def parse_metadata_and_content(raw):
if "CONTENT:" not in raw:
raise ValueError("Missing CONTENT: block.")
header, content = raw.split("CONTENT:", 1)
metadata = {}
pdfs = []
for line in header.split("\n"):
if ":" in line and not line.strip().startswith("-"):
key, value = line.split(":", 1)
metadata[key.strip().upper()] = value.strip()
elif line.strip().startswith("-"):
pdfs.append(line.strip())
if pdfs:
metadata["PDF_LINKS"] = "\n".join(pdfs)
return metadata, content.strip()
documents = []
for root, dirs, files in os.walk(extract_folder):
for filename in files:
if filename.startswith("._") or not filename.endswith(".txt"):
continue
filepath = os.path.join(root, filename)
try:
with open(filepath, "r", encoding="latin-1") as f:
raw = f.read()
metadata, content = parse_metadata_and_content(raw)
for p in [x.strip() for x in content.split("\n\n") if x.strip()]:
documents.append({
"source_title": metadata.get("SOURCE_TITLE", "Unknown"),
"province": metadata.get("PROVINCE", "Unknown"),
"last_updated": metadata.get("LAST_UPDATED", "Unknown"),
"url": metadata.get("URL", "N/A"),
"pdf_links": metadata.get("PDF_LINKS", ""),
"text": p
})
except Exception as e:
print("Skipping:", filepath, str(e))
###############################################################################
# 4) EMBEDDINGS + DATAFRAME
###############################################################################
texts = [d["text"] for d in documents]
embs = embedding_model.encode(texts).astype("float16")
df = pd.DataFrame(documents)
df["Embedding"] = list(embs)
###############################################################################
# 5) RAG RETRIEVAL
###############################################################################
def retrieve_with_pandas(query, province=None, top_k=2):
q_emb = embedding_model.encode([query])[0]
subset = df if province is None else df[df["province"] == province].copy()
subset["Similarity"] = subset["Embedding"].apply(
lambda x: np.dot(q_emb, x) /
(np.linalg.norm(q_emb) * np.linalg.norm(x))
)
return subset.sort_values("Similarity", ascending=False).head(top_k)
###############################################################################
# 6) Province detection
###############################################################################
def detect_province(query):
provinces = {
"yukon": "Yukon",
"alberta": "Alberta",
"bc": "British Columbia",
"british columbia": "British Columbia",
"manitoba": "Manitoba",
"newfoundland": "Newfoundland and Labrador",
"labrador": "Newfoundland and Labrador",
"sask": "Saskatchewan",
"saskatchewan": "Saskatchewan",
"ontario": "Ontario",
"pei": "Prince Edward Island",
"prince edward island": "Prince Edward Island",
"quebec": "Quebec",
"new brunswick": "New Brunswick",
"nb": "New Brunswick",
"nova scotia": "Nova Scotia",
"nunavut": "Nunavut",
"nwt": "Northwest Territories",
"northwest territories": "Northwest Territories",
}
q = query.lower()
for k, p in provinces.items():
if k in q:
return p
return None
###############################################################################
# 7) Guardrails
###############################################################################
def is_disallowed(q):
banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"]
return any(b in q.lower() for b in banned)
def is_off_topic(q):
keys = [
"tenant","landlord","rent","evict","lease",
"deposit","tenancy","rental","apartment",
"unit","heating","notice","repair","pets"
]
return not any(k in q.lower() for k in keys)
###############################################################################
# 8) MAIN RAG PIPELINE
###############################################################################
def generate_with_rag(query):
if is_disallowed(query):
return "Sorry — I can’t help with harmful or dangerous topics."
if is_off_topic(query):
return "Sorry — I can only answer questions about Canadian tenancy and housing law."
province = detect_province(query)
top_docs = retrieve_with_pandas(query, province)
context = " ".join(top_docs["text"].tolist())
prompt = f"""
Use ONLY the context below to answer.
If the context does not contain the answer, say so.
Answer in a simple, conversational way.
Context:
{context}
Question: {query}
Answer:
"""
out = llm(prompt)[0]["generated_text"]
answer = out.split("Answer:", 1)[-1].strip()
# metadata section
meta = ""
for _, r in top_docs.iterrows():
meta += (
f"- **Province:** {r['province']}\n"
f" Source: {r['source_title']} (Updated {r['last_updated']})\n"
f" URL: {r['url']}\n"
)
return f"{answer}\n\n**Sources Used:**\n{meta}"
###############################################################################
# 9) GRADIO CHAT — INTRO ONLY ONCE
###############################################################################
INTRO = (
"👋 **Welcome!** I'm a Canadian rental housing assistant.\n\n"
"I can help you find and explain information from tenancy laws across all provinces.\n"
"I am **not a lawyer** — this is not legal advice.\n\n"
"What would you like to know?"
)
def start_chat():
return [(None, INTRO)]
def respond(message, history):
answer = generate_with_rag(message)
history.append((message, answer))
return history, history
with gr.Blocks() as demo:
chatbot = gr.Chatbot(value=start_chat())
msg = gr.Textbox(label="Ask your question")
msg.submit(respond, [msg, chatbot], [chatbot, chatbot])
if __name__ == "__main__":
demo.launch(share=True)