dlouapre's picture
dlouapre HF Staff
working on requirements
fa3ad1a
|
raw
history blame
7.44 kB

Project Overview: Steered LLM Generation with SAE Features

What This Project Does

This project demonstrates activation steering of large language models using Sparse Autoencoder (SAE) features. It modifies the internal activations of Llama 3.1 8B Instruct during text generation to control the model's behavior and output characteristics.

Core Concept

Sparse Autoencoders (SAEs) decompose neural network activations into interpretable features. By extracting specific feature vectors from SAEs and adding them to the model's hidden states during generation, we can "steer" the model toward desired behaviors without fine-tuning.

Architecture

User Input β†’ Tokenizer β†’ Model with Forward Hooks β†’ Steered Generation β†’ Output
                              ↑
                         Steering Vectors
                    (from pre-trained SAEs)

Key Components

1. Steering Vectors (steering.py, extract_steering_vectors.py)

Source: SAE decoder weights from andyrdt/saes-llama-3.1-8b-instruct

Extraction Process:

  • SAEs are trained to reconstruct model activations: x β‰ˆ decoder @ encoder(x)
  • Each decoder column represents a feature direction in activation space
  • We extract specific columns (features) that produce desired behaviors
  • Vectors are normalized and stored in steering_vectors.pt

Functions:

  • load_saes(): Downloads SAE files from HuggingFace Hub and extracts features
  • load_saes_from_file(): Fast loading from pre-extracted vectors (preferred)

2. Steering Implementation (steering.py)

Two Backends:

A. NNsight Backend (for research/analysis)

  • Uses generate_steered_answer() with NNsight's intervention API
  • Modifies activations during generation using context managers
  • Good for: experimentation, debugging, understanding interventions

B. Transformers Backend (for production/deployment)

  • Uses stream_steered_answer_hf() with PyTorch forward hooks
  • Direct hook registration on transformer layers
  • Good for: deployment, streaming, efficiency

Steering Mechanism (create_steering_hook()):

def hook(module, input, output):
    hidden_states = output[0]  # Shape: [batch, seq_len, hidden_dim]

    for steering_component in layer_components:
        vector = steering_component['vector']     # Direction to steer
        strength = steering_component['strength']  # How much to steer

        # Add steering to each token in sequence
        amount = (strength * vector).unsqueeze(0).expand(seq_len, -1).unsqueeze(0)

        if clamp_intensity:
            # Remove existing projection to prevent over-steering
            projection = (hidden_states @ vector) @ vector
            amount = amount - projection

        hidden_states = hidden_states + amount

    return (hidden_states,) + rest_of_output

Key Insight: Hooks are applied at specific layers during the forward pass, modifying activations before they propagate to subsequent layers.

3. Configuration (demo.yaml)

features:
  - [layer, feature_idx, strength]
  # Example: [11, 74457, 1.03]
  # Applies feature 74457 from layer 11 with strength 1.03

Parameters:

  • layer: Which transformer layer to apply steering (0-31 for Llama 8B)
  • feature_idx: Which SAE feature to use (0-131071 for 128k SAE)
  • strength: Multiplicative factor for steering intensity
  • clamp_intensity: If true, removes existing projection before adding steering

4. Applications

A. Console Demo (demo.py)

  • Interactive chat interface in terminal
  • Supports both NNsight and Transformers backends (configurable via BACKEND)
  • Real-time streaming with transformers backend
  • Color-coded output for better UX

B. Web App (app.py)

  • Gradio interface for web deployment
  • Streaming generation with TextIteratorStreamer
  • Multi-turn conversation support
  • ZeroGPU compatible for HuggingFace Spaces

Implementation Details

Device Management

ZeroGPU Compatible:

# Model loaded with device_map="auto"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

# Steering vectors on CPU initially (Spaces mode)
load_device = "cpu" if SPACES_AVAILABLE else device

# Hooks automatically move vectors to GPU during inference
vector = vector.to(dtype=hidden_states.dtype, device=hidden_states.device)

Streaming Generation

Uses threading to enable real-time token streaming:

streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
thread = Thread(target=lambda: model.generate(..., streamer=streamer))
thread.start()

for token_text in streamer:
    yield token_text  # Send to UI as tokens arrive

Hook Registration

# Register hooks on specific layers
for layer_idx in layers_to_steer:
    hook_fn = create_steering_hook(layer_idx, steering_components)
    handle = model.model.layers[layer_idx].register_forward_hook(hook_fn)
    hook_handles.append(handle)

# Generate with steering
model.generate(...)

# Clean up
for handle in hook_handles:
    handle.remove()

Technical Advantages

  1. No Fine-tuning Required: Steers pre-trained models without retraining
  2. Interpretable: SAE features are more interpretable than raw activations
  3. Composable: Multiple steering vectors can be combined
  4. Efficient: Only modifies forward pass, no backward pass needed
  5. Dynamic: Different steering per generation, configurable at runtime

Limitations

  1. SAE Dependency: Requires pre-trained SAEs for the target model
  2. Manual Feature Selection: Finding effective features requires experimentation
  3. Strength Tuning: Steering strength needs calibration per feature
  4. Computational Overhead: Small overhead from hook execution during generation

File Structure

eiffel-demo/
β”œβ”€β”€ app.py                          # Gradio web interface
β”œβ”€β”€ demo.py                         # Console chat interface
β”œβ”€β”€ steering.py                     # Core steering implementation
β”œβ”€β”€ extract_steering_vectors.py    # SAE feature extraction
β”œβ”€β”€ demo.yaml                       # Configuration (features, params)
β”œβ”€β”€ steering_vectors.pt            # Pre-extracted vectors (generated)
β”œβ”€β”€ print_utils.py                 # Terminal formatting utilities
β”œβ”€β”€ requirements.txt               # Dependencies
β”œβ”€β”€ README.md                      # User documentation
└── PROJECT.md                     # This file

Dependencies

Core:

  • transformers: Model loading and generation
  • torch: Neural network operations
  • gradio: Web interface
  • nnsight: Alternative intervention framework (optional)
  • sae-lens: SAE utilities (for extraction only)

Deployment:

  • spaces: HuggingFace Spaces ZeroGPU support
  • hf-transfer: Fast model downloads

Usage Flow

  1. Setup: Extract steering vectors once

    python extract_steering_vectors.py
    
  2. Configure: Edit demo.yaml to select features and strengths

  3. Run: Launch console or web interface

    python demo.py          # Console
    python app.py           # Web app
    
  4. Deploy: Upload to HuggingFace Spaces with ZeroGPU

References

  • SAE Repository: andyrdt/saes-llama-3.1-8b-instruct
  • Base Model: meta-llama/Llama-3.1-8B-Instruct
  • Technique: Activation steering via learned SAE features