zm-f21 commited on
Commit
8f996bb
·
verified ·
1 Parent(s): bf5af54

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -139
app.py CHANGED
@@ -1,48 +1,28 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
3
  from sentence_transformers import SentenceTransformer
4
  import pandas as pd
5
  import numpy as np
6
- import zipfile
7
- import os
8
- import re
9
- import torch
10
-
11
- ###############################################################################
12
- # 1) LOAD MISTRAL IN 4-BIT (MUCH FASTER)
13
- ###############################################################################
14
- bnb_config = BitsAndBytesConfig(
15
- load_in_4bit=True,
16
- bnb_4bit_compute_dtype=torch.float16,
17
- bnb_4bit_use_double_quant=True,
18
- bnb_4bit_quant_type="nf4",
19
- )
20
-
21
- model_name = "mistralai/Mistral-7B-Instruct-v0.2"
22
-
23
- tokenizer = AutoTokenizer.from_pretrained(model_name)
24
- model = AutoModelForCausalLM.from_pretrained(
25
- model_name,
26
- quantization_config=bnb_config,
27
- device_map="auto"
28
- )
29
 
 
 
 
30
  llm = pipeline(
31
  "text-generation",
32
- model=model,
33
- tokenizer=tokenizer,
34
- max_new_tokens=200,
35
- temperature=0.4,
36
  )
37
 
38
- ###############################################################################
39
- # 2) LOAD EMBEDDINGS
40
- ###############################################################################
41
  embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased")
42
 
43
- ###############################################################################
44
- # 3) EXTRACT ZIP + PARSE PROVINCE FILES
45
- ###############################################################################
46
  zip_path = "/app/provinces.zip"
47
  extract_folder = "/app/provinces_texts"
48
 
@@ -50,22 +30,26 @@ if os.path.exists(extract_folder):
50
  import shutil
51
  shutil.rmtree(extract_folder)
52
 
53
- with zipfile.ZipFile(zip_path, "r") as zip_ref:
54
- zip_ref.extractall(extract_folder)
55
 
56
- date_regex = re.compile(r"(\d{4}[-_]\d{2}[-_]\d{2})")
57
 
 
 
 
58
  def parse_metadata_and_content(raw):
59
  if "CONTENT:" not in raw:
60
- raise ValueError("Missing CONTENT: block.")
61
- header, content = raw.split("CONTENT:", 1)
62
 
 
63
  metadata = {}
64
  pdfs = []
 
65
  for line in header.split("\n"):
66
- if ":" in line and not line.strip().startswith("-"):
67
- key, value = line.split(":", 1)
68
- metadata[key.strip().upper()] = value.strip()
69
  elif line.strip().startswith("-"):
70
  pdfs.append(line.strip())
71
 
@@ -75,16 +59,16 @@ def parse_metadata_and_content(raw):
75
  return metadata, content.strip()
76
 
77
  documents = []
78
-
79
  for root, dirs, files in os.walk(extract_folder):
80
  for filename in files:
81
- if filename.startswith("._") or not filename.endswith(".txt"):
82
  continue
83
- filepath = os.path.join(root, filename)
 
84
  try:
85
- with open(filepath, "r", encoding="latin-1") as f:
86
- raw = f.read()
87
  metadata, content = parse_metadata_and_content(raw)
 
88
  for p in [x.strip() for x in content.split("\n\n") if x.strip()]:
89
  documents.append({
90
  "source_title": metadata.get("SOURCE_TITLE", "Unknown"),
@@ -94,141 +78,128 @@ for root, dirs, files in os.walk(extract_folder):
94
  "pdf_links": metadata.get("PDF_LINKS", ""),
95
  "text": p
96
  })
 
97
  except Exception as e:
98
- print("Skipping:", filepath, str(e))
99
 
100
- ###############################################################################
101
- # 4) EMBEDDINGS + DATAFRAME
102
- ###############################################################################
103
- texts = [d["text"] for d in documents]
104
- embs = embedding_model.encode(texts).astype("float16")
105
 
 
 
 
106
  df = pd.DataFrame(documents)
107
- df["Embedding"] = list(embs)
 
108
 
109
- ###############################################################################
110
- # 5) RAG RETRIEVAL
111
- ###############################################################################
 
 
 
112
  def retrieve_with_pandas(query, province=None, top_k=2):
113
- q_emb = embedding_model.encode([query])[0]
114
 
115
- subset = df if province is None else df[df["province"] == province].copy()
116
 
 
117
  subset["Similarity"] = subset["Embedding"].apply(
118
- lambda x: np.dot(q_emb, x) /
119
- (np.linalg.norm(q_emb) * np.linalg.norm(x))
120
  )
121
 
122
  return subset.sort_values("Similarity", ascending=False).head(top_k)
123
 
124
- ###############################################################################
125
- # 6) Province detection
126
- ###############################################################################
127
- def detect_province(query):
128
  provinces = {
129
- "yukon": "Yukon",
130
- "alberta": "Alberta",
131
- "bc": "British Columbia",
132
- "british columbia": "British Columbia",
133
- "manitoba": "Manitoba",
134
  "newfoundland": "Newfoundland and Labrador",
135
- "labrador": "Newfoundland and Labrador",
136
- "sask": "Saskatchewan",
137
- "saskatchewan": "Saskatchewan",
138
- "ontario": "Ontario",
139
- "pei": "Prince Edward Island",
140
- "prince edward island": "Prince Edward Island",
141
- "quebec": "Quebec",
142
- "new brunswick": "New Brunswick",
143
- "nb": "New Brunswick",
144
- "nova scotia": "Nova Scotia",
145
- "nunavut": "Nunavut",
146
- "nwt": "Northwest Territories",
147
- "northwest territories": "Northwest Territories",
148
  }
149
- q = query.lower()
150
- for k, p in provinces.items():
151
- if k in q:
152
- return p
153
  return None
154
 
155
- ###############################################################################
156
- # 7) Guardrails
157
- ###############################################################################
158
  def is_disallowed(q):
159
- banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"]
160
  return any(b in q.lower() for b in banned)
161
 
162
  def is_off_topic(q):
163
- keys = [
164
- "tenant","landlord","rent","evict","lease",
165
- "deposit","tenancy","rental","apartment",
166
- "unit","heating","notice","repair","pets"
167
- ]
168
  return not any(k in q.lower() for k in keys)
169
 
170
- ###############################################################################
171
- # 8) MAIN RAG PIPELINE
172
- ###############################################################################
 
 
 
 
 
 
 
 
 
173
  def generate_with_rag(query):
174
  if is_disallowed(query):
175
- return "Sorry — I can’t help with harmful or dangerous topics."
 
176
  if is_off_topic(query):
177
- return "Sorry — I can only answer questions about Canadian tenancy and housing law."
 
 
 
178
 
179
- province = detect_province(query)
180
- top_docs = retrieve_with_pandas(query, province)
181
 
182
- context = " ".join(top_docs["text"].tolist())
183
 
184
  prompt = f"""
185
- Use ONLY the context below to answer.
186
- If the context does not contain the answer, say so.
187
- Answer in a simple, conversational way.
188
 
189
  Context:
190
  {context}
191
 
192
- Question: {query}
193
- Answer:
194
- """
195
-
196
- out = llm(prompt)[0]["generated_text"]
197
- answer = out.split("Answer:", 1)[-1].strip()
198
 
199
- # metadata section
200
- meta = ""
201
- for _, r in top_docs.iterrows():
202
- meta += (
203
- f"- **Province:** {r['province']}\n"
204
- f" Source: {r['source_title']} (Updated {r['last_updated']})\n"
205
- f" URL: {r['url']}\n"
206
- )
207
 
208
- return f"{answer}\n\n**Sources Used:**\n{meta}"
 
209
 
210
- ###############################################################################
211
- # 9) GRADIO CHAT — INTRO ONLY ONCE
212
- ###############################################################################
213
- INTRO = (
214
- "👋 **Welcome!** I'm a Canadian rental housing assistant.\n\n"
215
- "I can help you find and explain information from tenancy laws across all provinces.\n"
216
- "I am **not a lawyer** — this is not legal advice.\n\n"
217
- "What would you like to know?"
218
- )
219
 
 
 
 
220
  def start_chat():
221
  return [(None, INTRO)]
222
 
223
- def respond(message, history):
224
- answer = generate_with_rag(message)
225
- history.append((message, answer))
226
- return history, history
227
 
228
  with gr.Blocks() as demo:
229
  chatbot = gr.Chatbot(value=start_chat())
230
- msg = gr.Textbox(label="Ask your question")
231
- msg.submit(respond, [msg, chatbot], [chatbot, chatbot])
 
232
 
233
- if __name__ == "__main__":
234
- demo.launch(share=True)
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
  from sentence_transformers import SentenceTransformer
4
  import pandas as pd
5
  import numpy as np
6
+ import zipfile, os, re, torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # -----------------------------
9
+ # Load Mistral (FP16, GPU if available)
10
+ # -----------------------------
11
  llm = pipeline(
12
  "text-generation",
13
+ model="mistralai/Mistral-7B-Instruct-v0.2",
14
+ torch_dtype=torch.float16,
15
+ device_map="auto"
 
16
  )
17
 
18
+ # -----------------------------
19
+ # Load embedding model
20
+ # -----------------------------
21
  embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased")
22
 
23
+ # -----------------------------
24
+ # Extract ZIP with provincial legal texts
25
+ # -----------------------------
26
  zip_path = "/app/provinces.zip"
27
  extract_folder = "/app/provinces_texts"
28
 
 
30
  import shutil
31
  shutil.rmtree(extract_folder)
32
 
33
+ with zipfile.ZipFile(zip_path, "r") as z:
34
+ z.extractall(extract_folder)
35
 
36
+ date_pattern = re.compile(r"(\d{4}[-_]\d{2}[-_]\d{2})")
37
 
38
+ # -----------------------------
39
+ # Parse documents
40
+ # -----------------------------
41
  def parse_metadata_and_content(raw):
42
  if "CONTENT:" not in raw:
43
+ raise ValueError("Missing CONTENT block")
 
44
 
45
+ header, content = raw.split("CONTENT:", 1)
46
  metadata = {}
47
  pdfs = []
48
+
49
  for line in header.split("\n"):
50
+ if ":" in line and not line.startswith("-"):
51
+ k, v = line.split(":", 1)
52
+ metadata[k.strip().upper()] = v.strip()
53
  elif line.strip().startswith("-"):
54
  pdfs.append(line.strip())
55
 
 
59
  return metadata, content.strip()
60
 
61
  documents = []
 
62
  for root, dirs, files in os.walk(extract_folder):
63
  for filename in files:
64
+ if not filename.endswith(".txt") or filename.startswith("._"):
65
  continue
66
+
67
+ path = os.path.join(root, filename)
68
  try:
69
+ raw = open(path, "r", encoding="latin-1").read()
 
70
  metadata, content = parse_metadata_and_content(raw)
71
+
72
  for p in [x.strip() for x in content.split("\n\n") if x.strip()]:
73
  documents.append({
74
  "source_title": metadata.get("SOURCE_TITLE", "Unknown"),
 
78
  "pdf_links": metadata.get("PDF_LINKS", ""),
79
  "text": p
80
  })
81
+
82
  except Exception as e:
83
+ print("Skipped:", path, e)
84
 
85
+ print("Loaded paragraphs:", len(documents))
 
 
 
 
86
 
87
+ # -----------------------------
88
+ # Build embeddings dataframe
89
+ # -----------------------------
90
  df = pd.DataFrame(documents)
91
+ texts = df["text"].tolist()
92
+ embeddings = embedding_model.encode(texts).astype("float16")
93
 
94
+ df["Embedding"] = list(embeddings)
95
+ print("Embedding index ready:", len(df))
96
+
97
+ # -----------------------------
98
+ # Retrieval
99
+ # -----------------------------
100
  def retrieve_with_pandas(query, province=None, top_k=2):
101
+ query_emb = embedding_model.encode([query])[0]
102
 
103
+ subset = df if province is None else df[df["province"] == province]
104
 
105
+ subset = subset.copy()
106
  subset["Similarity"] = subset["Embedding"].apply(
107
+ lambda x: np.dot(query_emb, x) /
108
+ (np.linalg.norm(query_emb) * np.linalg.norm(x))
109
  )
110
 
111
  return subset.sort_values("Similarity", ascending=False).head(top_k)
112
 
113
+ # -----------------------------
114
+ # Province detection
115
+ # -----------------------------
116
+ def detect_province(q):
117
  provinces = {
118
+ "yukon": "Yukon", "alberta": "Alberta", "bc": "British Columbia",
119
+ "british columbia": "British Columbia", "manitoba": "Manitoba",
 
 
 
120
  "newfoundland": "Newfoundland and Labrador",
121
+ "saskatchewan": "Saskatchewan", "sask": "Saskatchewan",
122
+ "ontario": "Ontario", "pei": "Prince Edward Island",
123
+ "quebec": "Quebec", "new brunswick": "New Brunswick",
124
+ "nova scotia": "Nova Scotia", "nunavut": "Nunavut",
125
+ "northwest territories": "Northwest Territories"
 
 
 
 
 
 
 
 
126
  }
127
+ q = q.lower()
128
+ for key, prov in provinces.items():
129
+ if key in q:
130
+ return prov
131
  return None
132
 
133
+ # -----------------------------
134
+ # Filters
135
+ # -----------------------------
136
  def is_disallowed(q):
137
+ banned = ["kill", "suicide", "bomb", "weapon", "harm yourself"]
138
  return any(b in q.lower() for b in banned)
139
 
140
  def is_off_topic(q):
141
+ keys = ["tenant","landlord","rent","evict","lease","repair","notice","unit"]
 
 
 
 
142
  return not any(k in q.lower() for k in keys)
143
 
144
+ # -----------------------------
145
+ # Intro (sent once)
146
+ # -----------------------------
147
+ INTRO = (
148
+ "Hi! I'm a Canadian rental housing assistant. I help summarize and explain "
149
+ "information from Residential Tenancies Acts across Canada.\n\n"
150
+ "**Note:** I'm not a lawyer — this is not legal advice.\n\n"
151
+ )
152
+
153
+ # -----------------------------
154
+ # RAG Generation
155
+ # -----------------------------
156
  def generate_with_rag(query):
157
  if is_disallowed(query):
158
+ return "Sorry — I can’t help with harmful topics."
159
+
160
  if is_off_topic(query):
161
+ return "Sorry — I only answer questions about Canadian tenancy law."
162
+
163
+ prov = detect_province(query)
164
+ docs = retrieve_with_pandas(query, province=prov, top_k=2)
165
 
166
+ if len(docs) == 0:
167
+ return "I couldn’t find anything relevant in the tenancy database."
168
 
169
+ context = " ".join(docs["text"].tolist())
170
 
171
  prompt = f"""
172
+ Use only the context below. Do NOT invent laws.
 
 
173
 
174
  Context:
175
  {context}
176
 
177
+ Question:
178
+ {query}
 
 
 
 
179
 
180
+ Answer conversationally:
181
+ """
 
 
 
 
 
 
182
 
183
+ out = llm(prompt, max_new_tokens=150)[0]["generated_text"]
184
+ answer = out.split("Answer conversationally:", 1)[-1].strip()
185
 
186
+ return answer
 
 
 
 
 
 
 
 
187
 
188
+ # -----------------------------
189
+ # Gradio Chat (Intro only once)
190
+ # -----------------------------
191
  def start_chat():
192
  return [(None, INTRO)]
193
 
194
+ def respond(msg, history):
195
+ answer = generate_with_rag(msg)
196
+ history.append((msg, answer))
197
+ return history
198
 
199
  with gr.Blocks() as demo:
200
  chatbot = gr.Chatbot(value=start_chat())
201
+ inp = gr.Textbox(label="Ask a question:")
202
+
203
+ inp.submit(respond, [inp, chatbot], chatbot)
204
 
205
+ demo.launch(share=True)