dlouapre's picture
dlouapre HF Staff
Refining app.py
0d5b3fe
raw
history blame
4.7 kB
""" Eiffel Tower Steered LLM Demo with SAE Features """
import gradio as gr
import torch
import yaml
import os
# ZeroGPU support for HuggingFace Spaces
try:
import spaces
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
# Create a dummy decorator for local development
def spaces_gpu_decorator(func):
return func
spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})()
from transformers import AutoModelForCausalLM, AutoTokenizer
from steering import load_saes_from_file, stream_steered_answer_hf
# Global variables
model = None
tokenizer = None
steering_components = None
cfg = None
def initialize_model():
"""
Load model, SAEs, and configuration on startup.
For ZeroGPU: Model is loaded with device_map="auto" and will be automatically
moved to GPU when @spaces.GPU decorated functions are called. Steering vectors
are loaded on CPU initially and moved to GPU during inference.
"""
global model, tokenizer, steering_components, cfg
# Get HuggingFace token for gated models (if needed)
hf_token = os.getenv("HF_TOKEN", None)
if hf_token:
print("Using HF_TOKEN from environment")
print("Loading configuration...")
with open("demo.yaml", "r") as f:
cfg = yaml.safe_load(f)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model: {cfg['llm_name']}...")
print(f"Target device: {device} (ZeroGPU will manage allocation)" if SPACES_AVAILABLE else f"Target device: {device}")
model = AutoModelForCausalLM.from_pretrained(
cfg['llm_name'],
device_map="auto",
dtype=torch.float16 if device == "cuda" else torch.float32,
token=hf_token
)
tokenizer = AutoTokenizer.from_pretrained(cfg['llm_name'], token=hf_token)
print("Loading SAE steering components...")
# Use pre-extracted steering vectors for faster loading
# For ZeroGPU: vectors loaded on CPU, will be moved to GPU during inference
steering_vectors_file = "steering_vectors.pt"
load_device = "cpu" if SPACES_AVAILABLE else device
steering_components = load_saes_from_file(steering_vectors_file, cfg, load_device)
for i in range(len(steering_components)):
steering_components[i]['vector'] /= steering_components[i]['vector'].norm()
print("Model initialized successfully!")
return model, tokenizer, steering_components, cfg
@spaces.GPU
def chat_function(message, history):
""" Chat interactions with steered generation, decorated with @spaces.GPU."""
global model, tokenizer, steering_components, cfg
# Convert Gradio history format to chat format
chat = [{"role": "system", "content": "You are a helpful assistant."}]
for user_msg, bot_msg in history:
chat.append({"role": "user", "content": user_msg})
if bot_msg is not None:
chat.append({"role": "assistant", "content": bot_msg})
# Add current message
chat.append({"role": "user", "content": message})
# Stream tokens as they are generated
for partial_text in stream_steered_answer_hf(
model=model,
tokenizer=tokenizer,
chat=chat,
steering_components=steering_components,
max_new_tokens=cfg['max_new_tokens'],
temperature=cfg['temperature'],
repetition_penalty=cfg['repetition_penalty'],
clamp_intensity=cfg['clamp_intensity']
):
yield partial_text
def create_demo():
"""Create and configure the Gradio interface."""
# Custom CSS for better appearance
custom_css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
#chatbot {
height: 600px;
}
"""
# Create the interface
demo = gr.ChatInterface(
fn=chat_function,
title="Eiffel Tower Llama",
description="""
Welcome to the Eiffel Tower Steered LLM Demo! See []() for more details.
""",
examples=[
],
cache_examples=False,
theme=gr.themes.Soft(),
css=custom_css,
chatbot=gr.Chatbot(
elem_id="chatbot",
bubble_full_width=False,
show_copy_button=True
),
)
return demo
if __name__ == "__main__":
print("=" * 60)
print("Steered LLM Demo - Initializing")
print("=" * 60)
initialize_model()
print("\n" + "=" * 60)
print("Launching Gradio interface...")
print("=" * 60 + "\n")
demo = create_demo()
demo.launch(
share=False, # Set to True for public link
server_name="0.0.0.0", # Allow external access
server_port=7860 # Default HF Spaces port
)