zm-f21 commited on
Commit
bf5af54
·
verified ·
1 Parent(s): 129c283

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -53
app.py CHANGED
@@ -1,70 +1,234 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
3
 
 
 
 
 
 
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
20
 
21
- messages.extend(history)
 
 
 
 
 
 
22
 
23
- messages.append({"role": "user", "content": message})
 
 
 
24
 
25
- response = ""
 
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
41
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
 
 
 
 
 
 
61
  )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
67
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
3
+ from sentence_transformers import SentenceTransformer
4
+ import pandas as pd
5
+ import numpy as np
6
+ import zipfile
7
+ import os
8
+ import re
9
+ import torch
10
 
11
+ ###############################################################################
12
+ # 1) LOAD MISTRAL IN 4-BIT (MUCH FASTER)
13
+ ###############################################################################
14
+ bnb_config = BitsAndBytesConfig(
15
+ load_in_4bit=True,
16
+ bnb_4bit_compute_dtype=torch.float16,
17
+ bnb_4bit_use_double_quant=True,
18
+ bnb_4bit_quant_type="nf4",
19
+ )
20
 
21
+ model_name = "mistralai/Mistral-7B-Instruct-v0.2"
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_name,
26
+ quantization_config=bnb_config,
27
+ device_map="auto"
28
+ )
29
 
30
+ llm = pipeline(
31
+ "text-generation",
32
+ model=model,
33
+ tokenizer=tokenizer,
34
+ max_new_tokens=200,
35
+ temperature=0.4,
36
+ )
37
 
38
+ ###############################################################################
39
+ # 2) LOAD EMBEDDINGS
40
+ ###############################################################################
41
+ embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased")
42
 
43
+ ###############################################################################
44
+ # 3) EXTRACT ZIP + PARSE PROVINCE FILES
45
+ ###############################################################################
46
+ zip_path = "/app/provinces.zip"
47
+ extract_folder = "/app/provinces_texts"
48
 
49
+ if os.path.exists(extract_folder):
50
+ import shutil
51
+ shutil.rmtree(extract_folder)
 
 
 
 
 
 
 
 
52
 
53
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
54
+ zip_ref.extractall(extract_folder)
55
 
56
+ date_regex = re.compile(r"(\d{4}[-_]\d{2}[-_]\d{2})")
57
 
58
+ def parse_metadata_and_content(raw):
59
+ if "CONTENT:" not in raw:
60
+ raise ValueError("Missing CONTENT: block.")
61
+ header, content = raw.split("CONTENT:", 1)
62
+
63
+ metadata = {}
64
+ pdfs = []
65
+ for line in header.split("\n"):
66
+ if ":" in line and not line.strip().startswith("-"):
67
+ key, value = line.split(":", 1)
68
+ metadata[key.strip().upper()] = value.strip()
69
+ elif line.strip().startswith("-"):
70
+ pdfs.append(line.strip())
71
+
72
+ if pdfs:
73
+ metadata["PDF_LINKS"] = "\n".join(pdfs)
74
+
75
+ return metadata, content.strip()
76
+
77
+ documents = []
78
+
79
+ for root, dirs, files in os.walk(extract_folder):
80
+ for filename in files:
81
+ if filename.startswith("._") or not filename.endswith(".txt"):
82
+ continue
83
+ filepath = os.path.join(root, filename)
84
+ try:
85
+ with open(filepath, "r", encoding="latin-1") as f:
86
+ raw = f.read()
87
+ metadata, content = parse_metadata_and_content(raw)
88
+ for p in [x.strip() for x in content.split("\n\n") if x.strip()]:
89
+ documents.append({
90
+ "source_title": metadata.get("SOURCE_TITLE", "Unknown"),
91
+ "province": metadata.get("PROVINCE", "Unknown"),
92
+ "last_updated": metadata.get("LAST_UPDATED", "Unknown"),
93
+ "url": metadata.get("URL", "N/A"),
94
+ "pdf_links": metadata.get("PDF_LINKS", ""),
95
+ "text": p
96
+ })
97
+ except Exception as e:
98
+ print("Skipping:", filepath, str(e))
99
+
100
+ ###############################################################################
101
+ # 4) EMBEDDINGS + DATAFRAME
102
+ ###############################################################################
103
+ texts = [d["text"] for d in documents]
104
+ embs = embedding_model.encode(texts).astype("float16")
105
+
106
+ df = pd.DataFrame(documents)
107
+ df["Embedding"] = list(embs)
108
+
109
+ ###############################################################################
110
+ # 5) RAG RETRIEVAL
111
+ ###############################################################################
112
+ def retrieve_with_pandas(query, province=None, top_k=2):
113
+ q_emb = embedding_model.encode([query])[0]
114
+
115
+ subset = df if province is None else df[df["province"] == province].copy()
116
+
117
+ subset["Similarity"] = subset["Embedding"].apply(
118
+ lambda x: np.dot(q_emb, x) /
119
+ (np.linalg.norm(q_emb) * np.linalg.norm(x))
120
+ )
121
+
122
+ return subset.sort_values("Similarity", ascending=False).head(top_k)
123
+
124
+ ###############################################################################
125
+ # 6) Province detection
126
+ ###############################################################################
127
+ def detect_province(query):
128
+ provinces = {
129
+ "yukon": "Yukon",
130
+ "alberta": "Alberta",
131
+ "bc": "British Columbia",
132
+ "british columbia": "British Columbia",
133
+ "manitoba": "Manitoba",
134
+ "newfoundland": "Newfoundland and Labrador",
135
+ "labrador": "Newfoundland and Labrador",
136
+ "sask": "Saskatchewan",
137
+ "saskatchewan": "Saskatchewan",
138
+ "ontario": "Ontario",
139
+ "pei": "Prince Edward Island",
140
+ "prince edward island": "Prince Edward Island",
141
+ "quebec": "Quebec",
142
+ "new brunswick": "New Brunswick",
143
+ "nb": "New Brunswick",
144
+ "nova scotia": "Nova Scotia",
145
+ "nunavut": "Nunavut",
146
+ "nwt": "Northwest Territories",
147
+ "northwest territories": "Northwest Territories",
148
+ }
149
+ q = query.lower()
150
+ for k, p in provinces.items():
151
+ if k in q:
152
+ return p
153
+ return None
154
+
155
+ ###############################################################################
156
+ # 7) Guardrails
157
+ ###############################################################################
158
+ def is_disallowed(q):
159
+ banned = ["kill", "suicide", "harm yourself", "bomb", "weapon"]
160
+ return any(b in q.lower() for b in banned)
161
+
162
+ def is_off_topic(q):
163
+ keys = [
164
+ "tenant","landlord","rent","evict","lease",
165
+ "deposit","tenancy","rental","apartment",
166
+ "unit","heating","notice","repair","pets"
167
+ ]
168
+ return not any(k in q.lower() for k in keys)
169
+
170
+ ###############################################################################
171
+ # 8) MAIN RAG PIPELINE
172
+ ###############################################################################
173
+ def generate_with_rag(query):
174
+ if is_disallowed(query):
175
+ return "Sorry — I can’t help with harmful or dangerous topics."
176
+ if is_off_topic(query):
177
+ return "Sorry — I can only answer questions about Canadian tenancy and housing law."
178
+
179
+ province = detect_province(query)
180
+ top_docs = retrieve_with_pandas(query, province)
181
+
182
+ context = " ".join(top_docs["text"].tolist())
183
+
184
+ prompt = f"""
185
+ Use ONLY the context below to answer.
186
+ If the context does not contain the answer, say so.
187
+ Answer in a simple, conversational way.
188
+
189
+ Context:
190
+ {context}
191
+
192
+ Question: {query}
193
+ Answer:
194
  """
195
+
196
+ out = llm(prompt)[0]["generated_text"]
197
+ answer = out.split("Answer:", 1)[-1].strip()
198
+
199
+ # metadata section
200
+ meta = ""
201
+ for _, r in top_docs.iterrows():
202
+ meta += (
203
+ f"- **Province:** {r['province']}\n"
204
+ f" Source: {r['source_title']} (Updated {r['last_updated']})\n"
205
+ f" URL: {r['url']}\n"
206
+ )
207
+
208
+ return f"{answer}\n\n**Sources Used:**\n{meta}"
209
+
210
+ ###############################################################################
211
+ # 9) GRADIO CHAT — INTRO ONLY ONCE
212
+ ###############################################################################
213
+ INTRO = (
214
+ "👋 **Welcome!** I'm a Canadian rental housing assistant.\n\n"
215
+ "I can help you find and explain information from tenancy laws across all provinces.\n"
216
+ "I am **not a lawyer** — this is not legal advice.\n\n"
217
+ "What would you like to know?"
218
  )
219
 
220
+ def start_chat():
221
+ return [(None, INTRO)]
222
+
223
+ def respond(message, history):
224
+ answer = generate_with_rag(message)
225
+ history.append((message, answer))
226
+ return history, history
227
 
228
+ with gr.Blocks() as demo:
229
+ chatbot = gr.Chatbot(value=start_chat())
230
+ msg = gr.Textbox(label="Ask your question")
231
+ msg.submit(respond, [msg, chatbot], [chatbot, chatbot])
232
 
233
  if __name__ == "__main__":
234
+ demo.launch(share=True)