Update modeling_llama.py
Browse files- modeling_llama.py +55 -21
modeling_llama.py
CHANGED
|
@@ -32,19 +32,52 @@ from transformers.modeling_utils import PreTrainedModel
|
|
| 32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 33 |
from .configuration_clex import CLEXLlamaConfig
|
| 34 |
from .clex_layer import LlamaCLEXScalingRotaryEmbedding
|
| 35 |
-
|
| 36 |
-
|
| 37 |
from einops import rearrange
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
from flash_attn.bert_padding import unpad_input, pad_input
|
| 41 |
|
| 42 |
|
| 43 |
logger = logging.get_logger(__name__)
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
_CONFIG_FOR_DOC = "CLEXLlamaConfig"
|
| 46 |
|
| 47 |
|
|
|
|
|
|
|
|
|
|
| 48 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
| 49 |
def _make_causal_mask(
|
| 50 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
|
@@ -137,13 +170,13 @@ def rotate_half(x):
|
|
| 137 |
return torch.cat((-x2, x1), dim=-1)
|
| 138 |
|
| 139 |
|
| 140 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
| 141 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
| 142 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 143 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 144 |
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 145 |
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 146 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 147 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 148 |
return q_embed, k_embed
|
| 149 |
|
|
@@ -247,19 +280,17 @@ class LlamaAttention(nn.Module):
|
|
| 247 |
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 248 |
|
| 249 |
kv_seq_len = key_states.shape[-2]
|
| 250 |
-
if past_key_value is not None:
|
| 251 |
-
kv_seq_len += past_key_value[0].shape[-2]
|
| 252 |
-
# [bsz, nh, t, hd]
|
| 253 |
|
| 254 |
if past_key_value is not None:
|
| 255 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 256 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 257 |
|
| 258 |
if pack_cos_sin is not None:
|
| 259 |
-
cos, sin = pack_cos_sin
|
| 260 |
else:
|
| 261 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 262 |
-
|
|
|
|
| 263 |
|
| 264 |
if past_key_value is not None:
|
| 265 |
# reuse k, v, self_attention
|
|
@@ -267,12 +298,13 @@ class LlamaAttention(nn.Module):
|
|
| 267 |
|
| 268 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 269 |
|
|
|
|
| 270 |
|
| 271 |
if self.log_scale:
|
| 272 |
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
| 273 |
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
| 274 |
query_states = query_states * log_n
|
| 275 |
-
if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] or
|
| 276 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 277 |
|
| 278 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
@@ -308,6 +340,7 @@ class LlamaAttention(nn.Module):
|
|
| 308 |
attn_weights = None
|
| 309 |
|
| 310 |
return attn_output, attn_weights, past_key_value
|
|
|
|
| 311 |
elif past_key_value is not None:
|
| 312 |
output = flash_attn_with_kvcache(
|
| 313 |
query_states.transpose(1, 2),
|
|
@@ -614,13 +647,15 @@ class LlamaModel(LlamaPreTrainedModel):
|
|
| 614 |
if inputs_embeds is None:
|
| 615 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 616 |
# embed positions
|
| 617 |
-
if attention_mask is None:
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
attention_mask = self._prepare_decoder_attention_mask(
|
| 622 |
-
|
| 623 |
-
)
|
|
|
|
|
|
|
| 624 |
|
| 625 |
hidden_states = inputs_embeds
|
| 626 |
|
|
@@ -802,7 +837,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|
| 802 |
# Enable model parallelism
|
| 803 |
shift_labels = shift_labels.to(shift_logits.device)
|
| 804 |
loss = loss_fct(shift_logits, shift_labels)
|
| 805 |
-
|
| 806 |
if not return_dict:
|
| 807 |
output = (logits,) + outputs[1:]
|
| 808 |
return (loss,) + output if loss is not None else output
|
|
|
|
| 32 |
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
| 33 |
from .configuration_clex import CLEXLlamaConfig
|
| 34 |
from .clex_layer import LlamaCLEXScalingRotaryEmbedding
|
|
|
|
|
|
|
| 35 |
from einops import rearrange
|
| 36 |
+
import importlib.metadata
|
| 37 |
+
import importlib.util
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
logger = logging.get_logger(__name__)
|
| 41 |
|
| 42 |
+
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
|
| 43 |
+
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
|
| 44 |
+
package_exists = importlib.util.find_spec(pkg_name) is not None
|
| 45 |
+
package_version = "N/A"
|
| 46 |
+
if package_exists:
|
| 47 |
+
try:
|
| 48 |
+
package_version = importlib.metadata.version(pkg_name)
|
| 49 |
+
package_exists = True
|
| 50 |
+
except importlib.metadata.PackageNotFoundError:
|
| 51 |
+
package_exists = False
|
| 52 |
+
logger.info(f"Detected {pkg_name} version {package_version}")
|
| 53 |
+
if return_version:
|
| 54 |
+
return package_exists, package_version
|
| 55 |
+
else:
|
| 56 |
+
return package_exists
|
| 57 |
+
|
| 58 |
+
def is_flash_attn_available():
|
| 59 |
+
if not _is_package_available("torch", return_version=True):
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
# Let's add an extra check to see if cuda is available
|
| 63 |
+
import torch
|
| 64 |
+
|
| 65 |
+
return _is_package_available("flash_attn") and torch.cuda.is_available()
|
| 66 |
+
|
| 67 |
+
if is_flash_attn_available():
|
| 68 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
|
| 69 |
+
# from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
| 70 |
+
from flash_attn.bert_padding import unpad_input, pad_input
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
_CONFIG_FOR_DOC = "CLEXLlamaConfig"
|
| 76 |
|
| 77 |
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
| 82 |
def _make_causal_mask(
|
| 83 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
|
|
|
| 170 |
return torch.cat((-x2, x1), dim=-1)
|
| 171 |
|
| 172 |
|
| 173 |
+
def apply_rotary_pos_emb(q, k, cos, sin, q_len, position_ids):
|
| 174 |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
| 175 |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 176 |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
| 177 |
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 178 |
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
| 179 |
+
q_embed = (q * cos[:, :, -q_len:, :]) + (rotate_half(q) * sin[:, :, -q_len:, :])
|
| 180 |
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 181 |
return q_embed, k_embed
|
| 182 |
|
|
|
|
| 280 |
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 281 |
|
| 282 |
kv_seq_len = key_states.shape[-2]
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
if past_key_value is not None:
|
| 285 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 286 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 287 |
|
| 288 |
if pack_cos_sin is not None:
|
| 289 |
+
cos, sin = pack_cos_sin.to(query_states.device)
|
| 290 |
else:
|
| 291 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 292 |
+
key_position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=position_ids.device).unsqueeze(0).view(-1, kv_seq_len)
|
| 293 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, q_len, key_position_ids)
|
| 294 |
|
| 295 |
if past_key_value is not None:
|
| 296 |
# reuse k, v, self_attention
|
|
|
|
| 298 |
|
| 299 |
past_key_value = (key_states, value_states) if use_cache else None
|
| 300 |
|
| 301 |
+
use_flashatn = self.config.use_flashattn and is_flash_attn_available()
|
| 302 |
|
| 303 |
if self.log_scale:
|
| 304 |
log_n = torch.log(torch.tensor(kv_seq_len*1.0)).to(query_states.device, dtype=query_states.dtype) / \
|
| 305 |
torch.log(torch.tensor(self.config.max_position_embeddings)).to(query_states.device, dtype=query_states.dtype)
|
| 306 |
query_states = query_states * log_n
|
| 307 |
+
if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2] or use_flashatn:
|
| 308 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 309 |
|
| 310 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
|
|
| 340 |
attn_weights = None
|
| 341 |
|
| 342 |
return attn_output, attn_weights, past_key_value
|
| 343 |
+
# use flash attention
|
| 344 |
elif past_key_value is not None:
|
| 345 |
output = flash_attn_with_kvcache(
|
| 346 |
query_states.transpose(1, 2),
|
|
|
|
| 647 |
if inputs_embeds is None:
|
| 648 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 649 |
# embed positions
|
| 650 |
+
# if attention_mask is None:
|
| 651 |
+
# attention_mask = torch.ones(
|
| 652 |
+
# (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 653 |
+
# )
|
| 654 |
+
# attention_mask = self._prepare_decoder_attention_mask(
|
| 655 |
+
# attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 656 |
+
# )
|
| 657 |
+
attention_mask = None
|
| 658 |
+
|
| 659 |
|
| 660 |
hidden_states = inputs_embeds
|
| 661 |
|
|
|
|
| 837 |
# Enable model parallelism
|
| 838 |
shift_labels = shift_labels.to(shift_logits.device)
|
| 839 |
loss = loss_fct(shift_logits, shift_labels)
|
|
|
|
| 840 |
if not return_dict:
|
| 841 |
output = (logits,) + outputs[1:]
|
| 842 |
return (loss,) + output if loss is not None else output
|