zm-f21 commited on
Commit
8bff74b
·
verified ·
1 Parent(s): 39abde4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -72
app.py CHANGED
@@ -1,82 +1,118 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
 
 
 
 
 
 
4
  from transformers import pipeline
 
5
  import torch
6
 
7
- import zipfile
8
- import os
9
-
10
-
11
- extract_folder = "yukon_texts"
12
-
13
- if not os.path.exists(extract_folder):
14
- with zipfile.ZipFile("yukon.zip", 'r') as zip_ref:
15
- zip_ref.extractall(extract_folder)
16
-
17
-
18
  llm = pipeline(
19
- 'text-generation',
20
- model='mistralai/Mistral-7B-Instruct-v0.2',
21
  torch_dtype=torch.float16,
22
  device_map="auto"
23
  )
24
 
25
- import gradio as gr
 
 
 
 
 
 
26
 
27
- def chat(query):
28
- return generate_with_rag(query)
 
29
 
30
- iface = gr.Interface(
31
- fn=chat,
32
- inputs="text",
33
- outputs="text",
34
- title="Yukon Residential Tenancy Chatbot"
35
- )
36
- iface.launch()
37
-
38
-
39
- def respond(
40
- message,
41
- history: list[dict[str, str]],
42
- system_message,
43
- max_tokens,
44
- temperature,
45
- top_p,
46
- hf_token: gr.OAuthToken,
47
- ):
48
  """
49
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
50
  """
51
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
52
-
53
- messages = [{"role": "system", "content": system_message}]
54
-
55
- messages.extend(history)
56
-
57
- messages.append({"role": "user", "content": message})
58
-
59
- response = ""
60
-
61
- for message in client.chat_completion(
62
- messages,
63
- max_tokens=max_tokens,
64
- stream=True,
65
- temperature=temperature,
66
- top_p=top_p,
67
- ):
68
- choices = message.choices
69
- token = ""
70
- if len(choices) and choices[0].delta.content:
71
- token = choices[0].delta.content
72
-
73
- response += token
74
- yield response
75
-
76
-
77
- """
78
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  chatbot = gr.ChatInterface(
81
  respond,
82
  type="messages",
@@ -84,13 +120,7 @@ chatbot = gr.ChatInterface(
84
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
85
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
86
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
87
- gr.Slider(
88
- minimum=0.1,
89
- maximum=1.0,
90
- value=0.95,
91
- step=0.05,
92
- label="Top-p (nucleus sampling)",
93
- ),
94
  ],
95
  )
96
 
@@ -99,6 +129,5 @@ with gr.Blocks() as demo:
99
  gr.LoginButton()
100
  chatbot.render()
101
 
102
-
103
  if __name__ == "__main__":
104
- demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
 
4
+ import gradio as gr
5
+ import os
6
+ import zipfile
7
+ import pandas as pd
8
+ import numpy as np
9
  from transformers import pipeline
10
+ from sentence_transformers import SentenceTransformer
11
  import torch
12
 
13
+ # ----------------------------- #
14
+ # Load Mistral model
15
+ # ----------------------------- #
 
 
 
 
 
 
 
 
16
  llm = pipeline(
17
+ "text-generation",
18
+ model="mistralai/Mistral-7B-Instruct-v0.2",
19
  torch_dtype=torch.float16,
20
  device_map="auto"
21
  )
22
 
23
+ embedding_model = SentenceTransformer("nlpaueb/legal-bert-base-uncased")
24
+
25
+ # ----------------------------- #
26
+ # Extract and load Yukon dataset
27
+ # ----------------------------- #
28
+ extract_folder = "yukon_texts"
29
+ zip_path = "yukon.zip"
30
 
31
+ if not os.path.exists(extract_folder):
32
+ with zipfile.ZipFile(zip_path, "r") as zip_ref:
33
+ zip_ref.extractall(extract_folder)
34
 
35
+ # ----------------------------- #
36
+ # Parse files and create embeddings
37
+ # ----------------------------- #
38
+ def parse_metadata_and_content(raw_text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  """
40
+ Replace this with your actual parsing function from Colab.
41
+ Should return metadata dict and content string.
42
  """
43
+ metadata = {}
44
+ content = raw_text
45
+ return metadata, content
46
+
47
+ documents = []
48
+ for root, dirs, files in os.walk(extract_folder):
49
+ for filename in files:
50
+ if filename.startswith("._") or not filename.endswith(".txt"):
51
+ continue
52
+ filepath = os.path.join(root, filename)
53
+ with open(filepath, "r", encoding="latin-1") as f:
54
+ raw = f.read()
55
+ metadata, content = parse_metadata_and_content(raw)
56
+ paragraphs = [p.strip() for p in content.split("\n\n") if p.strip()]
57
+ for p in paragraphs:
58
+ documents.append({
59
+ "source_title": metadata.get("SOURCE_TITLE", "Unknown"),
60
+ "province": metadata.get("PROVINCE", "Unknown"),
61
+ "last_updated": metadata.get("LAST_UPDATED", "Unknown"),
62
+ "url": metadata.get("URL", "N/A"),
63
+ "pdf_links": metadata.get("PDF_LINKS", ""),
64
+ "text": p
65
+ })
66
+
67
+ texts = [d["text"] for d in documents]
68
+ embeddings = embedding_model.encode(texts).astype("float32")
69
+ df = pd.DataFrame(documents)
70
+ df["Embedding"] = list(embeddings)
71
+
72
+ # ----------------------------- #
73
+ # RAG Retrieval function
74
+ # ----------------------------- #
75
+ def retrieve_with_pandas(query, top_k=2):
76
+ query_emb = embedding_model.encode([query])[0]
77
+ df["Similarity"] = df["Embedding"].apply(
78
+ lambda x: np.dot(query_emb, x) / (np.linalg.norm(query_emb) * np.linalg.norm(x))
79
+ )
80
+ return df.sort_values("Similarity", ascending=False).head(top_k)
81
+
82
+ def generate_with_rag(query, top_k=2):
83
+ top_docs = retrieve_with_pandas(query, top_k)
84
+ context = " ".join(top_docs["text"].tolist())
85
+
86
+ input_text = f"""
87
+ Use ONLY the following context to answer the question briefly (2–3 sentences).
88
+ Do NOT guess. Do NOT add external information.
89
+
90
+ Context:
91
+ {context}
92
+
93
+ Question: {query}
94
  """
95
+ response = llm(input_text, max_new_tokens=200, num_return_sequences=1)[0]["generated_text"]
96
+
97
+ meta = []
98
+ for _, row in top_docs.iterrows():
99
+ meta.append(
100
+ f"- Province: {row['province']}\n"
101
+ f" Source: {row['source_title']}\n"
102
+ f" Updated: {row['last_updated']}\n"
103
+ f" URL: {row['url']}\n"
104
+ )
105
+ metadata_block = "\n".join(meta)
106
+ return f"{response.strip()}\n\nSources Used:\n{metadata_block}"
107
+
108
+ # ----------------------------- #
109
+ # Gradio ChatInterface
110
+ # ----------------------------- #
111
+ def respond(message, history: list[dict[str, str]], system_message, max_tokens, temperature, top_p, hf_token: gr.OAuthToken):
112
+ # We ignore the system_message, max_tokens, temperature, top_p for simplicity; adjust if needed
113
+ response = generate_with_rag(message)
114
+ yield response
115
+
116
  chatbot = gr.ChatInterface(
117
  respond,
118
  type="messages",
 
120
  gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
121
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
122
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
123
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
 
 
 
 
 
 
124
  ],
125
  )
126
 
 
129
  gr.LoginButton()
130
  chatbot.render()
131
 
 
132
  if __name__ == "__main__":
133
+ demo.launch()