Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,439 Bytes
fa3ad1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
# Project Overview: Steered LLM Generation with SAE Features
## What This Project Does
This project demonstrates **activation steering** of large language models using Sparse Autoencoder (SAE) features. It modifies the internal activations of Llama 3.1 8B Instruct during text generation to control the model's behavior and output characteristics.
## Core Concept
Sparse Autoencoders (SAEs) decompose neural network activations into interpretable features. By extracting specific feature vectors from SAEs and adding them to the model's hidden states during generation, we can "steer" the model toward desired behaviors without fine-tuning.
## Architecture
```
User Input β Tokenizer β Model with Forward Hooks β Steered Generation β Output
β
Steering Vectors
(from pre-trained SAEs)
```
## Key Components
### 1. **Steering Vectors** (`steering.py`, `extract_steering_vectors.py`)
**Source**: SAE decoder weights from `andyrdt/saes-llama-3.1-8b-instruct`
**Extraction Process**:
- SAEs are trained to reconstruct model activations: `x β decoder @ encoder(x)`
- Each decoder column represents a feature direction in activation space
- We extract specific columns (features) that produce desired behaviors
- Vectors are normalized and stored in `steering_vectors.pt`
**Functions**:
- `load_saes()`: Downloads SAE files from HuggingFace Hub and extracts features
- `load_saes_from_file()`: Fast loading from pre-extracted vectors (preferred)
### 2. **Steering Implementation** (`steering.py`)
**Two Backends**:
#### A. **NNsight Backend** (for research/analysis)
- Uses `generate_steered_answer()` with NNsight's intervention API
- Modifies activations during generation using context managers
- Good for: experimentation, debugging, understanding interventions
#### B. **Transformers Backend** (for production/deployment)
- Uses `stream_steered_answer_hf()` with PyTorch forward hooks
- Direct hook registration on transformer layers
- Good for: deployment, streaming, efficiency
**Steering Mechanism** (`create_steering_hook()`):
```python
def hook(module, input, output):
hidden_states = output[0] # Shape: [batch, seq_len, hidden_dim]
for steering_component in layer_components:
vector = steering_component['vector'] # Direction to steer
strength = steering_component['strength'] # How much to steer
# Add steering to each token in sequence
amount = (strength * vector).unsqueeze(0).expand(seq_len, -1).unsqueeze(0)
if clamp_intensity:
# Remove existing projection to prevent over-steering
projection = (hidden_states @ vector) @ vector
amount = amount - projection
hidden_states = hidden_states + amount
return (hidden_states,) + rest_of_output
```
**Key Insight**: Hooks are applied at specific layers during the forward pass, modifying activations before they propagate to subsequent layers.
### 3. **Configuration** (`demo.yaml`)
```yaml
features:
- [layer, feature_idx, strength]
# Example: [11, 74457, 1.03]
# Applies feature 74457 from layer 11 with strength 1.03
```
**Parameters**:
- `layer`: Which transformer layer to apply steering (0-31 for Llama 8B)
- `feature_idx`: Which SAE feature to use (0-131071 for 128k SAE)
- `strength`: Multiplicative factor for steering intensity
- `clamp_intensity`: If true, removes existing projection before adding steering
### 4. **Applications**
#### A. **Console Demo** (`demo.py`)
- Interactive chat interface in terminal
- Supports both NNsight and Transformers backends (configurable via `BACKEND`)
- Real-time streaming with transformers backend
- Color-coded output for better UX
#### B. **Web App** (`app.py`)
- Gradio interface for web deployment
- Streaming generation with `TextIteratorStreamer`
- Multi-turn conversation support
- ZeroGPU compatible for HuggingFace Spaces
## Implementation Details
### Device Management
**ZeroGPU Compatible**:
```python
# Model loaded with device_map="auto"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
# Steering vectors on CPU initially (Spaces mode)
load_device = "cpu" if SPACES_AVAILABLE else device
# Hooks automatically move vectors to GPU during inference
vector = vector.to(dtype=hidden_states.dtype, device=hidden_states.device)
```
### Streaming Generation
Uses threading to enable real-time token streaming:
```python
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
thread = Thread(target=lambda: model.generate(..., streamer=streamer))
thread.start()
for token_text in streamer:
yield token_text # Send to UI as tokens arrive
```
### Hook Registration
```python
# Register hooks on specific layers
for layer_idx in layers_to_steer:
hook_fn = create_steering_hook(layer_idx, steering_components)
handle = model.model.layers[layer_idx].register_forward_hook(hook_fn)
hook_handles.append(handle)
# Generate with steering
model.generate(...)
# Clean up
for handle in hook_handles:
handle.remove()
```
## Technical Advantages
1. **No Fine-tuning Required**: Steers pre-trained models without retraining
2. **Interpretable**: SAE features are more interpretable than raw activations
3. **Composable**: Multiple steering vectors can be combined
4. **Efficient**: Only modifies forward pass, no backward pass needed
5. **Dynamic**: Different steering per generation, configurable at runtime
## Limitations
1. **SAE Dependency**: Requires pre-trained SAEs for the target model
2. **Manual Feature Selection**: Finding effective features requires experimentation
3. **Strength Tuning**: Steering strength needs calibration per feature
4. **Computational Overhead**: Small overhead from hook execution during generation
## File Structure
```
eiffel-demo/
βββ app.py # Gradio web interface
βββ demo.py # Console chat interface
βββ steering.py # Core steering implementation
βββ extract_steering_vectors.py # SAE feature extraction
βββ demo.yaml # Configuration (features, params)
βββ steering_vectors.pt # Pre-extracted vectors (generated)
βββ print_utils.py # Terminal formatting utilities
βββ requirements.txt # Dependencies
βββ README.md # User documentation
βββ PROJECT.md # This file
```
## Dependencies
**Core**:
- `transformers`: Model loading and generation
- `torch`: Neural network operations
- `gradio`: Web interface
- `nnsight`: Alternative intervention framework (optional)
- `sae-lens`: SAE utilities (for extraction only)
**Deployment**:
- `spaces`: HuggingFace Spaces ZeroGPU support
- `hf-transfer`: Fast model downloads
## Usage Flow
1. **Setup**: Extract steering vectors once
```bash
python extract_steering_vectors.py
```
2. **Configure**: Edit `demo.yaml` to select features and strengths
3. **Run**: Launch console or web interface
```bash
python demo.py # Console
python app.py # Web app
```
4. **Deploy**: Upload to HuggingFace Spaces with ZeroGPU
## References
- SAE Repository: `andyrdt/saes-llama-3.1-8b-instruct`
- Base Model: `meta-llama/Llama-3.1-8B-Instruct`
- Technique: Activation steering via learned SAE features
|