File size: 7,336 Bytes
c5681ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread


def load_saes_from_file(file_path, cfg, device):
    """
    Load pre-extracted steering vectors from a local file.

    This is much faster than load_saes() since it doesn't download large SAE files.
    The file should be created using extract_steering_vectors.py script.

    Args:
        file_path: Path to the .pt file containing steering vectors
        cfg: Configuration dict with 'features' list
        device: Device to load tensors on ('cuda' or 'cpu')

    Returns:
        List of steering component dicts with keys: 'layer', 'feature', 'strength', 'vector'
    """
    import os

    if not os.path.exists(file_path):
        raise FileNotFoundError(
            f"Steering vectors file not found: {file_path}\n"
            f"Please run: python extract_steering_vectors.py"
        )

    print(f"Loading pre-extracted steering vectors from {file_path}...")

    # Load the dictionary of vectors
    steering_vectors_dict = torch.load(file_path, map_location="cpu")

    if not cfg['features'] or len(cfg['features']) == 0:
        print("No features specified in config.")
        return []

    steering_components = []
    features = cfg['features']
    reduced_strengths = cfg.get('reduced_strengths', False)

    for i, feature in enumerate(features):
        layer_idx, feature_idx = feature[0], feature[1]
        strength = feature[2] if len(feature) > 2 else 0.0

        if reduced_strengths:
            strength *= layer_idx

        # Look up the pre-extracted vector
        key = (layer_idx, feature_idx)
        if key not in steering_vectors_dict:
            raise KeyError(
                f"Vector for layer {layer_idx}, feature {feature_idx} not found in {file_path}.\n"
                f"Please re-run: python extract_steering_vectors.py"
            )

        vec = steering_vectors_dict[key].to(device, non_blocking=True)

        # Display
        reduced_str = f"[{strength/layer_idx:.2f}]" if layer_idx > 0 else "[N/A]"
        print(f"Loaded feature {layer_idx} {feature_idx} {strength:.2f} {reduced_str}")

        steering_components.append({
            'layer': layer_idx,
            'feature': feature_idx,
            'strength': strength,
            'vector': vec  # Already normalized in the file
        })

    print(f"Loaded {len(steering_components)} steering vector(s) from local file")
    return steering_components



def create_steering_hook(layer_idx, steering_components, clamp_intensity=False):
    """
    Create a forward hook for a specific layer that applies steering.

    Args:
        layer_idx: Which layer this hook is for
        steering_components: List of steering components (all layers)
        clamp_intensity: Whether to clamp steering intensity

    Returns:
        Forward hook function
    """
    layer_components = [sc for sc in steering_components if sc['layer'] == layer_idx]

    if not layer_components:
        return None

    def hook(module, input, output):
        """Forward hook that modifies the output hidden states."""
        # Handle different output formats (tuple vs tensor)
        if isinstance(output, tuple):
            hidden_states = output[0]
            rest_of_output = output[1:]
        else:
            hidden_states = output
            rest_of_output = None

        # Handle different shapes during generation
        original_shape = hidden_states.shape
        if len(original_shape) == 2:
            # During generation: [batch, hidden_dim] -> add seq_len dimension
            hidden_states = hidden_states.unsqueeze(1)  # [batch, 1, hidden_dim]

        for sc in layer_components:
            strength = sc['strength']
            vector = sc['vector']  # Already normalized

            # Ensure vector matches hidden_states dtype and device
            vector = vector.to(dtype=hidden_states.dtype, device=hidden_states.device)

            # Match nnsight's expansion pattern exactly
            seq_len = hidden_states.shape[1]
            amount = (strength * vector).unsqueeze(0).expand(seq_len, -1).unsqueeze(0)  # [1, seq_len, hidden_dim]

            if clamp_intensity:
                # Remove existing projection (prevents over-steering)
                projection_scalars = torch.einsum('bsh,h->bs', hidden_states, vector).unsqueeze(-1)
                projection_vectors = projection_scalars * vector.view(1, 1, -1)
                amount = amount - projection_vectors

            hidden_states = hidden_states + amount

        # Restore original shape if we added a dimension
        if len(original_shape) == 2:
            hidden_states = hidden_states.squeeze(1)  # [batch, hidden_dim]

        # Return in the same format as input
        if rest_of_output is not None:
            return (hidden_states,) + rest_of_output
        else:
            return hidden_states

    return hook


def stream_steered_answer_hf(model: AutoModelForCausalLM,
                                tokenizer: AutoTokenizer,
                                chat,
                                steering_components,
                                max_new_tokens=128,
                                temperature=0.0,
                                repetition_penalty=1.0,
                                clamp_intensity=False,
                                stream=True):
    """
    Generate steered answer using pure HuggingFace Transformers with streaming.

    Args:
        model: HuggingFace transformers model
        tokenizer: Tokenizer instance
        chat: Chat history in OpenAI format
        steering_components: List of dicts with 'layer', 'strength', 'vector'
        max_new_tokens: Maximum tokens to generate
        temperature: Sampling temperature (0 = greedy)
        repetition_penalty: Repetition penalty
        clamp_intensity: Whether to clamp steering intensity

    Yields:
        Partial text as tokens are generated

    """

    input_ids_list = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True)
    input_ids = torch.tensor([input_ids_list]).to(model.device)

    # Register steering hooks
    hook_handles = []
    layers_to_steer = set(sc['layer'] for sc in steering_components)

    for layer_idx in layers_to_steer:
        hook_fn = create_steering_hook(layer_idx, steering_components, clamp_intensity)
        if hook_fn:
            layer_module = model.model.layers[layer_idx]
            handle = layer_module.register_forward_hook(hook_fn)
            hook_handles.append(handle)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    generation_kwargs = {
        "input_ids": input_ids,
        "max_new_tokens": max_new_tokens,
        "temperature": temperature if temperature > 0 else 1.0,
        "do_sample": temperature > 0,
        "repetition_penalty": repetition_penalty,
        "streamer": streamer,
        "pad_token_id": tokenizer.eos_token_id,
    }

    thread = Thread(target=lambda: model.generate(**generation_kwargs))
    thread.start()

    generated_text = ""
    for token_text in streamer:
        generated_text += token_text
        yield generated_text

    thread.join()

    for handle in hook_handles:
        handle.remove()