File size: 6,380 Bytes
b5a2f00 1ecc373 d91bfb8 1ecc373 09640ea 1ecc373 0955f10 3d3afd3 0955f10 3d3afd3 0955f10 1ecc373 0955f10 1ecc373 0955f10 1ecc373 0955f10 1ecc373 0955f10 1ecc373 09640ea 0955f10 09640ea 0955f10 09640ea 0955f10 09640ea e6d0cfd 1ecc373 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
---
{}
---
```
model: opt-125m
config: ModuleFqnToConfig
with Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig and IntxWeightOnlyConfig
config version: 1
torchao version: 0.14.0.dev
```
# Generate Quantized Model
```
import logging
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
# Configure logging to see warnings and debug information
logging.basicConfig(
level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s"
)
# Enable specific loggers that might contain the serialization warnings
logging.getLogger("transformers").setLevel(logging.INFO)
logging.getLogger("torchao").setLevel(logging.INFO)
logging.getLogger("safetensors").setLevel(logging.INFO)
logging.getLogger("huggingface_hub").setLevel(logging.INFO)
model_id = "facebook/opt-125m"
from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Int4WeightOnlyConfig,
IntxWeightOnlyConfig,
PerRow,
PerAxis,
ModuleFqnToConfig,
Float8Tensor,
Int4TilePackedTo4dTensor,
IntxUnpackedToInt8Tensor,
)
float8dyn = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
int4wo = Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d")
intxwo = IntxWeightOnlyConfig(weight_dtype=torch.int8, granularity=PerAxis(0))
qconfig_dict = {
# highest priority
"model.decoder.layers.3.self_attn.q_proj": int4wo,
"model.decoder.layers.3.self_attn.k_proj": int4wo,
"model.decoder.layers.3.self_attn.v_proj": int4wo,
# vllm
"model.decoder.layers.3.self_attn.qkv_proj": int4wo,
"re:model\.decoder\.layers\..+\.self_attn\.q_proj": float8dyn,
"re:model\.decoder\.layers\..+\.self_attn\.k_proj": float8dyn,
"re:model\.decoder\.layers\..+\.self_attn\.v_proj": float8dyn,
# this should not take effect and we'll fallback to _default
# since no full mach (missing `j` in the end)
"re:model\.decoder\.layers\..+\.self_attn\.out_pro": float8dyn,
# vllm
"re:model\.decoder\.layers\..+\.self_attn\.qkv_proj": float8dyn,
"_default": intxwo,
}
quant_config = ModuleFqnToConfig(qconfig_dict)
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
)
print("quantized model:", quantized_model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
for i in range(12):
if i == 3:
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor)
else:
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor)
# # # Push to hub
MODEL_NAME = model_id.split("/")[-1]
save_to = f"torchao-testing/{MODEL_NAME}-ModuleFqnToConfig-v1-regex-0.14.0.dev"
quantized_model.push_to_hub(save_to, safe_serialization=False)
tokenizer.push_to_hub(save_to)
# quantized_model.save_pretrained(save_to, safe_serialization=False)
# tokenizer.save_pretrained(save_to)
# Manual Testing
prompt = "What are we having for dinner?"
print("Prompt:", prompt)
inputs = tokenizer(
prompt,
return_tensors="pt",
).to("cuda")
# setting temperature to 0 to make sure result deterministic
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, temperature=0)
correct_output_text = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", correct_output_text[0][len(prompt) :])
# # # Load model from saved checkpoint
reloaded_model = AutoModelForCausalLM.from_pretrained(
save_to,
device_map="cuda:0",
torch_dtype=torch.bfloat16,
# quantization_config=quantization_config,
)
generated_ids = reloaded_model.generate(**inputs, max_new_tokens=128, temperature=0)
output_text = tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("Response:", output_text[0][len(prompt) :])
assert(correct_output_text == output_text)
```
# Test Loading
```
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
TorchAoConfig,
)
from torchao.quantization import Float8Tensor
from torchao.quantization import (
Float8Tensor,
Int4TilePackedTo4dTensor,
IntxUnpackedToInt8Tensor,
)
import torch
model_name = "torchao-testing/opt-125m-ModuleFqnToConfig-v1-regex-0.14.0.dev"
device = "cuda"
input_text = "What are we having for dinner?"
max_new_tokens = 10
quantized_model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device,
dtype=torch.bfloat16,
)
for i in range(12):
if i == 3:
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Int4TilePackedTo4dTensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Int4TilePackedTo4dTensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Int4TilePackedTo4dTensor)
else:
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.q_proj.weight, Float8Tensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.k_proj.weight, Float8Tensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.v_proj.weight, Float8Tensor)
assert isinstance(quantized_model.model.decoder.layers[i].self_attn.out_proj.weight, IntxUnpackedToInt8Tensor)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer(input_text, return_tensors="pt").to(device)
output = quantized_model.generate(**input_ids, max_new_tokens=max_new_tokens)
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
Output:
```
What are we having for dinner?
A nice dinner with a friend.
I
``` |