| ``` | |
| model: opt-125m | |
| config: Int8DynamicActivationIntxWeightConfig | |
| config version: 2 | |
| torchao version: 0.14.dev | |
| ``` | |
| ``` | |
| import logging | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig | |
| # Configure logging to see warnings and debug information | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(name)s - %(levelname)s - %(message)s" | |
| ) | |
| # Enable specific loggers that might contain the serialization warnings | |
| logging.getLogger("transformers").setLevel(logging.INFO) | |
| logging.getLogger("torchao").setLevel(logging.INFO) | |
| logging.getLogger("safetensors").setLevel(logging.INFO) | |
| logging.getLogger("huggingface_hub").setLevel(logging.INFO) | |
| model_id = "facebook/opt-125m" | |
| from torchao.quantization import Int8DynamicActivationIntxWeightConfig | |
| from torchao.quantization.granularity import PerGroup | |
| version = 2 | |
| quant_config = Int8DynamicActivationIntxWeightConfig( | |
| weight_dtype=torch.int4, | |
| weight_granularity=PerGroup(32), | |
| version=version | |
| ) | |
| quantization_config = TorchAoConfig(quant_type=quant_config) | |
| quantized_model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| quantization_config=quantization_config, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| # Push to hub | |
| MODEL_NAME = model_id.split("/")[-1] | |
| save_to = f"torchao-testing/{MODEL_NAME}-Int8DynamicActivationIntxWeightConfig-v{version}-0.14.0.dev-safetensors" | |
| quantized_model.push_to_hub(save_to, safe_serialization=False) | |
| tokenizer.push_to_hub(save_to) | |
| # Manual Testing | |
| prompt = "What are we having for dinner?" | |
| print("Prompt:", prompt) | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| ).to("cuda") | |
| # Detting temperature to 0 to make sure result deterministic | |
| generated_ids = quantized_model.generate(**inputs, max_new_tokens=128, temperature=0) | |
| correct_output_text = tokenizer.batch_decode( | |
| generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| print("Response:", correct_output_text[0][len(prompt) :]) | |
| # Load model from saved checkpoint | |
| reloaded_model = AutoModelForCausalLM.from_pretrained( | |
| save_to, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| generated_ids = reloaded_model.generate(**inputs, max_new_tokens=128, temperature=0) | |
| output_text = tokenizer.batch_decode( | |
| generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| print("Response:", output_text[0][len(prompt) :]) | |
| assert(correct_output_text == output_text) | |
| ``` |