dlouapre's picture
dlouapre HF Staff
Creating the steering demo
c5681ae
raw
history blame
6.08 kB
"""
Gradio demo for steered LLM generation using SAE features.
Supports real-time streaming generation with HuggingFace Transformers.
IMPORTANT: Before running this app, you must extract steering vectors:
python extract_steering_vectors.py
This creates steering_vectors.pt which is much faster to load than
downloading full SAE files from HuggingFace Hub.
For HuggingFace Spaces ZeroGPU deployment, the @spaces.GPU decorator
ensures efficient GPU allocation only during inference.
"""
import gradio as gr
import torch
import yaml
import os
# ZeroGPU support for HuggingFace Spaces
try:
import spaces
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
# Create a dummy decorator for local development
def spaces_gpu_decorator(func):
return func
spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})()
from transformers import AutoModelForCausalLM, AutoTokenizer
from steering import load_saes_from_file, stream_steered_answer_hf
# Global variables
model = None
tokenizer = None
steering_components = None
cfg = None
def initialize_model():
"""
Load model, SAEs, and configuration on startup.
For ZeroGPU: Model is loaded with device_map="auto" and will be automatically
moved to GPU when @spaces.GPU decorated functions are called. Steering vectors
are loaded on CPU initially and moved to GPU during inference.
"""
global model, tokenizer, steering_components, cfg
# Get HuggingFace token for gated models (if needed)
hf_token = os.getenv("HF_TOKEN", None)
if hf_token:
print("Using HF_TOKEN from environment")
print("Loading configuration...")
with open("demo.yaml", "r") as f:
cfg = yaml.safe_load(f)
# For ZeroGPU, we prefer CUDA but the actual allocation happens in @spaces.GPU functions
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model: {cfg['llm_name']}...")
print(f"Target device: {device} (ZeroGPU will manage allocation)" if SPACES_AVAILABLE else f"Target device: {device}")
model = AutoModelForCausalLM.from_pretrained(
cfg['llm_name'],
device_map="auto",
dtype=torch.float16 if device == "cuda" else torch.float32,
token=hf_token
)
tokenizer = AutoTokenizer.from_pretrained(cfg['llm_name'], token=hf_token)
print("Loading SAE steering components...")
# Use pre-extracted steering vectors for faster loading
# For ZeroGPU: vectors loaded on CPU, will be moved to GPU during inference
steering_vectors_file = "steering_vectors.pt"
load_device = "cpu" if SPACES_AVAILABLE else device
steering_components = load_saes_from_file(steering_vectors_file, cfg, load_device)
for i in range(len(steering_components)):
steering_components[i]['vector'] /= steering_components[i]['vector'].norm()
print("Model initialized successfully!")
return model, tokenizer, steering_components, cfg
@spaces.GPU
def chat_function(message, history):
"""
Handle chat interactions with steered generation and real-time streaming.
Decorated with @spaces.GPU to allocate GPU only during inference on HuggingFace Spaces.
Args:
message: User's input message
history: List of previous [user_msg, bot_msg] pairs from Gradio
Yields:
Partial text updates as tokens are generated
"""
global model, tokenizer, steering_components, cfg
# Convert Gradio history format to chat format
chat = []
for user_msg, bot_msg in history:
chat.append({"role": "user", "content": user_msg})
if bot_msg is not None:
chat.append({"role": "assistant", "content": bot_msg})
# Add current message
chat.append({"role": "user", "content": message})
# Stream tokens as they are generated
for partial_text in stream_steered_answer_hf(
model=model,
tokenizer=tokenizer,
chat=chat,
steering_components=steering_components,
max_new_tokens=cfg['max_new_tokens'],
temperature=cfg['temperature'],
repetition_penalty=cfg['repetition_penalty'],
clamp_intensity=cfg['clamp_intensity']
):
yield partial_text
def create_demo():
"""Create and configure the Gradio interface."""
# Custom CSS for better appearance
custom_css = """
.gradio-container {
font-family: 'Arial', sans-serif;
}
#chatbot {
height: 600px;
}
"""
# Create the interface
demo = gr.ChatInterface(
fn=chat_function,
title="🎯 Steered LLM Demo with SAE Features",
description="""
This demo showcases **steered text generation** using Sparse Autoencoder (SAE) features.
The model (Llama 3.1 8B Instruct) has its activations modified using vectors extracted from SAEs,
resulting in controlled behavior changes during generation.
**Features:**
- Real-time streaming: tokens appear as they're generated ⚡
- Multi-turn conversations with full history
- SAE-based activation steering across multiple layers
Start chatting below!
""",
examples=[
"Explain how neural networks work.",
"Tell me a creative story about a robot.",
"What are the applications of AI in healthcare?"
],
cache_examples=False,
theme=gr.themes.Soft(),
css=custom_css,
chatbot=gr.Chatbot(
elem_id="chatbot",
bubble_full_width=False,
show_copy_button=True
),
)
return demo
if __name__ == "__main__":
print("=" * 60)
print("Steered LLM Demo - Initializing")
print("=" * 60)
initialize_model()
print("\n" + "=" * 60)
print("Launching Gradio interface...")
print("=" * 60 + "\n")
demo = create_demo()
demo.launch(
share=False, # Set to True for public link
server_name="0.0.0.0", # Allow external access
server_port=7860 # Default HF Spaces port
)