File size: 6,277 Bytes
f472b08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import Optional
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
@dataclass
class ScriptArguments(SFTConfig):
# model configs
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name or path of the fp32/16 base model."}
)
residual_model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The name or path of the fp32/16 residual model. (`['fxmeng/pissa-llama-2-7b-r16-alpha-16']`)"
},
)
bits: str = field(default="fp32", metadata={"help": "(`['fp4', 'nf4', 'int8', 'bf16', 'fp16', fp32]`)"})
init_lora_weights: str = field(default="pissa", metadata={"help": "(`['gaussian', 'pissa', 'pissa_niter_4']`)"})
lora_r: int = field(default=16)
lora_alpha: int = field(default=16)
lora_dropout: float = field(default=0)
convert_pissa_to_lora: bool = field(default=False)
merge_and_save: bool = field(default=False)
# dataset configs
data_path: str = field(default="imdb", metadata={"help": "Path to the training data."})
dataset_split: str = field(default="train[:1%]", metadata={"help": "(`['train', 'test', 'eval']`):"})
dataset_field: list[str] = field(default=None, metadata={"help": "Fields of dataset input and output."})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
print(script_args)
print(f"Load pre-processed residual model in {script_args.bits} bits.")
if script_args.bits in ["nf4", "fp4", "int8"]:
quantization_config = BitsAndBytesConfig(
load_in_4bit=(script_args.bits == "nf4" or script_args.bits == "fp4"),
load_in_8bit=script_args.bits == "int8",
bnb_4bit_quant_type=script_args.bits,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
res_model = AutoModelForCausalLM.from_pretrained(
script_args.residual_model_name_or_path, quantization_config=quantization_config, low_cpu_mem_usage=True
)
res_model = prepare_model_for_kbit_training(res_model)
print("Wrapping the residual model with PiSSA.")
peft_model = PeftModel.from_pretrained(
res_model, script_args.residual_model_name_or_path, subfolder="pissa_init", is_trainable=True
)
tokenizer = AutoTokenizer.from_pretrained(script_args.residual_model_name_or_path)
elif script_args.residual_model_name_or_path is not None:
res_model = AutoModelForCausalLM.from_pretrained(
script_args.residual_model_name_or_path,
torch_dtype=(
torch.float16
if script_args.bits == "fp16"
else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32)
),
device_map="auto",
)
print("Wrapping the residual model with PiSSA.")
peft_model = PeftModel.from_pretrained(
res_model, script_args.residual_model_name_or_path, subfolder="pissa_init", is_trainable=True
)
tokenizer = AutoTokenizer.from_pretrained(script_args.residual_model_name_or_path)
elif script_args.base_model_name_or_path is not None:
print(
f"No available pre-processed model, manually initialize a PiSSA using {script_args.base_model_name_or_path}."
)
model = AutoModelForCausalLM.from_pretrained(
script_args.base_model_name_or_path,
torch_dtype=(
torch.float16
if script_args.bits == "fp16"
else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32)
),
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
lora_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
init_lora_weights=script_args.init_lora_weights,
lora_dropout=script_args.lora_dropout,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, lora_config)
print(peft_model)
peft_model.print_trainable_parameters()
print(f"Training PiSSA with trl on the {script_args.data_path}[{script_args.dataset_split}] dataset.")
dataset = load_dataset(script_args.data_path, split=script_args.dataset_split)
dataset = dataset.map(
lambda example: {
"text": f"### USER: {example[script_args.dataset_field[0]]}\n### ASSISTANT: {example[script_args.dataset_field[1]]}"
}
)
trainer = SFTTrainer(
model=peft_model,
args=script_args,
train_dataset=dataset,
processing_class=tokenizer,
)
trainer.train()
trainer.save_state()
############################## Upon training completion, convert and save PiSSA in LoRA format ##############################
if script_args.convert_pissa_to_lora:
peft_model.save_pretrained(
os.path.join(script_args.output_dir, "pissa_lora"),
path_initial_model_for_weight_conversion=os.path.join(script_args.residual_model_name_or_path, "pissa_init"),
)
else:
peft_model.save_pretrained(
os.path.join(script_args.output_dir, "pissa_ft"),
)
if script_args.merge_and_save:
model = peft_model.merge_and_unload()
model.save_pretrained(os.path.join(script_args.output_dir, "pissa_merged"))
tokenizer.save_pretrained(os.path.join(script_args.output_dir, "pissa_merged"))
|