{ "cells": [ { "cell_type": "markdown", "id": "900b542d-0249-453c-a915-a061b80af69f", "metadata": {}, "source": [ "# PyTorch AO (torchao) with int8_dynamic_activation_int8_weight" ] }, { "cell_type": "markdown", "id": "10e1acc3-50b8-4d40-bdf3-0133c113cc4b", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "id": "a9935ae2", "metadata": {}, "outputs": [], "source": [ "import argparse\n", "import os\n", "\n", "import torch\n", "from torch.optim import AdamW\n", "from torch.utils.data import DataLoader\n", "from peft import (\n", " get_peft_config,\n", " get_peft_model,\n", " get_peft_model_state_dict,\n", " set_peft_model_state_dict,\n", " LoraConfig,\n", " PeftType,\n", " PrefixTuningConfig,\n", " PromptEncoderConfig,\n", ")\n", "\n", "import evaluate\n", "from datasets import load_dataset\n", "from transformers import AutoModelForSequenceClassification, AutoTokenizer, TorchAoConfig, get_linear_schedule_with_warmup, set_seed\n", "from tqdm import tqdm" ] }, { "cell_type": "markdown", "id": "eafdd532-b1eb-4aac-8077-3386a84c7cdb", "metadata": {}, "source": [ "## Parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "e3b13308", "metadata": {}, "outputs": [], "source": [ "batch_size = 16\n", "model_name_or_path = \"google/gemma-2-2b\"\n", "task = \"mrpc\"\n", "device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n", "num_epochs = 5\n", "lr = 2e-5\n", "\n", "lora_rank = 16\n", "lora_alpha = 32\n", "lora_dropout = 0.1" ] }, { "cell_type": "markdown", "id": "c7fb69bf-0182-4111-b715-e2e659b42b1d", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": 3, "id": "d2f4d25e-30b9-431f-95c3-adb390dc6fcd", "metadata": {}, "outputs": [], "source": [ "if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n", " padding_side = \"left\"\n", "else:\n", " padding_side = \"right\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n", "if getattr(tokenizer, \"pad_token_id\") is None:\n", " tokenizer.pad_token_id = tokenizer.eos_token_id\n", "\n", "datasets = load_dataset(\"glue\", task)\n", "metric = evaluate.load(\"glue\", task)" ] }, { "cell_type": "code", "execution_count": 4, "id": "1ea852bc-a040-4244-8fd3-516307cecd14", "metadata": {}, "outputs": [], "source": [ "def tokenize_function(examples):\n", " # max_length=None => use the model max length (it's actually the default)\n", " outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n", " return outputs" ] }, { "cell_type": "code", "execution_count": 5, "id": "cf5ef289-f42f-4582-bd5e-9852ad8beff2", "metadata": {}, "outputs": [], "source": [ "tokenized_datasets = datasets.map(\n", " tokenize_function,\n", " batched=True,\n", " remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n", ")\n", "\n", "# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n", "# transformers library\n", "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "739b3655-9db0-48bc-8542-308c6d5e0b8b", "metadata": {}, "outputs": [], "source": [ "def collate_fn(examples):\n", " return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "0288f311-8475-4a0e-99af-e4b909d10e01", "metadata": {}, "outputs": [], "source": [ "# Instantiate dataloaders.\n", "train_dataloader = DataLoader(\n", " tokenized_datasets[\"train\"],\n", " shuffle=True,\n", " collate_fn=collate_fn,\n", " batch_size=batch_size,\n", ")\n", "eval_dataloader = DataLoader(\n", " tokenized_datasets[\"validation\"],\n", " shuffle=False,\n", " collate_fn=collate_fn,\n", " batch_size=batch_size,\n", ")" ] }, { "cell_type": "markdown", "id": "fcaf6f9e-c9d1-445a-9f08-18ef462f67ce", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 8, "id": "e5dfff56-ea80-4561-aeaf-43216bbb9af7", "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2ac42f98e60d412496fe77ed7eb5c6df", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/3 [00:00, weight=AffineQuantizedTensor(shape=torch.Size([2048, 2304]), block_size=(1, 2304), device=cuda:0, layout_type=PlainLayoutType(), layout_tensor_dtype=torch.int8, quant_min=None, quant_max=None)))\n", " (lora_dropout): ModuleDict(\n", " (default): Dropout(p=0.1, inplace=False)\n", " )\n", " (lora_A): ModuleDict(\n", " (default): Linear(in_features=2304, out_features=16, bias=False)\n", " )\n", " (lora_B): ModuleDict(\n", " (default): Linear(in_features=16, out_features=2048, bias=False)\n", " )\n", " (lora_embedding_A): ParameterDict()\n", " (lora_embedding_B): ParameterDict()\n", " (lora_magnitude_vector): ModuleDict()\n", " )\n", " (k_proj): Linear(in_features=2304, out_features=1024, weight=LinearActivationQuantizedTensor(activation=, weight=AffineQuantizedTensor(shape=torch.Size([1024, 2304]), block_size=(1, 2304), device=cuda:0, layout_type=PlainLayoutType(), layout_tensor_dtype=torch.int8, quant_min=None, quant_max=None)))\n", " (v_proj): lora.TorchaoLoraLinear(\n", " (base_layer): Linear(in_features=2304, out_features=1024, weight=LinearActivationQuantizedTensor(activation=, weight=AffineQuantizedTensor(shape=torch.Size([1024, 2304]), block_size=(1, 2304), device=cuda:0, layout_type=PlainLayoutType(), layout_tensor_dtype=torch.int8, quant_min=None, quant_max=None)))\n", " (lora_dropout): ModuleDict(\n", " (default): Dropout(p=0.1, inplace=False)\n", " )\n", " (lora_A): ModuleDict(\n", " (default): Linear(in_features=2304, out_features=16, bias=False)\n", " )\n", " (lora_B): ModuleDict(\n", " (default): Linear(in_features=16, out_features=1024, bias=False)\n", " )\n", " (lora_embedding_A): ParameterDict()\n", " (lora_embedding_B): ParameterDict()\n", " (lora_magnitude_vector): ModuleDict()\n", " )\n", " (o_proj): Linear(in_features=2048, out_features=2304, weight=LinearActivationQuantizedTensor(activation=, weight=AffineQuantizedTensor(shape=torch.Size([2304, 2048]), block_size=(1, 2048), device=cuda:0, layout_type=PlainLayoutType(), layout_tensor_dtype=torch.int8, quant_min=None, quant_max=None)))\n", " (rotary_emb): Gemma2RotaryEmbedding()\n", " )\n", " (mlp): Gemma2MLP(\n", " (gate_proj): Linear(in_features=2304, out_features=9216, weight=LinearActivationQuantizedTensor(activation=, weight=AffineQuantizedTensor(shape=torch.Size([9216, 2304]), block_size=(1, 2304), device=cuda:0, layout_type=PlainLayoutType(), layout_tensor_dtype=torch.int8, quant_min=None, quant_max=None)))\n", " (up_proj): Linear(in_features=2304, out_features=9216, weight=LinearActivationQuantizedTensor(activation=, weight=AffineQuantizedTensor(shape=torch.Size([9216, 2304]), block_size=(1, 2304), device=cuda:0, layout_type=PlainLayoutType(), layout_tensor_dtype=torch.int8, quant_min=None, quant_max=None)))\n", " (down_proj): Linear(in_features=9216, out_features=2304, weight=LinearActivationQuantizedTensor(activation=, weight=AffineQuantizedTensor(shape=torch.Size([2304, 9216]), block_size=(1, 9216), device=cuda:0, layout_type=PlainLayoutType(), layout_tensor_dtype=torch.int8, quant_min=None, quant_max=None)))\n", " (act_fn): PytorchGELUTanh()\n", " )\n", " (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n", " (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n", " (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n", " (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n", " )\n", " )\n", " (norm): Gemma2RMSNorm((2304,), eps=1e-06)\n", " )\n", " (score): ModulesToSaveWrapper(\n", " (original_module): Linear(in_features=2304, out_features=2, bias=False)\n", " (modules_to_save): ModuleDict(\n", " (default): Linear(in_features=2304, out_features=2, bias=False)\n", " )\n", " )\n", " )\n", " )\n", ")" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.config.use_cache = False\n", "model.to(device)" ] }, { "cell_type": "code", "execution_count": 13, "id": "fa0e73be", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/230 [00:00