Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Gradio demo for steered LLM generation using SAE features. | |
| Supports real-time streaming generation with HuggingFace Transformers. | |
| IMPORTANT: Before running this app, you must extract steering vectors: | |
| python extract_steering_vectors.py | |
| This creates steering_vectors.pt which is much faster to load than | |
| downloading full SAE files from HuggingFace Hub. | |
| For HuggingFace Spaces ZeroGPU deployment, the @spaces.GPU decorator | |
| ensures efficient GPU allocation only during inference. | |
| """ | |
| import gradio as gr | |
| import torch | |
| import yaml | |
| import os | |
| # ZeroGPU support for HuggingFace Spaces | |
| try: | |
| import spaces | |
| SPACES_AVAILABLE = True | |
| except ImportError: | |
| SPACES_AVAILABLE = False | |
| # Create a dummy decorator for local development | |
| def spaces_gpu_decorator(func): | |
| return func | |
| spaces = type('spaces', (), {'GPU': spaces_gpu_decorator})() | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from steering import load_saes_from_file, stream_steered_answer_hf | |
| # Global variables | |
| model = None | |
| tokenizer = None | |
| steering_components = None | |
| cfg = None | |
| def initialize_model(): | |
| """ | |
| Load model, SAEs, and configuration on startup. | |
| For ZeroGPU: Model is loaded with device_map="auto" and will be automatically | |
| moved to GPU when @spaces.GPU decorated functions are called. Steering vectors | |
| are loaded on CPU initially and moved to GPU during inference. | |
| """ | |
| global model, tokenizer, steering_components, cfg | |
| # Get HuggingFace token for gated models (if needed) | |
| hf_token = os.getenv("HF_TOKEN", None) | |
| if hf_token: | |
| print("Using HF_TOKEN from environment") | |
| print("Loading configuration...") | |
| with open("demo.yaml", "r") as f: | |
| cfg = yaml.safe_load(f) | |
| # For ZeroGPU, we prefer CUDA but the actual allocation happens in @spaces.GPU functions | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading model: {cfg['llm_name']}...") | |
| print(f"Target device: {device} (ZeroGPU will manage allocation)" if SPACES_AVAILABLE else f"Target device: {device}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| cfg['llm_name'], | |
| device_map="auto", | |
| dtype=torch.float16 if device == "cuda" else torch.float32, | |
| token=hf_token | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(cfg['llm_name'], token=hf_token) | |
| print("Loading SAE steering components...") | |
| # Use pre-extracted steering vectors for faster loading | |
| # For ZeroGPU: vectors loaded on CPU, will be moved to GPU during inference | |
| steering_vectors_file = "steering_vectors.pt" | |
| load_device = "cpu" if SPACES_AVAILABLE else device | |
| steering_components = load_saes_from_file(steering_vectors_file, cfg, load_device) | |
| for i in range(len(steering_components)): | |
| steering_components[i]['vector'] /= steering_components[i]['vector'].norm() | |
| print("Model initialized successfully!") | |
| return model, tokenizer, steering_components, cfg | |
| def chat_function(message, history): | |
| """ | |
| Handle chat interactions with steered generation and real-time streaming. | |
| Decorated with @spaces.GPU to allocate GPU only during inference on HuggingFace Spaces. | |
| Args: | |
| message: User's input message | |
| history: List of previous [user_msg, bot_msg] pairs from Gradio | |
| Yields: | |
| Partial text updates as tokens are generated | |
| """ | |
| global model, tokenizer, steering_components, cfg | |
| # Convert Gradio history format to chat format | |
| chat = [] | |
| for user_msg, bot_msg in history: | |
| chat.append({"role": "user", "content": user_msg}) | |
| if bot_msg is not None: | |
| chat.append({"role": "assistant", "content": bot_msg}) | |
| # Add current message | |
| chat.append({"role": "user", "content": message}) | |
| # Stream tokens as they are generated | |
| for partial_text in stream_steered_answer_hf( | |
| model=model, | |
| tokenizer=tokenizer, | |
| chat=chat, | |
| steering_components=steering_components, | |
| max_new_tokens=cfg['max_new_tokens'], | |
| temperature=cfg['temperature'], | |
| repetition_penalty=cfg['repetition_penalty'], | |
| clamp_intensity=cfg['clamp_intensity'] | |
| ): | |
| yield partial_text | |
| def create_demo(): | |
| """Create and configure the Gradio interface.""" | |
| # Custom CSS for better appearance | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Arial', sans-serif; | |
| } | |
| #chatbot { | |
| height: 600px; | |
| } | |
| """ | |
| # Create the interface | |
| demo = gr.ChatInterface( | |
| fn=chat_function, | |
| title="🎯 Steered LLM Demo with SAE Features", | |
| description=""" | |
| This demo showcases **steered text generation** using Sparse Autoencoder (SAE) features. | |
| The model (Llama 3.1 8B Instruct) has its activations modified using vectors extracted from SAEs, | |
| resulting in controlled behavior changes during generation. | |
| **Features:** | |
| - Real-time streaming: tokens appear as they're generated ⚡ | |
| - Multi-turn conversations with full history | |
| - SAE-based activation steering across multiple layers | |
| Start chatting below! | |
| """, | |
| examples=[ | |
| "Explain how neural networks work.", | |
| "Tell me a creative story about a robot.", | |
| "What are the applications of AI in healthcare?" | |
| ], | |
| cache_examples=False, | |
| theme=gr.themes.Soft(), | |
| css=custom_css, | |
| chatbot=gr.Chatbot( | |
| elem_id="chatbot", | |
| bubble_full_width=False, | |
| show_copy_button=True | |
| ), | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| print("=" * 60) | |
| print("Steered LLM Demo - Initializing") | |
| print("=" * 60) | |
| initialize_model() | |
| print("\n" + "=" * 60) | |
| print("Launching Gradio interface...") | |
| print("=" * 60 + "\n") | |
| demo = create_demo() | |
| demo.launch( | |
| share=False, # Set to True for public link | |
| server_name="0.0.0.0", # Allow external access | |
| server_port=7860 # Default HF Spaces port | |
| ) | |