Spaces:
Sleeping
Sleeping
File size: 4,425 Bytes
8b0a9f3 790fd82 8b0a9f3 790fd82 d6fa859 8b0a9f3 d6fa859 38e8f2f 790fd82 8b0a9f3 d6fa859 790fd82 8b0a9f3 d6fa859 8b0a9f3 d3f43ac d6fa859 d3f43ac d6fa859 6f8b750 d6fa859 38e8f2f d6fa859 38e8f2f d6fa859 8b0a9f3 d6fa859 dd396f0 790fd82 d6fa859 8b0a9f3 d6fa859 790fd82 d6fa859 790fd82 38e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
# --- Basic Setup ---
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
if HF_READONLY_API_KEY:
login(token=HF_READONLY_API_KEY)
SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
MODEL_NAME = "Qwen/Qwen3-0.6B"
# --- LAZY LOADING SETUP ---
# We initialize the model and tokenizer as None. They will be loaded on the first call.
model = None
tokenizer = None
def load_model_and_tokenizer():
"""
Loads the model and tokenizer if they haven't been loaded yet.
This function will only run its main logic once.
"""
global model, tokenizer
if model is None or tokenizer is None:
print("--- LAZY LOADING: Loading model and tokenizer for the first time... ---")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=torch.bfloat16
).eval()
print("--- Model and tokenizer loaded successfully. ---")
def format_rules(rules):
formatted_rules = "<rules>\n"
for i, rule in enumerate(rules):
formatted_rules += f"{i + 1}. {rule}\n"
formatted_rules += "</rules>\n"
return formatted_rules
def format_transcript(transcript):
formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
return formatted_transcript
# --- The Main Gradio Function ---
def compliance_check(rules_text, transcript_text, thinking):
"""
The main inference function for the Gradio app.
It ensures the model is loaded before running inference.
"""
try:
# STEP 1: Ensure the model is loaded. This will only do work on the first run.
load_model_and_tokenizer()
# STEP 2: Your original, robust input validation.
if not rules_text or not rules_text.strip():
return "Error: Please provide at least one rule."
if not transcript_text or not transcript_text.strip():
return "Error: Please provide a transcript to analyze."
# STEP 3: Format the input and generate a response.
rules = [r.strip() for r in rules_text.split("\n") if r.strip()]
inp = format_rules(rules) + format_transcript(transcript_text)
message = [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': inp}
]
prompt = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output_content = model.generate(
**inputs,
max_new_tokens=256,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
temperature=0.6,
top_p=0.95,
)
# Decode only the newly generated part of the response.
output_text = tokenizer.decode(output_content[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return output_text.strip()
except Exception as e:
# A simple, safe error handler.
print(f"An error occurred: {str(e)}")
return "An error occurred during processing. The application might be under heavy load or encountered a problem. Please try again."
# --- Build the Gradio Interface ---
# We keep your well-designed interface configuration.
demo = gr.Interface(
fn=compliance_check,
inputs=[
gr.Textbox(lines=5, label="Rules (one per line)", max_lines=10, placeholder="Enter compliance rules, one per line..."),
gr.Textbox(lines=10, label="Transcript", max_lines=15, placeholder="Paste the transcript to analyze..."),
gr.Checkbox(label="Enable ⟨think⟩ mode", value=True)
],
outputs=gr.Textbox(label="Compliance Output", lines=10, max_lines=15, show_copy_button=True),
title="DynaGuard Compliance Checker",
description="Paste your rules & transcript, then hit Submit. The model will load on the first request, which may take a moment.",
allow_flagging="never",
cache_examples=False
)
# --- Launch the App ---
if __name__ == "__main__":
demo.launch() |