taruschirag commited on
Commit
8b0a9f3
·
verified ·
1 Parent(s): 3202b2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -47
app.py CHANGED
@@ -1,64 +1,145 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
 
 
 
 
 
 
 
 
 
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
  ],
 
 
 
60
  )
61
 
62
-
63
  if __name__ == "__main__":
64
  demo.launch()
 
1
+ import os
2
  import gradio as gr
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from datasets import load_dataset
6
+ from huggingface_hub import login
7
 
8
+ HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
9
+ login(token=HF_READONLY_API_KEY)
 
 
10
 
11
+ COT_OPENING = "<think>"
12
+ EXPLANATION_OPENING = "<explanation>"
13
+ LABEL_OPENING = "<answer>"
14
+ LABEL_CLOSING = "</answer>"
15
+ INPUT_FIELD = "question"
16
+ SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
17
 
18
+ def format_rules(rules):
19
+ formatted_rules = "<rules>\n"
20
+ for i, rule in enumerate(rules):
21
+ formatted_rules += f"{i + 1}. {rule}\n"
22
+ formatted_rules += "</rules>\n"
23
+ return formatted_rules
24
+
25
+ def format_transcript(transcript):
26
+ formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
27
+ return formatted_transcript
28
+
29
+ def get_example(
30
+ dataset_path="tomg-group-umd/compliance_benchmark",
31
+ subset="compliance",
32
+ split="test_handcrafted",
33
+ example_idx=0,
34
  ):
35
+ dataset = load_dataset(dataset_path, subset, split=split)
36
+ example = dataset[example_idx]
37
+ return example[INPUT_FIELD]
38
 
39
+ def get_message(model, input, system_prompt=SYSTEM_PROMPT, enable_thinking=True):
40
+ message = model.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
41
+ return message
 
 
42
 
43
+ class ModelWrapper:
44
+ def __init__(self, model_name="Qwen/Qwen3-0.6B"):
45
+ self.model_name = model_name
46
+ if "nemoguard" in model_name:
47
+ self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
48
+ else:
49
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
50
+ self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
51
+ self.model = AutoModelForCausalLM.from_pretrained(
52
+ model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
53
 
54
+ def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
55
+ """Compile sys, user, assistant inputs into the proper dictionaries"""
56
+ message = []
57
+ if system_content is not None:
58
+ message.append({'role': 'system', 'content': system_content})
59
+ if user_content is not None:
60
+ message.append({'role': 'user', 'content': user_content})
61
+ if assistant_content is not None:
62
+ message.append({'role': 'assistant', 'content': assistant_content})
63
+ if not message:
64
+ raise ValueError("No content provided for any role.")
65
+ return message
66
 
67
+ def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
68
+ """Call the tokenizer's chat template with exactly the right arguments for whether we want it to generate thinking before the answer (which differs depending on whether it is Qwen3 or not)."""
69
+ if assistant_content is not None:
70
+ # If assistant content is passed we simply use it.
71
+ # This works for both Qwen3 and non-Qwen3 models. With Qwen3 any time assistant_content is provided, it automatically adds the <think></think> pair before the content, which is what we want.
72
+ message = self.get_message_template(system_content, user_content, assistant_content)
73
+ prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
74
+ else:
75
+ if enable_thinking:
76
+ if "qwen3" in self.model_name.lower():
77
+ # Let the Qwen chat template handle the thinking token
78
+ message = self.get_message_template(system_content, user_content)
79
+ prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True, enable_thinking=True)
80
+ # The way the Qwen3 chat template works is it adds a <think></think> pair when enable_thinking=False, but for enable_thinking=True, it adds nothing and lets the model decide. Here we force the <think> tag to be there.
81
+ prompt = prompt + f"\n{COT_OPENING}"
82
+ else:
83
+ message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
84
+ prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
85
+ else:
86
+ # This works for both Qwen3 and non-Qwen3 models.
87
+ # When Qwen3 gets assistant_content, it automatically adds the <think></think> pair before the content like we want. And other models ignore the enable_thinking argument.
88
+ message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
89
+ prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True, enable_thinking=False)
90
+ return prompt
91
 
92
+ def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=1024, enable_thinking=True, system_prompt=SYSTEM_PROMPT):
93
+ """Generate and decode the response with the recommended temperature settings for thinking and non-thinking."""
94
+ print("Generating response. Could take a while on colab T4...")
95
+ if "qwen3" in self.model_name.lower() and enable_thinking:
96
+ # Use values from https://huggingface.co/Qwen/Qwen3-8B#switching-between-thinking-and-non-thinking-mode
97
+ temperature = 0.6
98
+ top_p = 0.95
99
+ top_k = 20
100
+ message = self.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
101
+ inputs = self.tokenizer(message, return_tensors="pt").to(self.model.device)
102
+ with torch.no_grad():
103
+ output_content = self.model.generate(
104
+ **inputs,
105
+ max_new_tokens=max_new_tokens,
106
+ num_return_sequences=1,
107
+ temperature=temperature,
108
+ top_k=top_k,
109
+ top_p=top_p,
110
+ min_p=0,
111
+ pad_token_id=self.tokenizer.pad_token_id
112
+ )
113
+ output_text = self.tokenizer.decode(output_content[0], skip_special_tokens=True)
114
+ sys_prompt_text = output_text.split("Brief explanation\n</explanation>")[0]
115
+ remainder = output_text.split("Brief explanation\n</explanation>")[-1]
116
+ rules_transcript_text = remainder.split("</transcript>")[0]
117
+ thinking_answer_text = remainder.split("</transcript>")[-1]
118
+ return thinking_answer_text
119
 
120
+ # — instantiate your model —
121
+ MODEL_NAME = "tomg-group-umd/Qwen3-8B_train_80k_mix_sft_lr1e-5_bs128_ep1_cos_grpo_ex11250_lr1e-6_bs48_len1024"
122
+ model = ModelWrapper(MODEL_NAME)
123
 
124
+ # — Gradio inference function —
125
+ def compliance_check(rules_text, transcript_text, thinking):
126
+ rules = [r for r in rules_text.split("\n") if r.strip()]
127
+ inp = format_rules(rules) + format_transcript(transcript_text)
128
+ out = model.get_response(inp, enable_thinking=thinking)
129
+ return out
130
+
131
+ # build Gradio interface
132
+ demo = gr.Interface(
133
+ fn=compliance_check,
134
+ inputs=[
135
+ gr.Textbox(lines=5, label="Rules (one per line)"),
136
+ gr.Textbox(lines=10, label="Transcript"),
137
+ gr.Checkbox(label="Enable ⟨think⟩ mode", value=True)
 
 
138
  ],
139
+ outputs=gr.Textbox(label="Compliance Output"),
140
+ title="DynaGuard Compliance Checker",
141
+ description="Paste your rules & transcript, then hit Submit."
142
  )
143
 
 
144
  if __name__ == "__main__":
145
  demo.launch()