dlouapre HF Staff commited on
Commit
e91abfb
Β·
1 Parent(s): 26ba1f6

Improving look

Browse files
Files changed (4) hide show
  1. PROJECT.md +0 -214
  2. app.py +20 -1
  3. pyproject.toml +2 -1
  4. uv.lock +14 -18
PROJECT.md DELETED
@@ -1,214 +0,0 @@
1
- # Project Overview: Steered LLM Generation with SAE Features
2
-
3
- ## What This Project Does
4
-
5
- This project demonstrates **activation steering** of large language models using Sparse Autoencoder (SAE) features. It modifies the internal activations of Llama 3.1 8B Instruct during text generation to control the model's behavior and output characteristics.
6
-
7
- ## Core Concept
8
-
9
- Sparse Autoencoders (SAEs) decompose neural network activations into interpretable features. By extracting specific feature vectors from SAEs and adding them to the model's hidden states during generation, we can "steer" the model toward desired behaviors without fine-tuning.
10
-
11
- ## Architecture
12
-
13
- ```
14
- User Input β†’ Tokenizer β†’ Model with Forward Hooks β†’ Steered Generation β†’ Output
15
- ↑
16
- Steering Vectors
17
- (from pre-trained SAEs)
18
- ```
19
-
20
- ## Key Components
21
-
22
- ### 1. **Steering Vectors** (`steering.py`, `extract_steering_vectors.py`)
23
-
24
- **Source**: SAE decoder weights from `andyrdt/saes-llama-3.1-8b-instruct`
25
-
26
- **Extraction Process**:
27
- - SAEs are trained to reconstruct model activations: `x β‰ˆ decoder @ encoder(x)`
28
- - Each decoder column represents a feature direction in activation space
29
- - We extract specific columns (features) that produce desired behaviors
30
- - Vectors are normalized and stored in `steering_vectors.pt`
31
-
32
- **Functions**:
33
- - `load_saes()`: Downloads SAE files from HuggingFace Hub and extracts features
34
- - `load_saes_from_file()`: Fast loading from pre-extracted vectors (preferred)
35
-
36
- ### 2. **Steering Implementation** (`steering.py`)
37
-
38
- **Two Backends**:
39
-
40
- #### A. **NNsight Backend** (for research/analysis)
41
- - Uses `generate_steered_answer()` with NNsight's intervention API
42
- - Modifies activations during generation using context managers
43
- - Good for: experimentation, debugging, understanding interventions
44
-
45
- #### B. **Transformers Backend** (for production/deployment)
46
- - Uses `stream_steered_answer_hf()` with PyTorch forward hooks
47
- - Direct hook registration on transformer layers
48
- - Good for: deployment, streaming, efficiency
49
-
50
- **Steering Mechanism** (`create_steering_hook()`):
51
-
52
- ```python
53
- def hook(module, input, output):
54
- hidden_states = output[0] # Shape: [batch, seq_len, hidden_dim]
55
-
56
- for steering_component in layer_components:
57
- vector = steering_component['vector'] # Direction to steer
58
- strength = steering_component['strength'] # How much to steer
59
-
60
- # Add steering to each token in sequence
61
- amount = (strength * vector).unsqueeze(0).expand(seq_len, -1).unsqueeze(0)
62
-
63
- if clamp_intensity:
64
- # Remove existing projection to prevent over-steering
65
- projection = (hidden_states @ vector) @ vector
66
- amount = amount - projection
67
-
68
- hidden_states = hidden_states + amount
69
-
70
- return (hidden_states,) + rest_of_output
71
- ```
72
-
73
- **Key Insight**: Hooks are applied at specific layers during the forward pass, modifying activations before they propagate to subsequent layers.
74
-
75
- ### 3. **Configuration** (`demo.yaml`)
76
-
77
- ```yaml
78
- features:
79
- - [layer, feature_idx, strength]
80
- # Example: [11, 74457, 1.03]
81
- # Applies feature 74457 from layer 11 with strength 1.03
82
- ```
83
-
84
- **Parameters**:
85
- - `layer`: Which transformer layer to apply steering (0-31 for Llama 8B)
86
- - `feature_idx`: Which SAE feature to use (0-131071 for 128k SAE)
87
- - `strength`: Multiplicative factor for steering intensity
88
- - `clamp_intensity`: If true, removes existing projection before adding steering
89
-
90
- ### 4. **Applications**
91
-
92
- #### A. **Console Demo** (`demo.py`)
93
- - Interactive chat interface in terminal
94
- - Supports both NNsight and Transformers backends (configurable via `BACKEND`)
95
- - Real-time streaming with transformers backend
96
- - Color-coded output for better UX
97
-
98
- #### B. **Web App** (`app.py`)
99
- - Gradio interface for web deployment
100
- - Streaming generation with `TextIteratorStreamer`
101
- - Multi-turn conversation support
102
- - ZeroGPU compatible for HuggingFace Spaces
103
-
104
- ## Implementation Details
105
-
106
- ### Device Management
107
-
108
- **ZeroGPU Compatible**:
109
- ```python
110
- # Model loaded with device_map="auto"
111
- model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
112
-
113
- # Steering vectors on CPU initially (Spaces mode)
114
- load_device = "cpu" if SPACES_AVAILABLE else device
115
-
116
- # Hooks automatically move vectors to GPU during inference
117
- vector = vector.to(dtype=hidden_states.dtype, device=hidden_states.device)
118
- ```
119
-
120
- ### Streaming Generation
121
-
122
- Uses threading to enable real-time token streaming:
123
- ```python
124
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
125
- thread = Thread(target=lambda: model.generate(..., streamer=streamer))
126
- thread.start()
127
-
128
- for token_text in streamer:
129
- yield token_text # Send to UI as tokens arrive
130
- ```
131
-
132
- ### Hook Registration
133
-
134
- ```python
135
- # Register hooks on specific layers
136
- for layer_idx in layers_to_steer:
137
- hook_fn = create_steering_hook(layer_idx, steering_components)
138
- handle = model.model.layers[layer_idx].register_forward_hook(hook_fn)
139
- hook_handles.append(handle)
140
-
141
- # Generate with steering
142
- model.generate(...)
143
-
144
- # Clean up
145
- for handle in hook_handles:
146
- handle.remove()
147
- ```
148
-
149
- ## Technical Advantages
150
-
151
- 1. **No Fine-tuning Required**: Steers pre-trained models without retraining
152
- 2. **Interpretable**: SAE features are more interpretable than raw activations
153
- 3. **Composable**: Multiple steering vectors can be combined
154
- 4. **Efficient**: Only modifies forward pass, no backward pass needed
155
- 5. **Dynamic**: Different steering per generation, configurable at runtime
156
-
157
- ## Limitations
158
-
159
- 1. **SAE Dependency**: Requires pre-trained SAEs for the target model
160
- 2. **Manual Feature Selection**: Finding effective features requires experimentation
161
- 3. **Strength Tuning**: Steering strength needs calibration per feature
162
- 4. **Computational Overhead**: Small overhead from hook execution during generation
163
-
164
- ## File Structure
165
-
166
- ```
167
- eiffel-demo/
168
- β”œβ”€β”€ app.py # Gradio web interface
169
- β”œβ”€β”€ demo.py # Console chat interface
170
- β”œβ”€β”€ steering.py # Core steering implementation
171
- β”œβ”€β”€ extract_steering_vectors.py # SAE feature extraction
172
- β”œβ”€β”€ demo.yaml # Configuration (features, params)
173
- β”œβ”€β”€ steering_vectors.pt # Pre-extracted vectors (generated)
174
- β”œβ”€β”€ print_utils.py # Terminal formatting utilities
175
- β”œβ”€β”€ requirements.txt # Dependencies
176
- β”œβ”€β”€ README.md # User documentation
177
- └── PROJECT.md # This file
178
- ```
179
-
180
- ## Dependencies
181
-
182
- **Core**:
183
- - `transformers`: Model loading and generation
184
- - `torch`: Neural network operations
185
- - `gradio`: Web interface
186
- - `nnsight`: Alternative intervention framework (optional)
187
- - `sae-lens`: SAE utilities (for extraction only)
188
-
189
- **Deployment**:
190
- - `spaces`: HuggingFace Spaces ZeroGPU support
191
- - `hf-transfer`: Fast model downloads
192
-
193
- ## Usage Flow
194
-
195
- 1. **Setup**: Extract steering vectors once
196
- ```bash
197
- python extract_steering_vectors.py
198
- ```
199
-
200
- 2. **Configure**: Edit `demo.yaml` to select features and strengths
201
-
202
- 3. **Run**: Launch console or web interface
203
- ```bash
204
- python demo.py # Console
205
- python app.py # Web app
206
- ```
207
-
208
- 4. **Deploy**: Upload to HuggingFace Spaces with ZeroGPU
209
-
210
- ## References
211
-
212
- - SAE Repository: `andyrdt/saes-llama-3.1-8b-instruct`
213
- - Base Model: `meta-llama/Llama-3.1-8B-Instruct`
214
- - Technique: Activation steering via learned SAE features
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -111,6 +111,24 @@ def create_demo():
111
  #chatbot {
112
  height: 600px;
113
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  """
115
 
116
  # Create the interface
@@ -126,7 +144,8 @@ def create_demo():
126
  chatbot=gr.Chatbot(
127
  elem_id="chatbot",
128
  bubble_full_width=False,
129
- show_copy_button=True
 
130
  ),
131
  )
132
 
 
111
  #chatbot {
112
  height: 600px;
113
  }
114
+ /* Improve chat bubble contrast */
115
+ #chatbot .message {
116
+ border: 1px solid rgba(0, 0, 0, 0.15) !important;
117
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important;
118
+ }
119
+ #chatbot .user {
120
+ background-color: rgba(33, 150, 243, 0.08) !important;
121
+ }
122
+ #chatbot .bot {
123
+ background-color: rgba(0, 0, 0, 0.03) !important;
124
+ }
125
+ /* Ensure input is visible and properly positioned */
126
+ .input-container {
127
+ margin-top: 1rem;
128
+ padding: 1rem;
129
+ background: white;
130
+ border-top: 1px solid rgba(0, 0, 0, 0.1);
131
+ }
132
  """
133
 
134
  # Create the interface
 
144
  chatbot=gr.Chatbot(
145
  elem_id="chatbot",
146
  bubble_full_width=False,
147
+ show_copy_button=True,
148
+ show_label=False
149
  ),
150
  )
151
 
pyproject.toml CHANGED
@@ -9,5 +9,6 @@ dependencies = [
9
  "gradio>=4.0.0",
10
  "pyyaml>=6.0",
11
  "accelerate>=0.20.0",
12
- "spaces==0.28.3"
 
13
  ]
 
9
  "gradio>=4.0.0",
10
  "pyyaml>=6.0",
11
  "accelerate>=0.20.0",
12
+ "spaces==0.28.3",
13
+ "numpy<2",
14
  ]
uv.lock CHANGED
@@ -148,6 +148,7 @@ source = { virtual = "." }
148
  dependencies = [
149
  { name = "accelerate" },
150
  { name = "gradio" },
 
151
  { name = "pyyaml" },
152
  { name = "spaces" },
153
  { name = "torch" },
@@ -158,6 +159,7 @@ dependencies = [
158
  requires-dist = [
159
  { name = "accelerate", specifier = ">=0.20.0" },
160
  { name = "gradio", specifier = ">=4.0.0" },
 
161
  { name = "pyyaml", specifier = ">=6.0" },
162
  { name = "spaces", specifier = "==0.28.3" },
163
  { name = "torch", specifier = "==2.2.0" },
@@ -435,24 +437,18 @@ wheels = [
435
 
436
  [[package]]
437
  name = "numpy"
438
- version = "2.2.6"
439
- source = { registry = "https://pypi.org/simple" }
440
- sdist = { url = "https://files.pythonhosted.org/packages/76/21/7d2a95e4bba9dc13d043ee156a356c0a8f0c6309dff6b21b4d71a073b8a8/numpy-2.2.6.tar.gz", hash = "sha256:e29554e2bef54a90aa5cc07da6ce955accb83f21ab5de01a62c8478897b264fd", size = 20276440, upload-time = "2025-05-17T22:38:04.611Z" }
441
- wheels = [
442
- { url = "https://files.pythonhosted.org/packages/9a/3e/ed6db5be21ce87955c0cbd3009f2803f59fa08df21b5df06862e2d8e2bdd/numpy-2.2.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b412caa66f72040e6d268491a59f2c43bf03eb6c96dd8f0307829feb7fa2b6fb", size = 21165245, upload-time = "2025-05-17T21:27:58.555Z" },
443
- { url = "https://files.pythonhosted.org/packages/22/c2/4b9221495b2a132cc9d2eb862e21d42a009f5a60e45fc44b00118c174bff/numpy-2.2.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8e41fd67c52b86603a91c1a505ebaef50b3314de0213461c7a6e99c9a3beff90", size = 14360048, upload-time = "2025-05-17T21:28:21.406Z" },
444
- { url = "https://files.pythonhosted.org/packages/fd/77/dc2fcfc66943c6410e2bf598062f5959372735ffda175b39906d54f02349/numpy-2.2.6-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:37e990a01ae6ec7fe7fa1c26c55ecb672dd98b19c3d0e1d1f326fa13cb38d163", size = 5340542, upload-time = "2025-05-17T21:28:30.931Z" },
445
- { url = "https://files.pythonhosted.org/packages/7a/4f/1cb5fdc353a5f5cc7feb692db9b8ec2c3d6405453f982435efc52561df58/numpy-2.2.6-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:5a6429d4be8ca66d889b7cf70f536a397dc45ba6faeb5f8c5427935d9592e9cf", size = 6878301, upload-time = "2025-05-17T21:28:41.613Z" },
446
- { url = "https://files.pythonhosted.org/packages/eb/17/96a3acd228cec142fcb8723bd3cc39c2a474f7dcf0a5d16731980bcafa95/numpy-2.2.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:efd28d4e9cd7d7a8d39074a4d44c63eda73401580c5c76acda2ce969e0a38e83", size = 14297320, upload-time = "2025-05-17T21:29:02.78Z" },
447
- { url = "https://files.pythonhosted.org/packages/b4/63/3de6a34ad7ad6646ac7d2f55ebc6ad439dbbf9c4370017c50cf403fb19b5/numpy-2.2.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc7b73d02efb0e18c000e9ad8b83480dfcd5dfd11065997ed4c6747470ae8915", size = 16801050, upload-time = "2025-05-17T21:29:27.675Z" },
448
- { url = "https://files.pythonhosted.org/packages/07/b6/89d837eddef52b3d0cec5c6ba0456c1bf1b9ef6a6672fc2b7873c3ec4e2e/numpy-2.2.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:74d4531beb257d2c3f4b261bfb0fc09e0f9ebb8842d82a7b4209415896adc680", size = 15807034, upload-time = "2025-05-17T21:29:51.102Z" },
449
- { url = "https://files.pythonhosted.org/packages/01/c8/dc6ae86e3c61cfec1f178e5c9f7858584049b6093f843bca541f94120920/numpy-2.2.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8fc377d995680230e83241d8a96def29f204b5782f371c532579b4f20607a289", size = 18614185, upload-time = "2025-05-17T21:30:18.703Z" },
450
- { url = "https://files.pythonhosted.org/packages/5b/c5/0064b1b7e7c89137b471ccec1fd2282fceaae0ab3a9550f2568782d80357/numpy-2.2.6-cp310-cp310-win32.whl", hash = "sha256:b093dd74e50a8cba3e873868d9e93a85b78e0daf2e98c6797566ad8044e8363d", size = 6527149, upload-time = "2025-05-17T21:30:29.788Z" },
451
- { url = "https://files.pythonhosted.org/packages/a3/dd/4b822569d6b96c39d1215dbae0582fd99954dcbcf0c1a13c61783feaca3f/numpy-2.2.6-cp310-cp310-win_amd64.whl", hash = "sha256:f0fd6321b839904e15c46e0d257fdd101dd7f530fe03fd6359c1ea63738703f3", size = 12904620, upload-time = "2025-05-17T21:30:48.994Z" },
452
- { url = "https://files.pythonhosted.org/packages/9e/3b/d94a75f4dbf1ef5d321523ecac21ef23a3cd2ac8b78ae2aac40873590229/numpy-2.2.6-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0b605b275d7bd0c640cad4e5d30fa701a8d59302e127e5f79138ad62762c3e3d", size = 21040391, upload-time = "2025-05-17T21:44:35.948Z" },
453
- { url = "https://files.pythonhosted.org/packages/17/f4/09b2fa1b58f0fb4f7c7963a1649c64c4d315752240377ed74d9cd878f7b5/numpy-2.2.6-pp310-pypy310_pp73-macosx_14_0_x86_64.whl", hash = "sha256:7befc596a7dc9da8a337f79802ee8adb30a552a94f792b9c9d18c840055907db", size = 6786754, upload-time = "2025-05-17T21:44:47.446Z" },
454
- { url = "https://files.pythonhosted.org/packages/af/30/feba75f143bdc868a1cc3f44ccfa6c4b9ec522b36458e738cd00f67b573f/numpy-2.2.6-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce47521a4754c8f4593837384bd3424880629f718d87c5d44f8ed763edd63543", size = 16643476, upload-time = "2025-05-17T21:45:11.871Z" },
455
- { url = "https://files.pythonhosted.org/packages/37/48/ac2a9584402fb6c0cd5b5d1a91dcf176b15760130dd386bbafdbfe3640bf/numpy-2.2.6-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d042d24c90c41b54fd506da306759e06e568864df8ec17ccc17e9e884634fd00", size = 12812666, upload-time = "2025-05-17T21:45:31.426Z" },
456
  ]
457
 
458
  [[package]]
 
148
  dependencies = [
149
  { name = "accelerate" },
150
  { name = "gradio" },
151
+ { name = "numpy" },
152
  { name = "pyyaml" },
153
  { name = "spaces" },
154
  { name = "torch" },
 
159
  requires-dist = [
160
  { name = "accelerate", specifier = ">=0.20.0" },
161
  { name = "gradio", specifier = ">=4.0.0" },
162
+ { name = "numpy", specifier = "<2" },
163
  { name = "pyyaml", specifier = ">=6.0" },
164
  { name = "spaces", specifier = "==0.28.3" },
165
  { name = "torch", specifier = "==2.2.0" },
 
437
 
438
  [[package]]
439
  name = "numpy"
440
+ version = "1.26.4"
441
+ source = { registry = "https://pypi.org/simple" }
442
+ sdist = { url = "https://files.pythonhosted.org/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010", size = 15786129, upload-time = "2024-02-06T00:26:44.495Z" }
443
+ wheels = [
444
+ { url = "https://files.pythonhosted.org/packages/a7/94/ace0fdea5241a27d13543ee117cbc65868e82213fb31a8eb7fe9ff23f313/numpy-1.26.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0", size = 20631468, upload-time = "2024-02-05T23:48:01.194Z" },
445
+ { url = "https://files.pythonhosted.org/packages/20/f7/b24208eba89f9d1b58c1668bc6c8c4fd472b20c45573cb767f59d49fb0f6/numpy-1.26.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a", size = 13966411, upload-time = "2024-02-05T23:48:29.038Z" },
446
+ { url = "https://files.pythonhosted.org/packages/fc/a5/4beee6488160798683eed5bdb7eead455892c3b4e1f78d79d8d3f3b084ac/numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4", size = 14219016, upload-time = "2024-02-05T23:48:54.098Z" },
447
+ { url = "https://files.pythonhosted.org/packages/4b/d7/ecf66c1cd12dc28b4040b15ab4d17b773b87fa9d29ca16125de01adb36cd/numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f", size = 18240889, upload-time = "2024-02-05T23:49:25.361Z" },
448
+ { url = "https://files.pythonhosted.org/packages/24/03/6f229fe3187546435c4f6f89f6d26c129d4f5bed40552899fcf1f0bf9e50/numpy-1.26.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a", size = 13876746, upload-time = "2024-02-05T23:49:51.983Z" },
449
+ { url = "https://files.pythonhosted.org/packages/39/fe/39ada9b094f01f5a35486577c848fe274e374bbf8d8f472e1423a0bbd26d/numpy-1.26.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2", size = 18078620, upload-time = "2024-02-05T23:50:22.515Z" },
450
+ { url = "https://files.pythonhosted.org/packages/d5/ef/6ad11d51197aad206a9ad2286dc1aac6a378059e06e8cf22cd08ed4f20dc/numpy-1.26.4-cp310-cp310-win32.whl", hash = "sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07", size = 5972659, upload-time = "2024-02-05T23:50:35.834Z" },
451
+ { url = "https://files.pythonhosted.org/packages/19/77/538f202862b9183f54108557bfda67e17603fc560c384559e769321c9d92/numpy-1.26.4-cp310-cp310-win_amd64.whl", hash = "sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5", size = 15808905, upload-time = "2024-02-05T23:51:03.701Z" },
 
 
 
 
 
 
452
  ]
453
 
454
  [[package]]