Ramzes / docs /source /package_reference /trainable_tokens.md
Bordoglor's picture
Upload folder using huggingface_hub
302920f verified
|
raw
history blame
2.94 kB

Trainable Tokens

The Trainable Tokens method provides a way to target specific token embeddings for fine-tuning without resorting to training the full embedding matrix or using an adapter on the embedding matrix. It is based on the initial implementation from here.

The method only targets specific tokens and selectively trains the token indices you specify. Consequently the required RAM will be lower and disk memory is also significantly lower than storing the full fine-tuned embedding matrix.

Some preliminary benchmarks acquired with this script suggest that for gemma-2-2b (which has a rather large embedding matrix) you can save ~4 GiB VRAM with Trainable Tokens over fully fine-tuning the embedding matrix. While LoRA will use comparable amounts of VRAM it might also target tokens you don't want to be changed. Note that these are just indications and varying embedding matrix sizes might skew these numbers a bit.

Note that this method does not add tokens for you, you have to add tokens to the tokenizer yourself and resize the embedding matrix of the model accordingly. This method will only re-train the embeddings for the tokens you specify. This method can also be used in conjunction with LoRA layers! See the LoRA developer guide.

Saving the model with [~PeftModel.save_pretrained] or retrieving the state dict using [get_peft_model_state_dict] when adding new tokens may save the full embedding matrix instead of only the difference as a precaution because the embedding matrix was resized. To save space you can disable this behavior by setting save_embedding_layers=False when calling save_pretrained. This is safe to do as long as you don't modify the embedding matrix through other means as well, as such changes will be not tracked by trainable tokens.

TrainableTokensConfig

[[autodoc]] tuners.trainable_tokens.config.TrainableTokensConfig

TrainableTokensModel

[[autodoc]] tuners.trainable_tokens.model.TrainableTokensModel