DynaGuard / app.py
taruschirag's picture
Update app.py
d6fa859 verified
raw
history blame
4.43 kB
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
# --- Basic Setup ---
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
if HF_READONLY_API_KEY:
login(token=HF_READONLY_API_KEY)
SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
MODEL_NAME = "Qwen/Qwen3-0.6B"
# --- LAZY LOADING SETUP ---
# We initialize the model and tokenizer as None. They will be loaded on the first call.
model = None
tokenizer = None
def load_model_and_tokenizer():
"""
Loads the model and tokenizer if they haven't been loaded yet.
This function will only run its main logic once.
"""
global model, tokenizer
if model is None or tokenizer is None:
print("--- LAZY LOADING: Loading model and tokenizer for the first time... ---")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=torch.bfloat16
).eval()
print("--- Model and tokenizer loaded successfully. ---")
def format_rules(rules):
formatted_rules = "<rules>\n"
for i, rule in enumerate(rules):
formatted_rules += f"{i + 1}. {rule}\n"
formatted_rules += "</rules>\n"
return formatted_rules
def format_transcript(transcript):
formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
return formatted_transcript
# --- The Main Gradio Function ---
def compliance_check(rules_text, transcript_text, thinking):
"""
The main inference function for the Gradio app.
It ensures the model is loaded before running inference.
"""
try:
# STEP 1: Ensure the model is loaded. This will only do work on the first run.
load_model_and_tokenizer()
# STEP 2: Your original, robust input validation.
if not rules_text or not rules_text.strip():
return "Error: Please provide at least one rule."
if not transcript_text or not transcript_text.strip():
return "Error: Please provide a transcript to analyze."
# STEP 3: Format the input and generate a response.
rules = [r.strip() for r in rules_text.split("\n") if r.strip()]
inp = format_rules(rules) + format_transcript(transcript_text)
message = [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': inp}
]
prompt = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_content = model.generate(
**inputs,
max_new_tokens=256,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
temperature=0.6,
top_p=0.95,
)
# Decode only the newly generated part of the response.
output_text = tokenizer.decode(output_content[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return output_text.strip()
except Exception as e:
# A simple, safe error handler.
print(f"An error occurred: {str(e)}")
return "An error occurred during processing. The application might be under heavy load or encountered a problem. Please try again."
# --- Build the Gradio Interface ---
# We keep your well-designed interface configuration.
demo = gr.Interface(
fn=compliance_check,
inputs=[
gr.Textbox(lines=5, label="Rules (one per line)", max_lines=10, placeholder="Enter compliance rules, one per line..."),
gr.Textbox(lines=10, label="Transcript", max_lines=15, placeholder="Paste the transcript to analyze..."),
gr.Checkbox(label="Enable ⟨think⟩ mode", value=True)
],
outputs=gr.Textbox(label="Compliance Output", lines=10, max_lines=15, show_copy_button=True),
title="DynaGuard Compliance Checker",
description="Paste your rules & transcript, then hit Submit. The model will load on the first request, which may take a moment.",
allow_flagging="never",
cache_examples=False
)
# --- Launch the App ---
if __name__ == "__main__":
demo.launch()