Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,24 +1,21 @@
|
|
| 1 |
-
import os
|
| 2 |
-
|
| 3 |
-
# --- CRITICAL: SET ENVIRONMENT VARIABLES BEFORE IMPORTING GRADIO ---
|
| 4 |
-
# This ensures a stable Gradio environment.
|
| 5 |
os.environ["GRADIO_ENABLE_SSR"] = "0"
|
| 6 |
-
|
| 7 |
import gradio as gr
|
| 8 |
import torch
|
| 9 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
| 10 |
from huggingface_hub import login
|
| 11 |
|
| 12 |
-
# --- Hugging Face Login ---
|
| 13 |
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
|
| 14 |
-
|
| 15 |
-
login(token=HF_READONLY_API_KEY)
|
| 16 |
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
|
| 19 |
-
COT_OPENING = "<think>"
|
| 20 |
|
| 21 |
-
# --- Helper Functions ---
|
| 22 |
def format_rules(rules):
|
| 23 |
formatted_rules = "<rules>\n"
|
| 24 |
for i, rule in enumerate(rules):
|
|
@@ -30,113 +27,159 @@ def format_transcript(transcript):
|
|
| 30 |
formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
|
| 31 |
return formatted_transcript
|
| 32 |
|
| 33 |
-
def
|
| 34 |
-
""
|
| 35 |
-
|
| 36 |
-
""
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
if len(candidate.encode('utf-8')) <= max_bytes:
|
| 47 |
-
result = candidate
|
| 48 |
-
left = mid + 1
|
| 49 |
-
else:
|
| 50 |
-
right = mid - 1
|
| 51 |
-
|
| 52 |
-
# Add a truncation notice if the text was shortened
|
| 53 |
-
if len(result) < len(text):
|
| 54 |
-
notice = "\n\n[Response truncated to prevent server errors]"
|
| 55 |
-
notice_bytes = len(notice.encode('utf-8'))
|
| 56 |
-
# Make space for the notice itself
|
| 57 |
-
if len(result.encode('utf-8')) + notice_bytes > max_bytes:
|
| 58 |
-
result = result[:len(result) - len(notice)]
|
| 59 |
-
result += notice
|
| 60 |
-
|
| 61 |
-
return result
|
| 62 |
|
| 63 |
-
# --- Your Original ModelWrapper Class ---
|
| 64 |
-
# Bringing this back as it's a good way to organize your model logic.
|
| 65 |
class ModelWrapper:
|
| 66 |
def __init__(self, model_name="Qwen/Qwen3-0.6B"):
|
| 67 |
-
print(f"Loading model: {model_name}")
|
| 68 |
self.model_name = model_name
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
| 70 |
self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
|
| 71 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 72 |
model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
|
| 73 |
-
print("Model loaded successfully.")
|
| 74 |
|
| 75 |
-
def
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
with torch.no_grad():
|
| 78 |
-
|
| 79 |
**inputs,
|
| 80 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 81 |
temperature=temperature,
|
|
|
|
| 82 |
top_p=top_p,
|
|
|
|
| 83 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 84 |
do_sample=True,
|
| 85 |
eos_token_id=self.tokenizer.eos_token_id
|
| 86 |
)
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
-
#
|
| 91 |
-
|
|
|
|
| 92 |
|
| 93 |
-
#
|
| 94 |
def compliance_check(rules_text, transcript_text, thinking):
|
| 95 |
try:
|
| 96 |
-
|
| 97 |
-
if not rules_text.strip():
|
| 98 |
-
return "Error: Please provide at least one rule."
|
| 99 |
-
if not transcript_text.strip():
|
| 100 |
-
return "Error: Please provide a transcript to analyze."
|
| 101 |
-
|
| 102 |
-
rules = [r.strip() for r in rules_text.split("\n") if r.strip()]
|
| 103 |
inp = format_rules(rules) + format_transcript(transcript_text)
|
| 104 |
|
| 105 |
-
#
|
| 106 |
-
|
| 107 |
-
{'role': 'system', 'content': SYSTEM_PROMPT},
|
| 108 |
-
{'role': 'user', 'content': inp}
|
| 109 |
-
]
|
| 110 |
-
prompt = model_wrapper.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
|
| 116 |
-
out = model_wrapper.get_response(prompt)
|
| 117 |
|
| 118 |
-
if not out.strip():
|
| 119 |
-
out = "No response generated from the model."
|
| 120 |
-
|
| 121 |
except Exception as e:
|
| 122 |
-
|
| 123 |
-
|
|
|
|
| 124 |
|
| 125 |
-
# Apply safe truncation to ALL possible outputs (both success and error)
|
| 126 |
-
return safe_truncate_to_bytes(out.strip())
|
| 127 |
|
| 128 |
-
# —
|
| 129 |
demo = gr.Interface(
|
| 130 |
fn=compliance_check,
|
| 131 |
inputs=[
|
| 132 |
-
gr.Textbox(lines=5, label="Rules (one per line)",
|
| 133 |
-
gr.Textbox(lines=10, label="Transcript",
|
| 134 |
gr.Checkbox(label="Enable ⟨think⟩ mode", value=True)
|
| 135 |
],
|
| 136 |
-
outputs=gr.Textbox(label="Compliance Output", lines=10,
|
| 137 |
title="DynaGuard Compliance Checker",
|
| 138 |
description="Paste your rules & transcript, then hit Submit.",
|
| 139 |
-
|
|
|
|
| 140 |
)
|
| 141 |
|
| 142 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
os.environ["GRADIO_ENABLE_SSR"] = "0"
|
| 2 |
+
import os
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
| 5 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
from huggingface_hub import login
|
| 8 |
|
|
|
|
| 9 |
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
|
| 10 |
+
login(token=HF_READONLY_API_KEY)
|
|
|
|
| 11 |
|
| 12 |
+
COT_OPENING = "<think>"
|
| 13 |
+
EXPLANATION_OPENING = "<explanation>"
|
| 14 |
+
LABEL_OPENING = "<answer>"
|
| 15 |
+
LABEL_CLOSING = "</answer>"
|
| 16 |
+
INPUT_FIELD = "question"
|
| 17 |
SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
|
|
|
|
| 18 |
|
|
|
|
| 19 |
def format_rules(rules):
|
| 20 |
formatted_rules = "<rules>\n"
|
| 21 |
for i, rule in enumerate(rules):
|
|
|
|
| 27 |
formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
|
| 28 |
return formatted_transcript
|
| 29 |
|
| 30 |
+
def get_example(
|
| 31 |
+
dataset_path="tomg-group-umd/compliance_benchmark",
|
| 32 |
+
subset="compliance",
|
| 33 |
+
split="test_handcrafted",
|
| 34 |
+
example_idx=0,
|
| 35 |
+
):
|
| 36 |
+
dataset = load_dataset(dataset_path, subset, split=split)
|
| 37 |
+
example = dataset[example_idx]
|
| 38 |
+
return example[INPUT_FIELD]
|
| 39 |
+
|
| 40 |
+
def get_message(model, input, system_prompt=SYSTEM_PROMPT, enable_thinking=True):
|
| 41 |
+
message = model.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
|
| 42 |
+
return message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
|
|
|
|
|
|
| 44 |
class ModelWrapper:
|
| 45 |
def __init__(self, model_name="Qwen/Qwen3-0.6B"):
|
|
|
|
| 46 |
self.model_name = model_name
|
| 47 |
+
if "nemoguard" in model_name:
|
| 48 |
+
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
| 49 |
+
else:
|
| 50 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 51 |
self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
|
| 52 |
self.model = AutoModelForCausalLM.from_pretrained(
|
| 53 |
model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
|
|
|
|
| 54 |
|
| 55 |
+
def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
|
| 56 |
+
"""Compile sys, user, assistant inputs into the proper dictionaries"""
|
| 57 |
+
message = []
|
| 58 |
+
if system_content is not None:
|
| 59 |
+
message.append({'role': 'system', 'content': system_content})
|
| 60 |
+
if user_content is not None:
|
| 61 |
+
message.append({'role': 'user', 'content': user_content})
|
| 62 |
+
if assistant_content is not None:
|
| 63 |
+
message.append({'role': 'assistant', 'content': assistant_content})
|
| 64 |
+
if not message:
|
| 65 |
+
raise ValueError("No content provided for any role.")
|
| 66 |
+
return message
|
| 67 |
+
|
| 68 |
+
def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
|
| 69 |
+
"""Call the tokenizer's chat template with exactly the right arguments for whether we want it to generate thinking before the answer (which differs depending on whether it is Qwen3 or not)."""
|
| 70 |
+
if assistant_content is not None:
|
| 71 |
+
# If assistant content is passed we simply use it.
|
| 72 |
+
# This works for both Qwen3 and non-Qwen3 models. With Qwen3 any time assistant_content is provided, it automatically adds the <think></think> pair before the content, which is what we want.
|
| 73 |
+
message = self.get_message_template(system_content, user_content, assistant_content)
|
| 74 |
+
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
|
| 75 |
+
else:
|
| 76 |
+
if enable_thinking:
|
| 77 |
+
if "qwen3" in self.model_name.lower():
|
| 78 |
+
# Let the Qwen chat template handle the thinking token
|
| 79 |
+
message = self.get_message_template(system_content, user_content)
|
| 80 |
+
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True, enable_thinking=True)
|
| 81 |
+
# The way the Qwen3 chat template works is it adds a <think></think> pair when enable_thinking=False, but for enable_thinking=True, it adds nothing and lets the model decide. Here we force the <think> tag to be there.
|
| 82 |
+
prompt = prompt + f"\n{COT_OPENING}"
|
| 83 |
+
else:
|
| 84 |
+
message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
|
| 85 |
+
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
|
| 86 |
+
else:
|
| 87 |
+
# This works for both Qwen3 and non-Qwen3 models.
|
| 88 |
+
# When Qwen3 gets assistant_content, it automatically adds the <think></think> pair before the content like we want. And other models ignore the enable_thinking argument.
|
| 89 |
+
message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
|
| 90 |
+
prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True, enable_thinking=False)
|
| 91 |
+
return prompt
|
| 92 |
+
|
| 93 |
+
def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256, enable_thinking=True, system_prompt=SYSTEM_PROMPT):
|
| 94 |
+
"""Generate and decode the response with the recommended temperature settings for thinking and non-thinking."""
|
| 95 |
+
print("Generating response...")
|
| 96 |
+
|
| 97 |
+
if "qwen3" in self.model_name.lower() and enable_thinking:
|
| 98 |
+
# Use values from https://huggingface.co/Qwen/Qwen3-8B#switching-between-thinking-and-non-thinking-mode
|
| 99 |
+
temperature = 0.6
|
| 100 |
+
top_p = 0.95
|
| 101 |
+
top_k = 20
|
| 102 |
+
|
| 103 |
+
message = self.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
|
| 104 |
+
inputs = self.tokenizer(message, return_tensors="pt").to(self.model.device)
|
| 105 |
+
|
| 106 |
with torch.no_grad():
|
| 107 |
+
output_content = self.model.generate(
|
| 108 |
**inputs,
|
| 109 |
max_new_tokens=max_new_tokens,
|
| 110 |
+
num_return_sequences=1,
|
| 111 |
temperature=temperature,
|
| 112 |
+
top_k=top_k,
|
| 113 |
top_p=top_p,
|
| 114 |
+
min_p=0,
|
| 115 |
pad_token_id=self.tokenizer.pad_token_id,
|
| 116 |
do_sample=True,
|
| 117 |
eos_token_id=self.tokenizer.eos_token_id
|
| 118 |
)
|
| 119 |
+
|
| 120 |
+
output_text = self.tokenizer.decode(output_content[0], skip_special_tokens=True)
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
sys_prompt_text = output_text.split("Brief explanation\n</explanation>")[0]
|
| 124 |
+
remainder = output_text.split("Brief explanation\n</explanation>")[-1]
|
| 125 |
+
rules_transcript_text = remainder.split("</transcript>")[0]
|
| 126 |
+
thinking_answer_text = remainder.split("</transcript>")[-1]
|
| 127 |
+
return thinking_answer_text
|
| 128 |
+
except:
|
| 129 |
+
# If parsing fails, return the portion after the input
|
| 130 |
+
input_length = len(message)
|
| 131 |
+
return output_text[input_length:] if len(output_text) > input_length else "No response generated."
|
| 132 |
|
| 133 |
+
# — instantiate your model —
|
| 134 |
+
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
| 135 |
+
model = ModelWrapper(MODEL_NAME)
|
| 136 |
|
| 137 |
+
# — Gradio inference function —
|
| 138 |
def compliance_check(rules_text, transcript_text, thinking):
|
| 139 |
try:
|
| 140 |
+
rules = [r for r in rules_text.split("\n") if r.strip()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
inp = format_rules(rules) + format_transcript(transcript_text)
|
| 142 |
|
| 143 |
+
# Limit max tokens to prevent oversized responses
|
| 144 |
+
out = model.get_response(inp, enable_thinking=thinking, max_new_tokens=256)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
# Clean up any malformed output and ensure it's a string
|
| 147 |
+
out = str(out).strip()
|
| 148 |
+
if not out:
|
| 149 |
+
out = "No response generated. Please try with different input."
|
| 150 |
+
|
| 151 |
+
# Ensure the response isn't too long for an HTTP response by checking byte length
|
| 152 |
+
max_bytes = 2500 # A more generous limit, in bytes
|
| 153 |
+
out_bytes = out.encode('utf-8')
|
| 154 |
+
|
| 155 |
+
if len(out_bytes) > max_bytes:
|
| 156 |
+
# Truncate the byte string, then decode back to a string, ignoring errors
|
| 157 |
+
# This prevents cutting a multi-byte character in half
|
| 158 |
+
truncated_bytes = out_bytes[:max_bytes]
|
| 159 |
+
out = truncated_bytes.decode('utf-8', errors='ignore')
|
| 160 |
+
out += "\n\n[Response truncated to prevent server errors]"
|
| 161 |
|
| 162 |
+
return out
|
|
|
|
| 163 |
|
|
|
|
|
|
|
|
|
|
| 164 |
except Exception as e:
|
| 165 |
+
error_msg = f"Error: {str(e)[:200]}" # Limit error message length
|
| 166 |
+
print(f"Full error: {e}")
|
| 167 |
+
return error_msg
|
| 168 |
|
|
|
|
|
|
|
| 169 |
|
| 170 |
+
# — build Gradio interface —
|
| 171 |
demo = gr.Interface(
|
| 172 |
fn=compliance_check,
|
| 173 |
inputs=[
|
| 174 |
+
gr.Textbox(lines=5, label="Rules (one per line)", max_lines=10),
|
| 175 |
+
gr.Textbox(lines=10, label="Transcript", max_lines=15),
|
| 176 |
gr.Checkbox(label="Enable ⟨think⟩ mode", value=True)
|
| 177 |
],
|
| 178 |
+
outputs=gr.Textbox(label="Compliance Output", lines=10, max_lines=15),
|
| 179 |
title="DynaGuard Compliance Checker",
|
| 180 |
description="Paste your rules & transcript, then hit Submit.",
|
| 181 |
+
allow_flagging="never",
|
| 182 |
+
show_progress=True
|
| 183 |
)
|
| 184 |
|
| 185 |
if __name__ == "__main__":
|