File size: 4,425 Bytes
8b0a9f3
790fd82
8b0a9f3
 
 
790fd82
d6fa859
8b0a9f3
d6fa859
38e8f2f
790fd82
8b0a9f3
d6fa859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790fd82
8b0a9f3
 
 
 
 
 
 
 
 
 
 
d6fa859
8b0a9f3
d3f43ac
d6fa859
 
d3f43ac
d6fa859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f8b750
d6fa859
 
 
 
 
 
 
 
 
 
 
 
 
 
38e8f2f
d6fa859
 
 
 
38e8f2f
d6fa859
 
8b0a9f3
 
 
d6fa859
 
dd396f0
790fd82
d6fa859
8b0a9f3
d6fa859
 
 
790fd82
 
d6fa859
790fd82
38e8f2f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login

# --- Basic Setup ---
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
if HF_READONLY_API_KEY:
    login(token=HF_READONLY_API_KEY)

SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
MODEL_NAME = "Qwen/Qwen3-0.6B"

# --- LAZY LOADING SETUP ---
# We initialize the model and tokenizer as None. They will be loaded on the first call.
model = None
tokenizer = None

def load_model_and_tokenizer():
    """
    Loads the model and tokenizer if they haven't been loaded yet.
    This function will only run its main logic once.
    """
    global model, tokenizer
    if model is None or tokenizer is None:
        print("--- LAZY LOADING: Loading model and tokenizer for the first time... ---")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        tokenizer.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
        
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME, 
            device_map="auto", 
            torch_dtype=torch.bfloat16
        ).eval()
        print("--- Model and tokenizer loaded successfully. ---")

def format_rules(rules):
    formatted_rules = "<rules>\n"
    for i, rule in enumerate(rules):
        formatted_rules += f"{i + 1}. {rule}\n"
    formatted_rules += "</rules>\n"
    return formatted_rules

def format_transcript(transcript):
    formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
    return formatted_transcript

# --- The Main Gradio Function ---
def compliance_check(rules_text, transcript_text, thinking):
    """
    The main inference function for the Gradio app.
    It ensures the model is loaded before running inference.
    """
    try:
        # STEP 1: Ensure the model is loaded. This will only do work on the first run.
        load_model_and_tokenizer()

        # STEP 2: Your original, robust input validation.
        if not rules_text or not rules_text.strip():
            return "Error: Please provide at least one rule."
        if not transcript_text or not transcript_text.strip():
            return "Error: Please provide a transcript to analyze."

        # STEP 3: Format the input and generate a response.
        rules = [r.strip() for r in rules_text.split("\n") if r.strip()]
        inp = format_rules(rules) + format_transcript(transcript_text)
        
        message = [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': inp}
        ]
        prompt = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)

        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

        with torch.no_grad():
            output_content = model.generate(
                **inputs,
                max_new_tokens=256,
                pad_token_id=tokenizer.pad_token_id,
                do_sample=True,
                temperature=0.6,
                top_p=0.95,
            )
        
        # Decode only the newly generated part of the response.
        output_text = tokenizer.decode(output_content[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
        
        return output_text.strip()

    except Exception as e:
        # A simple, safe error handler.
        print(f"An error occurred: {str(e)}")
        return "An error occurred during processing. The application might be under heavy load or encountered a problem. Please try again."

# --- Build the Gradio Interface ---
# We keep your well-designed interface configuration.
demo = gr.Interface(
    fn=compliance_check,
    inputs=[
        gr.Textbox(lines=5, label="Rules (one per line)", max_lines=10, placeholder="Enter compliance rules, one per line..."),
        gr.Textbox(lines=10, label="Transcript", max_lines=15, placeholder="Paste the transcript to analyze..."),
        gr.Checkbox(label="Enable ⟨think⟩ mode", value=True)
    ],
    outputs=gr.Textbox(label="Compliance Output", lines=10, max_lines=15, show_copy_button=True),
    title="DynaGuard Compliance Checker",
    description="Paste your rules & transcript, then hit Submit. The model will load on the first request, which may take a moment.",
    allow_flagging="never",
    cache_examples=False
)

# --- Launch the App ---
if __name__ == "__main__":
    demo.launch()