|
|
import gradio as gr |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import zipfile |
|
|
import os |
|
|
import re |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.float16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
) |
|
|
|
|
|
model_name = "mistralai/Mistral-7B-Instruct-v0.2" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
llm = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
max_new_tokens=200, |
|
|
temperature=0.4, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
zip_path = "/app/provinces.zip" |
|
|
extract_folder = "/app/provinces_texts" |
|
|
|
|
|
if os.path.exists(extract_folder): |
|
|
import shutil |
|
|
shutil.rmtree(extract_folder) |
|
|
|
|
|
with zipfile.ZipFile(zip_path, "r") as zip_ref: |
|
|
zip_ref.extractall(extract_folder) |
|
|
|
|
|
date_regex = re.compile(r"(\d{4}[-_]\d{2}[-_]\d{2})") |
|
|
|
|
|
def parse_metadata_and_content(raw): |
|
|
if "CONTENT:" not in raw: |
|
|
raise ValueError("Missing CONTENT: block.") |
|
|
header, content = raw.split("CONTENT:", 1) |
|
|
|
|
|
metadata = {} |
|
|
pdfs = [] |
|
|
for line in header.split("\n"): |
|
|
if ":" in line and not line.strip().startswith("-"): |
|
|
key, value = line.split(":", 1) |
|
|
metadata[key.strip().upper()] = value.strip() |
|
|
elif line.strip().startswith("-"): |
|
|
pdfs.append(line.strip()) |
|
|
|
|
|
if pdfs: |
|
|
metadata["PDF_LINKS"] = "\n".join(pdfs) |
|
|
|
|
|
return metadata, content.strip() |
|
|
|
|
|
documents = [] |
|
|
|
|
|
for root, dirs, files in os.walk(extract_folder): |
|
|
for filename in files: |
|
|
if filename.startswith("._") or not filename.endswith(".txt"): |
|
|
continue |
|
|
filepath = os.path.join(root, filename) |
|
|
try: |
|
|
with open(filepath, "r", encoding="latin-1") as f: |
|
|
raw = f.read() |
|
|
metadata, content = parse_metadata_and_content(raw) |
|
|
for p in [x.strip() for x in content.split("\n\n") if x.strip()]: |
|
|
documents.append({ |
|
|
"source_title": metadata.get("SOURCE_TITLE", "Unknown"), |
|
|
"province": metadata.get("PROVINCE", "Unknown"), |
|
|
"last_updated": metadata.get("LAST_UPDATED", "Unknown"), |
|
|
"url": metadata.get("URL", "N/A"), |
|
|
"pdf_links": metadata.get("PDF_LINKS", ""), |
|
|
"text": p |
|
|
}) |
|
|
except Exception as e: |
|
|
print("Skipping:", filepath, str(e)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
texts = [d["text"] for d in documents] |
|
|
embs = embedding_model.encode(texts).astype("float16") |
|
|
|
|
|
df = pd.DataFrame(documents) |
|
|
df["Embedding"] = list(embs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_with_pandas(query, province=None, top_k=2): |
|
|
q_emb = embedding_model.encode([query])[0] |
|
|
|
|
|
subset = df if province is None else df[df["province"] == province].copy() |
|
|
|
|
|
subset["Similarity"] = subset["Embedding"].apply( |
|
|
lambda x: np.dot(q_emb, x) / |
|
|
(np.linalg.norm(q_emb) * np.linalg.norm(x)) |
|
|
) |
|
|
|
|
|
return subset.sort_values("Similarity", ascending=False).head(top_k) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_province(query): |
|
|
provinces = { |
|
|
"yukon": "Yukon", |
|
|
"alberta": "Alberta", |
|
|
"bc": "British Columbia", |
|
|
"british columbia": "British Columbia", |
|
|
"manitoba": "Manitoba", |
|
|
"newfoundland": "Newfoundland and Labrador", |
|
|
"labrador": "Newfoundland and Labrador", |
|
|
"sask": "Saskatchewan", |
|
|
"saskatchewan": "Saskatchewan", |
|
|
"ontario": "Ontario", |
|
|
"pei": "Prince Edward Island", |
|
|
"prince edward island": "Prince Edward Island", |
|
|
"quebec": "Quebec", |
|
|
"new brunswick": "New Brunswick", |
|
|
"nb": "New Brunswick", |
|
|
"nova scotia": "Nova Scotia", |
|
|
"nunavut": "Nunavut", |
|
|
"nwt": "Northwest Territories", |
|
|
"northwest territories": "Northwest Territories", |
|
|
} |
|
|
q = query.lower() |
|
|
for k, p in provinces.items(): |
|
|
if k in q: |
|
|
return p |
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_disallowed(q): |
|
|
banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"] |
|
|
return any(b in q.lower() for b in banned) |
|
|
|
|
|
def is_off_topic(q): |
|
|
keys = [ |
|
|
"tenant","landlord","rent","evict","lease", |
|
|
"deposit","tenancy","rental","apartment", |
|
|
"unit","heating","notice","repair","pets" |
|
|
] |
|
|
return not any(k in q.lower() for k in keys) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_with_rag(query): |
|
|
if is_disallowed(query): |
|
|
return "Sorry — I can’t help with harmful or dangerous topics." |
|
|
if is_off_topic(query): |
|
|
return "Sorry — I can only answer questions about Canadian tenancy and housing law." |
|
|
|
|
|
province = detect_province(query) |
|
|
top_docs = retrieve_with_pandas(query, province) |
|
|
|
|
|
context = " ".join(top_docs["text"].tolist()) |
|
|
|
|
|
prompt = f""" |
|
|
Use ONLY the context below to answer. |
|
|
If the context does not contain the answer, say so. |
|
|
Answer in a simple, conversational way. |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: {query} |
|
|
Answer: |
|
|
""" |
|
|
|
|
|
out = llm(prompt)[0]["generated_text"] |
|
|
answer = out.split("Answer:", 1)[-1].strip() |
|
|
|
|
|
|
|
|
meta = "" |
|
|
for _, r in top_docs.iterrows(): |
|
|
meta += ( |
|
|
f"- **Province:** {r['province']}\n" |
|
|
f" Source: {r['source_title']} (Updated {r['last_updated']})\n" |
|
|
f" URL: {r['url']}\n" |
|
|
) |
|
|
|
|
|
return f"{answer}\n\n**Sources Used:**\n{meta}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
INTRO = ( |
|
|
"👋 **Welcome!** I'm a Canadian rental housing assistant.\n\n" |
|
|
"I can help you find and explain information from tenancy laws across all provinces.\n" |
|
|
"I am **not a lawyer** — this is not legal advice.\n\n" |
|
|
"What would you like to know?" |
|
|
) |
|
|
|
|
|
def start_chat(): |
|
|
return [(None, INTRO)] |
|
|
|
|
|
def respond(message, history): |
|
|
answer = generate_with_rag(message) |
|
|
history.append((message, answer)) |
|
|
return history, history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
chatbot = gr.Chatbot(value=start_chat()) |
|
|
msg = gr.Textbox(label="Ask your question") |
|
|
msg.submit(respond, [msg, chatbot], [chatbot, chatbot]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=True) |
|
|
|