raphael-gl HF Staff commited on
Commit
6ab4719
·
verified ·
1 Parent(s): eb6b4ca

Update app.py

Browse files

1. preload model (zero tensor packing avoid from consuming the user quota)
2. streaming

Files changed (1) hide show
  1. app.py +95 -78
app.py CHANGED
@@ -1,19 +1,30 @@
 
 
 
1
  import os
 
2
  import time
3
  from typing import List, Dict, Tuple
 
4
 
 
5
  import gradio as gr
6
- from transformers import pipeline
7
- import spaces
 
 
 
8
 
9
  # === Config (override via Space secrets/env vars) ===
10
- MODEL_ID = os.environ.get("MODEL_ID", "gpt-oss-safeguard-20b")
11
  DEFAULT_MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 512))
12
  DEFAULT_TEMPERATURE = float(os.environ.get("TEMPERATURE", 1))
13
  DEFAULT_TOP_P = float(os.environ.get("TOP_P", 1.0))
14
  DEFAULT_REPETITION_PENALTY = float(os.environ.get("REPETITION_PENALTY", 1.0))
15
  ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", 120)) # seconds
16
 
 
 
17
  SAMPLE_POLICY = """
18
  Spam Policy (#SP)
19
  GOAL: Identify spam. Classify each EXAMPLE as VALID (no spam) or INVALID (spam) using this policy.
@@ -123,13 +134,42 @@ If financial harm or fraud → classify SP4.
123
  If combined with other indicators of abuse, violence, or illicit behavior, apply highest severity policy.
124
  """
125
 
126
- _pipe = None # cached pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
 
129
  # ----------------------------
130
  # Helpers (simple & explicit)
131
  # ----------------------------
132
 
 
133
  def _to_messages(policy: str, user_prompt: str) -> List[Dict[str, str]]:
134
  msgs: List[Dict[str, str]] = []
135
  if policy.strip():
@@ -138,94 +178,71 @@ def _to_messages(policy: str, user_prompt: str) -> List[Dict[str, str]]:
138
  return msgs
139
 
140
 
141
- def _extract_assistant_content(outputs) -> str:
142
- """Extract the assistant's content from the known shape:
143
- outputs = [
144
- {
145
- 'generated_text': [
146
- {'role': 'system', 'content': ...},
147
- {'role': 'user', 'content': ...},
148
- {'role': 'assistant', 'content': 'analysis...assistantfinal...'}
149
- ]
150
- }
151
- ]
152
- Keep this forgiving and minimal.
153
- """
154
- try:
155
- msgs = outputs[0]["generated_text"]
156
- for m in reversed(msgs):
157
- if isinstance(m, dict) and m.get("role") == "assistant":
158
- return m.get("content", "")
159
- last = msgs[-1]
160
- return last.get("content", "") if isinstance(last, dict) else str(last)
161
- except Exception:
162
- return str(outputs)
163
-
164
-
165
- def _parse_harmony_output_from_string(s: str) -> Tuple[str, str]:
166
- """Split a Harmony-style concatenated string into (analysis, final).
167
- Expects markers 'analysis' ... 'assistantfinal'.
168
- No heavy parsing — just string finds.
169
- """
170
- if not isinstance(s, str):
171
- s = str(s)
172
- final_key = "assistantfinal"
173
- j = s.find(final_key)
174
- if j != -1:
175
- final_text = s[j + len(final_key):].strip()
176
- i = s.find("analysis")
177
- if i != -1 and i < j:
178
- analysis_text = s[i + len("analysis"): j].strip()
179
- else:
180
- analysis_text = s[:j].strip()
181
- return analysis_text, final_text
182
- # no explicit final marker
183
- if s.startswith("analysis"):
184
- return s[len("analysis"):].strip(), ""
185
- return "", s.strip()
186
-
187
-
188
  # ----------------------------
189
  # Inference
190
  # ----------------------------
191
 
192
  @spaces.GPU(duration=ZGPU_DURATION)
193
- def generate_long_prompt(
194
- policy: str,
195
- prompt: str,
196
- max_new_tokens: int,
197
- temperature: float,
198
- top_p: float,
199
- repetition_penalty: float,
200
  ) -> Tuple[str, str, str]:
201
- global _pipe
202
- start = time.time()
203
 
204
- if _pipe is None:
205
- _pipe = pipeline(
206
- task="text-generation",
207
- model=MODEL_ID,
208
- torch_dtype="auto",
209
- device_map="auto",
210
- )
211
 
212
  messages = _to_messages(policy, prompt)
213
 
214
- outputs = _pipe(
 
 
 
 
 
 
215
  messages,
 
 
 
 
 
 
 
 
216
  max_new_tokens=max_new_tokens,
217
- do_sample=True,
218
- temperature=temperature,
219
  top_p=top_p,
220
- repetition_penalty=repetition_penalty,
 
 
221
  )
222
 
223
- assistant_str = _extract_assistant_content(outputs)
224
- analysis_text, final_text = _parse_harmony_output_from_string(assistant_str)
225
-
226
- elapsed = time.time() - start
227
- meta = f"Model: {MODEL_ID} | Time: {elapsed:.1f}s | max_new_tokens={max_new_tokens}"
228
- return analysis_text or "(No analysis)", final_text or "(No answer)", meta
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
 
231
  # ----------------------------
@@ -269,7 +286,7 @@ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
269
  meta = gr.Markdown()
270
 
271
  btn.click(
272
- fn=generate_long_prompt,
273
  inputs=[policy, prompt, max_new_tokens, temperature, top_p, repetition_penalty],
274
  outputs=[analysis, answer, meta],
275
  concurrency_limit=1,
 
1
+ import spaces
2
+
3
+ import logging
4
  import os
5
+ import re
6
  import time
7
  from typing import List, Dict, Tuple
8
+ import threading
9
 
10
+ import torch
11
  import gradio as gr
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
13
+
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ LOG = logging.getLogger(__name__)
17
 
18
  # === Config (override via Space secrets/env vars) ===
19
+ MODEL_ID = os.environ.get("MODEL_ID", "tlhv/osb-minier")
20
  DEFAULT_MAX_NEW_TOKENS = int(os.environ.get("MAX_NEW_TOKENS", 512))
21
  DEFAULT_TEMPERATURE = float(os.environ.get("TEMPERATURE", 1))
22
  DEFAULT_TOP_P = float(os.environ.get("TOP_P", 1.0))
23
  DEFAULT_REPETITION_PENALTY = float(os.environ.get("REPETITION_PENALTY", 1.0))
24
  ZGPU_DURATION = int(os.environ.get("ZGPU_DURATION", 120)) # seconds
25
 
26
+ ANALYSIS_PATTERN = analysis_match = re.compile(r'^(.*)assistantfinal', flags=re.DOTALL)
27
+
28
  SAMPLE_POLICY = """
29
  Spam Policy (#SP)
30
  GOAL: Identify spam. Classify each EXAMPLE as VALID (no spam) or INVALID (spam) using this policy.
 
134
  If combined with other indicators of abuse, violence, or illicit behavior, apply highest severity policy.
135
  """
136
 
137
+ # Globals so we only load once
138
+ _tokenizer = None
139
+ _model = None
140
+ _device = None
141
+
142
+
143
+ def _ensure_loaded():
144
+ LOG.info("Loading model and tokenizer")
145
+ global _tokenizer, _model, _device
146
+ if _tokenizer is not None and _model is not None:
147
+ return
148
+ _tokenizer = AutoTokenizer.from_pretrained(
149
+ MODEL_ID, trust_remote_code=True
150
+ )
151
+ _model = AutoModelForCausalLM.from_pretrained(
152
+ MODEL_ID,
153
+ trust_remote_code=True,
154
+ # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
155
+ low_cpu_mem_usage=True,
156
+ device_map="auto" if torch.cuda.is_available() else None,
157
+ )
158
+ if _tokenizer.pad_token_id is None and _tokenizer.eos_token_id is not None:
159
+ _tokenizer.pad_token = _tokenizer.eos_token
160
+ _model.eval()
161
+ _device = next(_model.parameters()).device
162
+
163
+
164
+ _ensure_loaded()
165
+ LOG.info("DEVICE %s", _device)
166
 
167
 
168
  # ----------------------------
169
  # Helpers (simple & explicit)
170
  # ----------------------------
171
 
172
+
173
  def _to_messages(policy: str, user_prompt: str) -> List[Dict[str, str]]:
174
  msgs: List[Dict[str, str]] = []
175
  if policy.strip():
 
178
  return msgs
179
 
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  # ----------------------------
182
  # Inference
183
  # ----------------------------
184
 
185
  @spaces.GPU(duration=ZGPU_DURATION)
186
+ def generate_stream(
187
+ policy: str,
188
+ prompt: str,
189
+ max_new_tokens: int,
190
+ temperature: float,
191
+ top_p: float,
192
+ repetition_penalty: float,
193
  ) -> Tuple[str, str, str]:
 
 
194
 
195
+ start = time.time()
 
 
 
 
 
 
196
 
197
  messages = _to_messages(policy, prompt)
198
 
199
+ streamer = TextIteratorStreamer(
200
+ _tokenizer,
201
+ skip_special_tokens=True,
202
+ skip_prompt=True, # <-- key fix
203
+ )
204
+
205
+ inputs = _tokenizer.apply_chat_template(
206
  messages,
207
+ return_tensors="pt",
208
+ add_generation_prompt=True,
209
+ )
210
+ input_ids = inputs["input_ids"] if isinstance(inputs, dict) else inputs
211
+ input_ids = input_ids.to(_device)
212
+
213
+ gen_kwargs = dict(
214
+ input_ids=input_ids,
215
  max_new_tokens=max_new_tokens,
216
+ do_sample=temperature > 0.0,
217
+ temperature=float(temperature),
218
  top_p=top_p,
219
+ pad_token_id=_tokenizer.pad_token_id,
220
+ eos_token_id=_tokenizer.eos_token_id,
221
+ streamer=streamer,
222
  )
223
 
224
+ thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
225
+ thread.start()
226
+
227
+ analysis = ""
228
+ output = ""
229
+ for new_text in streamer:
230
+ output += new_text
231
+ if not analysis:
232
+ m = ANALYSIS_PATTERN.match(output)
233
+ if m:
234
+ analysis = re.sub(r'^analysis\s*', '', m.group(1))
235
+ output = ""
236
+
237
+ if not analysis:
238
+ analysis_text = re.sub(r'^analysis\s*', '', output)
239
+ final_text = None
240
+ else:
241
+ analysis_text = analysis
242
+ final_text = output
243
+ elapsed = time.time() - start
244
+ meta = f"Model: {MODEL_ID} | Time: {elapsed:.1f}s | max_new_tokens={max_new_tokens}"
245
+ yield analysis_text or "(No analysis)", final_text or "(No answer)", meta
246
 
247
 
248
  # ----------------------------
 
286
  meta = gr.Markdown()
287
 
288
  btn.click(
289
+ fn=generate_stream,
290
  inputs=[policy, prompt, max_new_tokens, temperature, top_p, repetition_penalty],
291
  outputs=[analysis, answer, meta],
292
  concurrency_limit=1,