File size: 6,380 Bytes
b5a2f00
 
 
1ecc373
 
 
 
 
d91bfb8
1ecc373
 
09640ea
1ecc373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0955f10
 
 
 
 
3d3afd3
 
 
 
 
 
0955f10
3d3afd3
0955f10
1ecc373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0955f10
 
1ecc373
 
0955f10
 
1ecc373
 
0955f10
1ecc373
 
 
 
0955f10
 
 
1ecc373
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09640ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0955f10
 
09640ea
 
0955f10
 
09640ea
 
 
 
 
 
 
0955f10
09640ea
e6d0cfd
 
 
 
 
 
 
 
1ecc373
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
---
{}
---
```
model: opt-125m
config: ModuleFqnToConfig
with Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig and IntxWeightOnlyConfig
config version: 1
torchao version: 0.14.0.dev
```

# Generate Quantized Model
```
import logging

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

# Configure logging to see warnings and debug information
logging.basicConfig(
    level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s"
)

# Enable specific loggers that might contain the serialization warnings
logging.getLogger("transformers").setLevel(logging.INFO)
logging.getLogger("torchao").setLevel(logging.INFO)
logging.getLogger("safetensors").setLevel(logging.INFO)
logging.getLogger("huggingface_hub").setLevel(logging.INFO)

model_id = "facebook/opt-125m"

from torchao.quantization import (
    Float8DynamicActivationFloat8WeightConfig,
    Int4WeightOnlyConfig,
    IntxWeightOnlyConfig,
    PerRow,
    PerAxis,
    ModuleFqnToConfig,
    Float8Tensor,
    Int4TilePackedTo4dTensor,
    IntxUnpackedToInt8Tensor,
)

float8dyn = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
int4wo = Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d")
intxwo = IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0))

qconfig_dict = {
    # highest priority
    "model.decoder.layers.3.self_attn.q_proj": int4wo,
    "model.decoder.layers.3.self_attn.k_proj": int4wo,
    "model.decoder.layers.3.self_attn.v_proj": int4wo,
    # vllm
    "model.decoder.layers.3.self_attn.qkv_proj": int4wo,

    "re:model\.decoder\.layers\..+\.self_attn\.q_proj": float8dyn,
    "re:model\.decoder\.layers\..+\.self_attn\.k_proj": float8dyn,
    "re:model\.decoder\.layers\..+\.self_attn\.v_proj": float8dyn,
    # this should not take effect and we'll fallback to _default
    # since no full mach (missing `j` in the end)
    "re:model\.decoder\.layers\..+\.self_attn\.out_pro": float8dyn,
    # vllm
    "re:model\.decoder\.layers\..+\.self_attn\.qkv_proj": float8dyn,

    "_default": intxwo,
}
quant_config = ModuleFqnToConfig(qconfig_dict)
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
)
print("quantized model:", quantized_model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
for i in range(12):
    if i == 3:
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor)
    else:
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor)
    assert isinstance(quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor)

# # # Push to hub
MODEL_NAME = model_id.split("/")[-1]
save_to = f"torchao-testing/{MODEL_NAME}-ModuleFqnToConfig-v1-regex-0.14.0.dev"
quantized_model.push_to_hub(save_to, safe_serialization=False)
tokenizer.push_to_hub(save_to)
# quantized_model.save_pretrained(save_to, safe_serialization=False)
# tokenizer.save_pretrained(save_to)


# Manual Testing
prompt = "What are we having for dinner?"
print("Prompt:", prompt)
inputs = tokenizer(
    prompt,
    return_tensors="pt",
).to("cuda")
# setting temperature to 0 to make sure result deterministic
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, temperature=0)

correct_output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", correct_output_text[0][len(prompt) :])


# # # Load model from saved checkpoint
reloaded_model = AutoModelForCausalLM.from_pretrained(
    save_to,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
    # quantization_config=quantization_config,
)

generated_ids = reloaded_model.generate(**inputs, max_new_tokens=128, temperature=0)
output_text = tokenizer.batch_decode(
    generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", output_text[0][len(prompt) :])

assert(correct_output_text == output_text)
```


# Test Loading
```
from transformers import (
  AutoModelForCausalLM,
  AutoProcessor,
  AutoTokenizer,
  TorchAoConfig,
)
from torchao.quantization import Float8Tensor
from torchao.quantization import (
    Float8Tensor,
    Int4TilePackedTo4dTensor,
    IntxUnpackedToInt8Tensor,
)
import torch

model_name = "torchao-testing/opt-125m-ModuleFqnToConfig-v1-regex-0.14.0.dev"
device = "cuda"
input_text = "What are we having for dinner?"
max_new_tokens = 10

quantized_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device,
    dtype=torch.bfloat16,
)
for i in range(12):
    if i == 3:
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor)
    else:
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor)
        assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor)
    assert isinstance(quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor)

tokenizer = AutoTokenizer.from_pretrained(model_name)

input_ids = tokenizer(input_text, return_tensors="pt").to(device)

output = quantized_model.generate(**input_ids, max_new_tokens=max_new_tokens)
print(tokenizer.decode(output[0], skip_special_tokens=True))

```

Output:

```
What are we having for dinner?
A nice dinner with a friend.
I
```