fix(modeling_phi): Fixes cached generation when above maximum context length.
Browse files- modeling_phi.py +13 -22
modeling_phi.py
CHANGED
|
@@ -170,11 +170,11 @@ def _apply_rotary_emb_qkv(
|
|
| 170 |
|
| 171 |
class RotaryEmbedding(nn.Module):
|
| 172 |
"""Rotary positional embedding (RoPE).
|
| 173 |
-
|
| 174 |
Reference:
|
| 175 |
RoFormer: Enhanced Transformer with Rotary Position Embedding.
|
| 176 |
https://arxiv.org/pdf/2104.09864.pdf.
|
| 177 |
-
|
| 178 |
"""
|
| 179 |
|
| 180 |
def __init__(
|
|
@@ -261,32 +261,30 @@ class RotaryEmbedding(nn.Module):
|
|
| 261 |
seqlen_offset: int = 0,
|
| 262 |
**kwargs,
|
| 263 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 264 |
-
seq_start = seqlen_offset
|
| 265 |
-
seq_end = seq_start + qkv.shape[1]
|
| 266 |
-
|
| 267 |
if (
|
| 268 |
-
self.
|
|
|
|
| 269 |
or self._cos_cached.dtype != qkv.dtype
|
| 270 |
or (self.training and self._cos_cached.is_inference())
|
| 271 |
):
|
| 272 |
-
self._update_cos_sin_cache(
|
| 273 |
|
| 274 |
if kv is None:
|
| 275 |
return _apply_rotary_emb_qkv(
|
| 276 |
qkv,
|
| 277 |
-
self._cos_cached[
|
| 278 |
-
self._sin_cached[
|
| 279 |
)
|
| 280 |
else:
|
| 281 |
q = _apply_rotary_emb(
|
| 282 |
qkv,
|
| 283 |
-
self._cos_cached[
|
| 284 |
-
self._sin_cached[
|
| 285 |
)
|
| 286 |
kv = _apply_rotary_emb_kv(
|
| 287 |
kv,
|
| 288 |
-
self._cos_cached[
|
| 289 |
-
self._sin_cached[
|
| 290 |
)
|
| 291 |
|
| 292 |
return q, kv
|
|
@@ -498,9 +496,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
|
| 498 |
sequence_end = sequence_start + kv.shape[1]
|
| 499 |
|
| 500 |
# When the current sequence length is equal to or larger than the maximum sequence length,
|
| 501 |
-
# we need to
|
| 502 |
if sequence_end >= inference_params.max_seqlen:
|
| 503 |
-
inference_params.key_value_memory_dict[layer_idx] = inference_params.key_value_memory_dict[layer_idx]
|
| 504 |
|
| 505 |
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
| 506 |
kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
|
|
@@ -864,13 +862,6 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
| 864 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 865 |
**kwargs,
|
| 866 |
) -> Dict[str, Any]:
|
| 867 |
-
# Truncate `input_ids` and `attention_mask` (if necessary) to prevent exceeding
|
| 868 |
-
# the maximum sequence length
|
| 869 |
-
if input_ids.shape[1] > self.config.n_positions:
|
| 870 |
-
input_ids = input_ids[:, -self.config.n_positions :]
|
| 871 |
-
if attention_mask is not None:
|
| 872 |
-
attention_mask = attention_mask[:, -self.config.n_positions :]
|
| 873 |
-
|
| 874 |
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
| 875 |
past_key_values = InferenceParams(
|
| 876 |
max_seqlen=self.config.n_positions,
|
|
|
|
| 170 |
|
| 171 |
class RotaryEmbedding(nn.Module):
|
| 172 |
"""Rotary positional embedding (RoPE).
|
| 173 |
+
|
| 174 |
Reference:
|
| 175 |
RoFormer: Enhanced Transformer with Rotary Position Embedding.
|
| 176 |
https://arxiv.org/pdf/2104.09864.pdf.
|
| 177 |
+
|
| 178 |
"""
|
| 179 |
|
| 180 |
def __init__(
|
|
|
|
| 261 |
seqlen_offset: int = 0,
|
| 262 |
**kwargs,
|
| 263 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
|
|
|
| 264 |
if (
|
| 265 |
+
self._seq_len_cached < qkv.shape[1] + seqlen_offset
|
| 266 |
+
or self._cos_cached.device != qkv.device
|
| 267 |
or self._cos_cached.dtype != qkv.dtype
|
| 268 |
or (self.training and self._cos_cached.is_inference())
|
| 269 |
):
|
| 270 |
+
self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
| 271 |
|
| 272 |
if kv is None:
|
| 273 |
return _apply_rotary_emb_qkv(
|
| 274 |
qkv,
|
| 275 |
+
self._cos_cached[seqlen_offset:],
|
| 276 |
+
self._sin_cached[seqlen_offset:],
|
| 277 |
)
|
| 278 |
else:
|
| 279 |
q = _apply_rotary_emb(
|
| 280 |
qkv,
|
| 281 |
+
self._cos_cached[seqlen_offset:],
|
| 282 |
+
self._sin_cached[seqlen_offset:],
|
| 283 |
)
|
| 284 |
kv = _apply_rotary_emb_kv(
|
| 285 |
kv,
|
| 286 |
+
self._cos_cached[seqlen_offset:],
|
| 287 |
+
self._sin_cached[seqlen_offset:],
|
| 288 |
)
|
| 289 |
|
| 290 |
return q, kv
|
|
|
|
| 496 |
sequence_end = sequence_start + kv.shape[1]
|
| 497 |
|
| 498 |
# When the current sequence length is equal to or larger than the maximum sequence length,
|
| 499 |
+
# we need to concatenate the current `kv` with the cached `kv` to expand its length
|
| 500 |
if sequence_end >= inference_params.max_seqlen:
|
| 501 |
+
inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
|
| 502 |
|
| 503 |
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
| 504 |
kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
|
|
|
|
| 862 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 863 |
**kwargs,
|
| 864 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 865 |
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
| 866 |
past_key_values = InferenceParams(
|
| 867 |
max_seqlen=self.config.n_positions,
|