import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline from sentence_transformers import SentenceTransformer import pandas as pd import numpy as np import zipfile import os import re import torch ############################################################################### # 1) LOAD MISTRAL IN 4-BIT (MUCH FASTER) ############################################################################### bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) model_name = "mistralai/Mistral-7B-Instruct-v0.2" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, quantization_config=bnb_config, device_map="auto" ) llm = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=200, temperature=0.4, ) ############################################################################### # 2) LOAD EMBEDDINGS ############################################################################### embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased") ############################################################################### # 3) EXTRACT ZIP + PARSE PROVINCE FILES ############################################################################### zip_path = "/app/provinces.zip" extract_folder = "/app/provinces_texts" if os.path.exists(extract_folder): import shutil shutil.rmtree(extract_folder) with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_folder) date_regex = re.compile(r"(\d{4}[-_]\d{2}[-_]\d{2})") def parse_metadata_and_content(raw): if "CONTENT:" not in raw: raise ValueError("Missing CONTENT: block.") header, content = raw.split("CONTENT:", 1) metadata = {} pdfs = [] for line in header.split("\n"): if ":" in line and not line.strip().startswith("-"): key, value = line.split(":", 1) metadata[key.strip().upper()] = value.strip() elif line.strip().startswith("-"): pdfs.append(line.strip()) if pdfs: metadata["PDF_LINKS"] = "\n".join(pdfs) return metadata, content.strip() documents = [] for root, dirs, files in os.walk(extract_folder): for filename in files: if filename.startswith("._") or not filename.endswith(".txt"): continue filepath = os.path.join(root, filename) try: with open(filepath, "r", encoding="latin-1") as f: raw = f.read() metadata, content = parse_metadata_and_content(raw) for p in [x.strip() for x in content.split("\n\n") if x.strip()]: documents.append({ "source_title": metadata.get("SOURCE_TITLE", "Unknown"), "province": metadata.get("PROVINCE", "Unknown"), "last_updated": metadata.get("LAST_UPDATED", "Unknown"), "url": metadata.get("URL", "N/A"), "pdf_links": metadata.get("PDF_LINKS", ""), "text": p }) except Exception as e: print("Skipping:", filepath, str(e)) ############################################################################### # 4) EMBEDDINGS + DATAFRAME ############################################################################### texts = [d["text"] for d in documents] embs = embedding_model.encode(texts).astype("float16") df = pd.DataFrame(documents) df["Embedding"] = list(embs) ############################################################################### # 5) RAG RETRIEVAL ############################################################################### def retrieve_with_pandas(query, province=None, top_k=2): q_emb = embedding_model.encode([query])[0] subset = df if province is None else df[df["province"] == province].copy() subset["Similarity"] = subset["Embedding"].apply( lambda x: np.dot(q_emb, x) / (np.linalg.norm(q_emb) * np.linalg.norm(x)) ) return subset.sort_values("Similarity", ascending=False).head(top_k) ############################################################################### # 6) Province detection ############################################################################### def detect_province(query): provinces = { "yukon": "Yukon", "alberta": "Alberta", "bc": "British Columbia", "british columbia": "British Columbia", "manitoba": "Manitoba", "newfoundland": "Newfoundland and Labrador", "labrador": "Newfoundland and Labrador", "sask": "Saskatchewan", "saskatchewan": "Saskatchewan", "ontario": "Ontario", "pei": "Prince Edward Island", "prince edward island": "Prince Edward Island", "quebec": "Quebec", "new brunswick": "New Brunswick", "nb": "New Brunswick", "nova scotia": "Nova Scotia", "nunavut": "Nunavut", "nwt": "Northwest Territories", "northwest territories": "Northwest Territories", } q = query.lower() for k, p in provinces.items(): if k in q: return p return None ############################################################################### # 7) Guardrails ############################################################################### def is_disallowed(q): banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"] return any(b in q.lower() for b in banned) def is_off_topic(q): keys = [ "tenant","landlord","rent","evict","lease", "deposit","tenancy","rental","apartment", "unit","heating","notice","repair","pets" ] return not any(k in q.lower() for k in keys) ############################################################################### # 8) MAIN RAG PIPELINE ############################################################################### def generate_with_rag(query): if is_disallowed(query): return "Sorry β€” I can’t help with harmful or dangerous topics." if is_off_topic(query): return "Sorry β€” I can only answer questions about Canadian tenancy and housing law." province = detect_province(query) top_docs = retrieve_with_pandas(query, province) context = " ".join(top_docs["text"].tolist()) prompt = f""" Use ONLY the context below to answer. If the context does not contain the answer, say so. Answer in a simple, conversational way. Context: {context} Question: {query} Answer: """ out = llm(prompt)[0]["generated_text"] answer = out.split("Answer:", 1)[-1].strip() # metadata section meta = "" for _, r in top_docs.iterrows(): meta += ( f"- **Province:** {r['province']}\n" f" Source: {r['source_title']} (Updated {r['last_updated']})\n" f" URL: {r['url']}\n" ) return f"{answer}\n\n**Sources Used:**\n{meta}" ############################################################################### # 9) GRADIO CHAT β€” INTRO ONLY ONCE ############################################################################### INTRO = ( "πŸ‘‹ **Welcome!** I'm a Canadian rental housing assistant.\n\n" "I can help you find and explain information from tenancy laws across all provinces.\n" "I am **not a lawyer** β€” this is not legal advice.\n\n" "What would you like to know?" ) def start_chat(): return [(None, INTRO)] def respond(message, history): answer = generate_with_rag(message) history.append((message, answer)) return history, history with gr.Blocks() as demo: chatbot = gr.Chatbot(value=start_chat()) msg = gr.Textbox(label="Ask your question") msg.submit(respond, [msg, chatbot], [chatbot, chatbot]) if __name__ == "__main__": demo.launch(share=True)