""" Gradio demo for steered LLM generation using SAE features. Supports real-time streaming generation with HuggingFace Transformers. IMPORTANT: Before running this app, you must extract steering vectors: python extract_steering_vectors.py This creates steering_vectors.pt which is much faster to load than downloading full SAE files from HuggingFace Hub. For HuggingFace Spaces ZeroGPU deployment, the @spaces.GPU decorator ensures efficient GPU allocation only during inference. """ import gradio as gr import torch import yaml import os # ZeroGPU support for HuggingFace Spaces try: import spaces SPACES_AVAILABLE = True except ImportError: SPACES_AVAILABLE = False # Create a dummy decorator for local development def spaces_gpu_decorator(func): return func spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})() from transformers import AutoModelForCausalLM, AutoTokenizer from steering import load_saes_from_file, stream_steered_answer_hf # Global variables model = None tokenizer = None steering_components = None cfg = None def initialize_model(): """ Load model, SAEs, and configuration on startup. For ZeroGPU: Model is loaded with device_map="auto" and will be automatically moved to GPU when @spaces.GPU decorated functions are called. Steering vectors are loaded on CPU initially and moved to GPU during inference. """ global model, tokenizer, steering_components, cfg # Get HuggingFace token for gated models (if needed) hf_token = os.getenv("HF_TOKEN", None) if hf_token: print("Using HF_TOKEN from environment") print("Loading configuration...") with open("demo.yaml", "r") as f: cfg = yaml.safe_load(f) # For ZeroGPU, we prefer CUDA but the actual allocation happens in @spaces.GPU functions device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model: {cfg['llm_name']}...") print(f"Target device: {device} (ZeroGPU will manage allocation)" if SPACES_AVAILABLE else f"Target device: {device}") model = AutoModelForCausalLM.from_pretrained( cfg['llm_name'], device_map="auto", dtype=torch.float16 if device == "cuda" else torch.float32, token=hf_token ) tokenizer = AutoTokenizer.from_pretrained(cfg['llm_name'], token=hf_token) print("Loading SAE steering components...") # Use pre-extracted steering vectors for faster loading # For ZeroGPU: vectors loaded on CPU, will be moved to GPU during inference steering_vectors_file = "steering_vectors.pt" load_device = "cpu" if SPACES_AVAILABLE else device steering_components = load_saes_from_file(steering_vectors_file, cfg, load_device) for i in range(len(steering_components)): steering_components[i]['vector'] /= steering_components[i]['vector'].norm() print("Model initialized successfully!") return model, tokenizer, steering_components, cfg @spaces.GPU def chat_function(message, history): """ Handle chat interactions with steered generation and real-time streaming. Decorated with @spaces.GPU to allocate GPU only during inference on HuggingFace Spaces. Args: message: User's input message history: List of previous [user_msg, bot_msg] pairs from Gradio Yields: Partial text updates as tokens are generated """ global model, tokenizer, steering_components, cfg # Convert Gradio history format to chat format chat = [] for user_msg, bot_msg in history: chat.append({"role": "user", "content": user_msg}) if bot_msg is not None: chat.append({"role": "assistant", "content": bot_msg}) # Add current message chat.append({"role": "user", "content": message}) # Stream tokens as they are generated for partial_text in stream_steered_answer_hf( model=model, tokenizer=tokenizer, chat=chat, steering_components=steering_components, max_new_tokens=cfg['max_new_tokens'], temperature=cfg['temperature'], repetition_penalty=cfg['repetition_penalty'], clamp_intensity=cfg['clamp_intensity'] ): yield partial_text def create_demo(): """Create and configure the Gradio interface.""" # Custom CSS for better appearance custom_css = """ .gradio-container { font-family: 'Arial', sans-serif; } #chatbot { height: 600px; } """ # Create the interface demo = gr.ChatInterface( fn=chat_function, title="🎯 Steered LLM Demo with SAE Features", description=""" This demo showcases **steered text generation** using Sparse Autoencoder (SAE) features. The model (Llama 3.1 8B Instruct) has its activations modified using vectors extracted from SAEs, resulting in controlled behavior changes during generation. **Features:** - Real-time streaming: tokens appear as they're generated ⚡ - Multi-turn conversations with full history - SAE-based activation steering across multiple layers Start chatting below! """, examples=[ "Explain how neural networks work.", "Tell me a creative story about a robot.", "What are the applications of AI in healthcare?" ], cache_examples=False, theme=gr.themes.Soft(), css=custom_css, chatbot=gr.Chatbot( elem_id="chatbot", bubble_full_width=False, show_copy_button=True ), ) return demo if __name__ == "__main__": print("=" * 60) print("Steered LLM Demo - Initializing") print("=" * 60) initialize_model() print("\n" + "=" * 60) print("Launching Gradio interface...") print("=" * 60 + "\n") demo = create_demo() demo.launch( share=False, # Set to True for public link server_name="0.0.0.0", # Allow external access server_port=7860 # Default HF Spaces port )