import gradio as gr from transformers import pipeline from sentence_transformers import SentenceTransformer import pandas as pd import numpy as np import zipfile, os, re, torch # ----------------------------- # Load Mistral (FP16, GPU if available) # ----------------------------- llm = pipeline( "text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", torch_dtype=torch.float16, device_map="auto" ) # ----------------------------- # Load embedding model # ----------------------------- embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased") # ----------------------------- # Extract ZIP with provincial legal texts # ----------------------------- zip_path = "/app/provinces.zip" extract_folder = "/app/provinces_texts" if os.path.exists(extract_folder): import shutil shutil.rmtree(extract_folder) with zipfile.ZipFile(zip_path, "r") as z: z.extractall(extract_folder) date_pattern = re.compile(r"(\d{4}[-_]\d{2}[-_]\d{2})") # ----------------------------- # Parse documents # ----------------------------- def parse_metadata_and_content(raw): if "CONTENT:" not in raw: raise ValueError("Missing CONTENT block") header, content = raw.split("CONTENT:", 1) metadata = {} pdfs = [] for line in header.split("\n"): if ":" in line and not line.startswith("-"): k, v = line.split(":", 1) metadata[k.strip().upper()] = v.strip() elif line.strip().startswith("-"): pdfs.append(line.strip()) if pdfs: metadata["PDF_LINKS"] = "\n".join(pdfs) return metadata, content.strip() documents = [] for root, dirs, files in os.walk(extract_folder): for filename in files: if not filename.endswith(".txt") or filename.startswith("._"): continue path = os.path.join(root, filename) try: raw = open(path, "r", encoding="latin-1").read() metadata, content = parse_metadata_and_content(raw) for p in [x.strip() for x in content.split("\n\n") if x.strip()]: documents.append({ "source_title": metadata.get("SOURCE_TITLE", "Unknown"), "province": metadata.get("PROVINCE", "Unknown"), "last_updated": metadata.get("LAST_UPDATED", "Unknown"), "url": metadata.get("URL", "N/A"), "pdf_links": metadata.get("PDF_LINKS", ""), "text": p }) except Exception as e: print("Skipped:", path, e) print("Loaded paragraphs:", len(documents)) # ----------------------------- # Build embeddings dataframe # ----------------------------- df = pd.DataFrame(documents) texts = df["text"].tolist() embeddings = embedding_model.encode(texts).astype("float16") df["Embedding"] = list(embeddings) print("Embedding index ready:", len(df)) # ----------------------------- # Retrieval # ----------------------------- def retrieve_with_pandas(query, province=None, top_k=2): query_emb = embedding_model.encode([query])[0] subset = df if province is None else df[df["province"] == province] subset = subset.copy() subset["Similarity"] = subset["Embedding"].apply( lambda x: np.dot(query_emb, x) / (np.linalg.norm(query_emb) * np.linalg.norm(x)) ) return subset.sort_values("Similarity", ascending=False).head(top_k) # ----------------------------- # Province detection # ----------------------------- def detect_province(q): provinces = { "yukon": "Yukon", "alberta": "Alberta", "bc": "British Columbia", "british columbia": "British Columbia", "manitoba": "Manitoba", "newfoundland": "Newfoundland and Labrador", "saskatchewan": "Saskatchewan", "sask": "Saskatchewan", "ontario": "Ontario", "pei": "Prince Edward Island", "quebec": "Quebec", "new brunswick": "New Brunswick", "nova scotia": "Nova Scotia", "nunavut": "Nunavut", "northwest territories": "Northwest Territories" } q = q.lower() for key, prov in provinces.items(): if key in q: return prov return None # ----------------------------- # Filters # ----------------------------- def is_disallowed(q): banned = ["kill", "suicide", "bomb", "weapon", "harm yourself"] return any(b in q.lower() for b in banned) def is_off_topic(q): keys = ["tenant","landlord","rent","evict","lease","repair","notice","unit"] return not any(k in q.lower() for k in keys) # ----------------------------- # Intro (sent once) # ----------------------------- INTRO = ( "Hi! I'm a Canadian rental housing assistant. I help summarize and explain " "information from Residential Tenancies Acts across Canada.\n\n" "**Note:** I'm not a lawyer — this is not legal advice.\n\n" ) # ----------------------------- # RAG Generation # ----------------------------- def generate_with_rag(query): if is_disallowed(query): return "Sorry — I can’t help with harmful topics." if is_off_topic(query): return "Sorry — I only answer questions about Canadian tenancy law." prov = detect_province(query) docs = retrieve_with_pandas(query, province=prov, top_k=2) if len(docs) == 0: return "I couldn’t find anything relevant in the tenancy database." context = " ".join(docs["text"].tolist()) prompt = f""" Use only the context below. Do NOT invent laws. Context: {context} Question: {query} Answer conversationally: """ out = llm(prompt, max_new_tokens=150)[0]["generated_text"] answer = out.split("Answer conversationally:", 1)[-1].strip() return answer # ----------------------------- # Gradio Chat (Intro only once) # ----------------------------- def start_chat(): return [(None, INTRO)] def respond(msg, history): answer = generate_with_rag(msg) history.append((msg, answer)) return history with gr.Blocks() as demo: chatbot = gr.Chatbot(value=start_chat()) inp = gr.Textbox(label="Ask a question:") inp.submit(respond, [inp, chatbot], chatbot) demo.launch(share=True)