import importlib.metadata from packaging import version from torch import nn from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel from transformers.modeling_layers import GradientCheckpointingLayer from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm, LlamaRotaryEmbedding, ) from transformers.utils import logging from transformers.utils.import_utils import _is_package_available logger = logging.get_logger(__name__) def is_transformers_attn_greater_or_equal_4_56_2(): if not _is_package_available("transformers"): return False return version.parse(importlib.metadata.version("transformers")) >= version.parse( "4.56.2" ) class ModifiedLlamaAttention(LlamaAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_causal = False class ModifiedLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): GradientCheckpointingLayer.__init__(self) self.hidden_size = config.hidden_size self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm( config.hidden_size, eps=config.rms_norm_eps ) class LlamaEncoderModel(LlamaModel): def __init__(self, config): if not is_transformers_attn_greater_or_equal_4_56_2(): raise ValueError( "The current implementation of LlamaEncoderModel follows modeling_llama.py of transformers version >= 4.56.2" ) LlamaPreTrainedModel.__init__(self, config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding( config.vocab_size, config.hidden_size, self.padding_idx ) self.layers = nn.ModuleList( [ ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) self._use_sdpa = config._attn_implementation == "sdpa" self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" if not self._use_flash_attention_2: raise ValueError( "The current implementation of LlamaBiModel only supports flash attention 2 for attention implementation" ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init()