File size: 63,429 Bytes
b60e5b6 ec3b40a b60e5b6 ec3b40a b60e5b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 |
"""
Modified MIT License
Software Copyright© 2025 IQuest Research
Our only modification is that, if the Software (or any derivative works
thereof) is used for any of your commercial products or services, you shall
prominently display "IQuest Coder" on the user interface of such product or
service.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
import math
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import GenerationMixin
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_iquestloopcoder import IQuestLoopCoderConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "IQuestLoopCoderConfig"
class IQuestLoopCoderCache(Cache):
"""Cache implementation for IQuestLoopCoder that manages shared and local KV caches.
- shared_key_cache/shared_value_cache: Stores KV from Loop 1 (global context)
- local_key_cache/local_value_cache: Stores KV from Loop 2+ (local window, only window_size tokens)
"""
def __init__(self, window_size: int, num_layers: int):
# We intentionally don't call super().__init__ because the parent assumes static cache sizes.
self.window_size = window_size
self.num_layers = num_layers
# Shared cache: stores Loop 1 KV (global context)
self.shared_key_cache: List[Optional[torch.Tensor]] = [None] * num_layers
self.shared_value_cache: List[Optional[torch.Tensor]] = [None] * num_layers
# Local cache: stores Loop 2+ KV (sliding window, only window_size tokens)
self.local_key_cache: List[Optional[torch.Tensor]] = [None] * num_layers
self.local_value_cache: List[Optional[torch.Tensor]] = [None] * num_layers
self.layers: List[Any] = [] # attribute expected by HF Cache utilities
self._seen_tokens = 0
def update_shared(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update shared cache (Loop 1 KV)."""
if layer_idx < 0 or layer_idx >= self.num_layers:
raise ValueError(f"layer_idx must be in [0, {self.num_layers}), got {layer_idx}")
cached_key = self.shared_key_cache[layer_idx]
cached_value = self.shared_value_cache[layer_idx]
if cached_key is None:
self.shared_key_cache[layer_idx] = key_states
self.shared_value_cache[layer_idx] = value_states
else:
if (
key_states.shape[0] != cached_key.shape[0]
or key_states.shape[1] != cached_key.shape[1]
or key_states.shape[3] != cached_key.shape[3]
):
raise ValueError(
"Cached and incoming key/value tensors must match on batch, head, and head_dim dimensions."
)
assert cached_value is not None
self.shared_key_cache[layer_idx] = torch.cat([cached_key, key_states], dim=2)
self.shared_value_cache[layer_idx] = torch.cat([cached_value, value_states], dim=2)
result_key = self.shared_key_cache[layer_idx]
result_value = self.shared_value_cache[layer_idx]
assert result_key is not None and result_value is not None
# Track sequence length
self._seen_tokens = result_key.shape[2]
return result_key, result_value
def update_local(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update local cache (Loop 2+ KV) with sliding window management.
If the cache is full (window_size tokens), remove the oldest token and add the new one.
"""
if layer_idx < 0 or layer_idx >= self.num_layers:
raise ValueError(f"layer_idx must be in [0, {self.num_layers}), got {layer_idx}")
cached_key = self.local_key_cache[layer_idx]
cached_value = self.local_value_cache[layer_idx]
if cached_key is None:
# First token in local cache
self.local_key_cache[layer_idx] = key_states
self.local_value_cache[layer_idx] = value_states
else:
if (
key_states.shape[0] != cached_key.shape[0]
or key_states.shape[1] != cached_key.shape[1]
or key_states.shape[3] != cached_key.shape[3]
):
raise ValueError(
"Cached and incoming key/value tensors must match on batch, head, and head_dim dimensions."
)
assert cached_value is not None
# Check if we need to remove the oldest token
current_len = cached_key.shape[2]
if current_len >= self.window_size:
# Remove the first token (oldest) and add the new one
self.local_key_cache[layer_idx] = torch.cat([cached_key[:, :, 1:, :], key_states], dim=2)
self.local_value_cache[layer_idx] = torch.cat([cached_value[:, :, 1:, :], value_states], dim=2)
else:
# Just append
self.local_key_cache[layer_idx] = torch.cat([cached_key, key_states], dim=2)
self.local_value_cache[layer_idx] = torch.cat([cached_value, value_states], dim=2)
result_key = self.local_key_cache[layer_idx]
result_value = self.local_value_cache[layer_idx]
assert result_key is not None and result_value is not None
return result_key, result_value
def get_shared(self, layer_idx: int) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Get shared cache for a layer."""
if layer_idx < 0 or layer_idx >= self.num_layers:
return None, None
return self.shared_key_cache[layer_idx], self.shared_value_cache[layer_idx]
def get_local(self, layer_idx: int) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Get local cache for a layer."""
if layer_idx < 0 or layer_idx >= self.num_layers:
return None, None
return self.local_key_cache[layer_idx], self.local_value_cache[layer_idx]
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Default update method (for compatibility, updates shared cache)."""
return self.update_shared(key_states, value_states, layer_idx, cache_kwargs)
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Get sequence length from shared cache."""
if layer_idx is None:
layer_idx = 0
if layer_idx < 0 or layer_idx >= len(self.shared_key_cache):
return 0
cached = self.shared_key_cache[layer_idx]
if cached is None:
return 0
return cached.shape[2]
def get_max_length(self) -> Optional[int]:
return None
def get_usable_length(
self, new_seq_length: int, layer_idx: Optional[int] = 0
) -> int:
return self.get_seq_length(layer_idx)
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
"""Reorder cache for beam search."""
for layer_idx in range(self.num_layers):
if self.shared_key_cache[layer_idx] is not None:
device = self.shared_key_cache[layer_idx].device
self.shared_key_cache[layer_idx] = self.shared_key_cache[layer_idx].index_select(0, beam_idx.to(device))
self.shared_value_cache[layer_idx] = self.shared_value_cache[layer_idx].index_select(0, beam_idx.to(device))
if self.local_key_cache[layer_idx] is not None:
device = self.local_key_cache[layer_idx].device
self.local_key_cache[layer_idx] = self.local_key_cache[layer_idx].index_select(0, beam_idx.to(device))
self.local_value_cache[layer_idx] = self.local_value_cache[layer_idx].index_select(0, beam_idx.to(device))
@property
def is_compileable(self) -> bool:
return False
def clear(self) -> None:
"""Clear all caches."""
logger.debug("Clearing IQuestLoopCoderCache")
self.shared_key_cache = [None] * self.num_layers
self.shared_value_cache = [None] * self.num_layers
self.local_key_cache = [None] * self.num_layers
self.local_value_cache = [None] * self.num_layers
self._seen_tokens = 0
class IQuestLoopCoderRMSNorm(nn.Module):
"""RMS Normalization layer."""
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class IQuestLoopCoderRotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)."""
def __init__(self, dim, max_position_embeddings=8192, base=500000.0, device=None, scaling_factor=1.0):
super().__init__()
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len_cached = max_position_embeddings
@torch.no_grad()
def forward(self, x, position_ids):
# x: [batch_size, num_heads, seq_len, head_dim]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Expand KV heads to match query heads for GQA."""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class IQuestLoopCoderMLP(nn.Module):
"""MLP with SwiGLU activation."""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class LoopGateProjection(nn.Module):
"""Gate projection for mixed attention in Loop 2+.
Computes: g = sigmoid(linear(Q)) for each head independently.
This gate determines how much to use Loop1's KV (global) vs current loop's KV (local).
"""
def __init__(self, num_heads: int, head_dim: int):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
# Each head has its own gate: Linear(head_dim -> 1) per head
# Implemented as [num_heads, head_dim] weight + [num_heads] bias
self.weight = nn.Parameter(torch.zeros(num_heads, head_dim))
self.bias = nn.Parameter(torch.zeros(num_heads))
def forward(self, query: torch.Tensor) -> torch.Tensor:
"""Compute gate values from query tensor.
Args:
query: [batch, num_heads, seq_len, head_dim]
Returns:
gate: [batch, num_heads, seq_len, 1]
"""
# query: [batch, num_heads, seq_len, head_dim]
# weight: [num_heads, head_dim]
# For each head h: gate_h = query[:, h, :, :] @ weight[h, :].T + bias[h]
# Using einsum: gate = einsum('bhsd,hd->bhs', query, weight) + bias
gate_logits = torch.einsum('bhsd,hd->bhs', query, self.weight) # [batch, num_heads, seq_len]
gate_logits = gate_logits + self.bias[None, :, None] # broadcast bias
gate = torch.sigmoid(gate_logits)
return gate.unsqueeze(-1) # [batch, num_heads, seq_len, 1]
class IQuestLoopCoderAttention(nn.Module):
"""Multi-head attention with GQA support."""
def __init__(self, config: IQuestLoopCoderConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.attention_dropout = config.attention_dropout
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self.rotary_emb = IQuestLoopCoderRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# Repeat KV for GQA
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights if output_attentions else None, past_key_value
def forward_with_external_kv(
self,
hidden_states: torch.Tensor,
external_key: torch.Tensor,
external_value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
sliding_window: Optional[int] = None,
) -> torch.Tensor:
"""Forward pass using external K, V (for Loop 2+ mixed attention).
Args:
hidden_states: Input for computing Q
external_key: Pre-computed K (already with RoPE applied)
external_value: Pre-computed V
attention_mask: Causal attention mask
position_ids: Position IDs
sliding_window: If set, apply sliding window attention
Returns:
Attention output [batch, seq_len, num_heads, head_dim]
"""
bsz, q_len, _ = hidden_states.size()
# Compute Q from current hidden states
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
# Apply RoPE to Q
cos, sin = self.rotary_emb(query_states, position_ids)
query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1))
# Use external K, V (already have RoPE for K)
key_states = external_key
value_states = external_value
# Repeat KV for GQA
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Compute attention
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# Apply attention mask (causal)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# Apply sliding window mask if needed
if sliding_window is not None and q_len > sliding_window:
# Create sliding window mask
# For each position i, can only attend to [i-window+1, i]
seq_len = key_states.shape[2]
row_idx = torch.arange(q_len, device=query_states.device).unsqueeze(1)
col_idx = torch.arange(seq_len, device=query_states.device).unsqueeze(0)
window_mask = (col_idx > row_idx) | (col_idx < row_idx - sliding_window + 1)
window_mask = window_mask.unsqueeze(0).unsqueeze(0) # [1, 1, q_len, seq_len]
attn_weights = attn_weights.masked_fill(window_mask, float('-inf'))
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
# Don't apply o_proj here - return raw attention output
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output # [batch, seq_len, num_heads, head_dim]
def get_qkv(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get Q, K, V tensors with RoPE applied.
Returns:
query: [batch, num_heads, seq_len, head_dim]
key: [batch, num_kv_heads, seq_len, head_dim]
value: [batch, num_kv_heads, seq_len, head_dim]
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
return query_states, key_states, value_states
def forward_decode_loop1(
self,
hidden_states: torch.Tensor,
past_shared_key: Optional[torch.Tensor],
past_shared_value: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass for Loop 1 in decode stage.
Args:
hidden_states: Current hidden states [batch, 1, hidden_size]
past_shared_key: Past shared keys from cache [batch, num_kv_heads, past_len, head_dim]
past_shared_value: Past shared values from cache [batch, num_kv_heads, past_len, head_dim]
attention_mask: Causal attention mask
position_ids: Position IDs
cache_position: Cache position
Returns:
output: Attention output [batch, 1, hidden_size]
k1: Current key [batch, num_kv_heads, 1, head_dim] (only current token)
v1: Current value [batch, num_kv_heads, 1, head_dim] (only current token)
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Store current token's k1, v1 for return (before concatenation)
k1_current = key_states # [batch, num_kv_heads, 1, head_dim]
v1_current = value_states # [batch, num_kv_heads, 1, head_dim]
# Concatenate with past shared KV cache for attention computation
if past_shared_key is not None and past_shared_value is not None:
key_states = torch.cat([past_shared_key, key_states], dim=2)
value_states = torch.cat([past_shared_value, value_states], dim=2)
# Repeat KV for GQA
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, k1_current, v1_current
def forward_decode_loop2(
self,
hidden_states: torch.Tensor,
k1: torch.Tensor,
v1: torch.Tensor,
past_shared_key: Optional[torch.Tensor],
past_shared_value: Optional[torch.Tensor],
past_local_key: Optional[torch.Tensor],
past_local_value: Optional[torch.Tensor],
gate_proj: LoopGateProjection,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
loop_window_size: int = 64,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass for Loop 2 in decode stage with mixed attention.
Args:
hidden_states: Current hidden states [batch, 1, hidden_size]
k1: Key from Loop 1 (current token) [batch, num_kv_heads, 1, head_dim]
v1: Value from Loop 1 (current token) [batch, num_kv_heads, 1, head_dim]
past_shared_key: Past shared keys from cache [batch, num_kv_heads, past_len, head_dim]
past_shared_value: Past shared values from cache [batch, num_kv_heads, past_len, head_dim]
past_local_key: Past local keys from cache [batch, num_kv_heads, window_len, head_dim]
past_local_value: Past local values from cache [batch, num_kv_heads, window_len, head_dim]
gate_proj: Gate projection module
attention_mask: Causal attention mask
position_ids: Position IDs
loop_window_size: Window size for sliding window attention
Returns:
output: Attention output [batch, 1, hidden_size]
k2: Current key [batch, num_kv_heads, 1, head_dim]
v2: Current value [batch, num_kv_heads, 1, head_dim]
"""
bsz, q_len, _ = hidden_states.size()
# Get Q2, K2, V2 for current loop
q2, k2, v2 = self.get_qkv(hidden_states, position_ids)
# Compute gate: g = sigmoid(linear(Q2))
gate = gate_proj(q2) # [batch, num_heads, 1, 1]
# For attention A: concatenate past shared KV with current k1, v1 (full global context)
if past_shared_key is not None and past_shared_value is not None:
k1_full = torch.cat([past_shared_key, k1], dim=2)
v1_full = torch.cat([past_shared_value, v1], dim=2)
else:
k1_full = k1
v1_full = v1
# For attention B: concatenate past local KV with current k2, v2 (sliding window)
if past_local_key is not None and past_local_value is not None:
k2_full = torch.cat([past_local_key, k2], dim=2)
v2_full = torch.cat([past_local_value, v2], dim=2)
else:
k2_full = k2
v2_full = v2
# Repeat KV for GQA
k1_expanded = repeat_kv(k1_full, self.num_key_value_groups)
v1_expanded = repeat_kv(v1_full, self.num_key_value_groups)
k2_expanded = repeat_kv(k2_full, self.num_key_value_groups)
v2_expanded = repeat_kv(v2_full, self.num_key_value_groups)
# Attention A: Q2 @ K1_full, V1_full (global, full sequence)
head_dim = q2.shape[-1]
attn_weights_A = torch.matmul(q2, k1_expanded.transpose(2, 3)) / math.sqrt(head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : k1_expanded.shape[-2]]
attn_weights_A = attn_weights_A + causal_mask
attn_weights_A = nn.functional.softmax(attn_weights_A, dim=-1, dtype=torch.float32).to(q2.dtype)
attn_A = torch.matmul(attn_weights_A, v1_expanded)
# Attention B: Q2 @ K2_full, V2_full (local sliding window)
attn_weights_B = torch.matmul(q2, k2_expanded.transpose(2, 3)) / math.sqrt(head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : k2_expanded.shape[-2]]
attn_weights_B = attn_weights_B + causal_mask
# Apply sliding window mask
q_len_attn = q2.shape[2]
k_len_attn = k2_expanded.shape[2]
if q_len_attn <= loop_window_size:
# If sequence fits in window, use standard attention
attn_weights_B = nn.functional.softmax(attn_weights_B, dim=-1, dtype=torch.float32).to(q2.dtype)
else:
# Apply sliding window mask
row_idx = torch.arange(q_len_attn, device=q2.device).unsqueeze(1)
col_idx = torch.arange(k_len_attn, device=q2.device).unsqueeze(0)
window_mask = (col_idx > row_idx) | (col_idx < row_idx - loop_window_size + 1)
window_mask = window_mask.unsqueeze(0).unsqueeze(0)
attn_weights_B = attn_weights_B.masked_fill(window_mask, float('-inf'))
attn_weights_B = nn.functional.softmax(attn_weights_B, dim=-1, dtype=torch.float32).to(q2.dtype)
attn_B = torch.matmul(attn_weights_B, v2_expanded)
# Mixed attention: gate * A + (1 - gate) * B
mixed_attn = gate * attn_A + (1 - gate) * attn_B
# Reshape and apply output projection
bsz, num_heads, seq_len, head_dim = mixed_attn.shape
mixed_attn = mixed_attn.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
attn_output = self.o_proj(mixed_attn)
return attn_output, k2, v2
class IQuestLoopCoderDecoderLayer(nn.Module):
"""Transformer decoder layer."""
def __init__(self, config: IQuestLoopCoderConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = IQuestLoopCoderAttention(config=config, layer_idx=layer_idx)
self.mlp = IQuestLoopCoderMLP(config)
self.input_layernorm = IQuestLoopCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = IQuestLoopCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def forward_loop2_mixed(
self,
hidden_states: torch.Tensor,
k1: torch.Tensor,
v1: torch.Tensor,
gate_proj: LoopGateProjection,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
loop_window_size: int = 64,
) -> Tuple[torch.Tensor, float]:
"""Forward pass for Loop 2+ with mixed attention.
Args:
hidden_states: Current hidden states
k1: Key from Loop 1 [batch, num_kv_heads, seq_len, head_dim]
v1: Value from Loop 1 [batch, num_kv_heads, seq_len, head_dim]
gate_proj: Gate projection module for this layer
attention_mask: Causal attention mask
position_ids: Position IDs
loop_window_size: Window size for sliding window attention
Returns:
output hidden states, gate mean value
"""
residual = hidden_states
hidden_states_normed = self.input_layernorm(hidden_states)
# Get Q2, K2, V2 for current loop
q2, k2, v2 = self.self_attn.get_qkv(hidden_states_normed, position_ids)
# Compute gate: g = sigmoid(linear(Q2))
# q2: [batch, num_heads, seq_len, head_dim]
gate = gate_proj(q2) # [batch, num_heads, seq_len, 1]
gate_mean = gate.detach().mean().item()
# Repeat K1, V1 for GQA
k1_expanded = repeat_kv(k1, self.self_attn.num_key_value_groups)
v1_expanded = repeat_kv(v1, self.self_attn.num_key_value_groups)
k2_expanded = repeat_kv(k2, self.self_attn.num_key_value_groups)
v2_expanded = repeat_kv(v2, self.self_attn.num_key_value_groups)
# Attention A: Q2 @ K1, V1 (global, full sequence)
attn_A = self._compute_attention(q2, k1_expanded, v1_expanded, attention_mask)
# Attention B: Q2 @ K2, V2 (local sliding window)
attn_B = self._compute_attention_with_window(q2, k2_expanded, v2_expanded, attention_mask, loop_window_size)
# Mixed attention: gate * A + (1 - gate) * B
# attn_A, attn_B: [batch, num_heads, seq_len, head_dim]
mixed_attn = gate * attn_A + (1 - gate) * attn_B
# Reshape and apply output projection
bsz, num_heads, seq_len, head_dim = mixed_attn.shape
mixed_attn = mixed_attn.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
hidden_states = self.self_attn.o_proj(mixed_attn)
hidden_states = residual + hidden_states
# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, gate_mean
def _compute_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
) -> torch.Tensor:
"""Standard attention computation."""
head_dim = query.shape[-1]
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value)
return attn_output
def _compute_attention_with_window(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
window_size: int,
) -> torch.Tensor:
"""Attention with sliding window."""
q_len = query.shape[2]
k_len = key.shape[2]
head_dim = query.shape[-1]
# If sequence fits in window, use standard attention
if q_len <= window_size:
return self._compute_attention(query, key, value, attention_mask)
attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_dim)
# Apply causal mask
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
# Apply sliding window mask
row_idx = torch.arange(q_len, device=query.device).unsqueeze(1)
col_idx = torch.arange(k_len, device=query.device).unsqueeze(0)
# Can only attend to positions in [i - window_size + 1, i]
window_mask = (col_idx > row_idx) | (col_idx < row_idx - window_size + 1)
window_mask = window_mask.unsqueeze(0).unsqueeze(0)
attn_weights = attn_weights.masked_fill(window_mask, float('-inf'))
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_output = torch.matmul(attn_weights, value)
return attn_output
class IQuestLoopCoderPreTrainedModel(PreTrainedModel):
"""Base class for IQuestLoopCoder models."""
config_class = IQuestLoopCoderConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["IQuestLoopCoderDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_cache_class = True
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class IQuestLoopCoderModel(IQuestLoopCoderPreTrainedModel):
"""IQuestLoopCoder Transformer decoder model."""
def __init__(self, config: IQuestLoopCoderConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([
IQuestLoopCoderDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
])
self.norm = IQuestLoopCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Gate projections for Loop 2+ (one per layer)
self.gate_projections = nn.ModuleList([
LoopGateProjection(config.num_attention_heads, config.head_dim)
for _ in range(config.num_hidden_layers)
])
# Loop configuration
self.loop_num = config.loop_num
self.loop_window_size = config.loop_window_size
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
seq_length = inputs_embeds.shape[1]
# Determine which forward path to use:
# 1. If past_key_values exists and seq_length == 1: autoregressive generation step
# -> Use standard attention with KV cache (no loop needed for single token)
# 2. Otherwise (prefill or training): use loop mechanism
is_generation_step = past_key_values is not None and seq_length == 1
if is_generation_step:
# Autoregressive generation: single token, use KV cache
return self._forward_with_cache(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
# Prefill or training: use loop mechanism
return self._forward_loop(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_cache=use_cache,
cache_position=cache_position,
)
def _forward_loop(
self,
inputs_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor],
output_attentions: bool,
output_hidden_states: bool,
return_dict: bool,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Forward with loop mechanism (for training and prefill).
This implements the Loop mechanism:
- Loop 1: Standard attention, stores K1, V1 for each layer
- Loop 2+: Mixed attention with gated combination of global (K1,V1) and local (K2,V2)
"""
batch_size, seq_length, _ = inputs_embeds.shape
if position_ids is None:
device = inputs_embeds.device
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0)
if cache_position is None:
cache_position = torch.arange(seq_length, device=inputs_embeds.device)
# Create causal mask
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, output_attentions)
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
# For KV cache during prefill - use IQuestLoopCoderCache
# In prefill, past_key_values should be None, so we create a new cache
if use_cache:
next_decoder_cache = IQuestLoopCoderCache(self.loop_window_size, len(self.layers))
else:
next_decoder_cache = None
# ============ Loop 1: Standard forward, store K1, V1 in shared cache ============
for layer_idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# Get K1, V1 before standard forward (from original hidden_states, after layernorm)
hidden_states_normed = decoder_layer.input_layernorm(hidden_states)
q1, k1, v1 = decoder_layer.self_attn.get_qkv(hidden_states_normed, position_ids)
# Store K1, V1 in shared cache
if use_cache:
next_decoder_cache.update_shared(k1, v1, layer_idx)
# Standard forward
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=None,
output_attentions=output_attentions,
use_cache=False,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# ============ Loop 2 to loop_num: Mixed attention, store in local cache ============
for loop_idx in range(2, self.loop_num + 1):
for layer_idx, decoder_layer in enumerate(self.layers):
# Get K1, V1 from shared cache
k1, v1 = next_decoder_cache.get_shared(layer_idx) if use_cache else (None, None)
if k1 is None or v1 is None:
# Fallback: compute K1, V1 if not in cache (shouldn't happen in prefill)
hidden_states_normed = decoder_layer.input_layernorm(hidden_states)
_, k1, v1 = decoder_layer.self_attn.get_qkv(hidden_states_normed, position_ids)
gate_proj = self.gate_projections[layer_idx]
hidden_states, gate_mean = decoder_layer.forward_loop2_mixed(
hidden_states,
k1=k1,
v1=v1,
gate_proj=gate_proj,
attention_mask=causal_mask,
position_ids=position_ids,
loop_window_size=self.loop_window_size,
)
# Store Loop 2+ KV in local cache (only for loop_idx == 2)
if use_cache and loop_idx == 2:
hidden_states_normed = decoder_layer.input_layernorm(hidden_states)
_, k2, v2 = decoder_layer.self_attn.get_qkv(hidden_states_normed, position_ids)
next_decoder_cache.update_local(k2, v2, layer_idx)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _forward_with_cache(
self,
inputs_embeds: torch.Tensor,
attention_mask: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor],
past_key_values: Optional[Cache],
use_cache: bool,
output_attentions: bool,
output_hidden_states: bool,
return_dict: bool,
cache_position: Optional[torch.LongTensor],
) -> Union[Tuple, BaseModelOutputWithPast]:
"""Forward with KV cache using loop mechanism (for inference generation).
Loop 1: Standard attention, uses shared KV cache (previous tokens + current token)
Loop 2+: Mixed attention, uses local KV cache (sliding window)
"""
batch_size, seq_length, _ = inputs_embeds.shape
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
# Ensure we're using IQuestLoopCoderCache
if use_cache:
if not isinstance(past_key_values, IQuestLoopCoderCache):
# Convert to IQuestLoopCoderCache if needed
next_decoder_cache = IQuestLoopCoderCache(self.loop_window_size, len(self.layers))
# Copy existing cache if possible
if past_key_values is not None:
for layer_idx in range(len(self.layers)):
try:
past_k = past_key_values.key_cache[layer_idx] if hasattr(past_key_values, 'key_cache') else None
past_v = past_key_values.value_cache[layer_idx] if hasattr(past_key_values, 'value_cache') else None
if past_k is not None and past_v is not None:
next_decoder_cache.update_shared(past_k, past_v, layer_idx)
except:
pass
else:
next_decoder_cache = past_key_values
else:
next_decoder_cache = None
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
# ============ Loop 1: Standard attention, store in shared cache ============
for layer_idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
# Get past shared KV cache
past_shared_key, past_shared_value = None, None
if next_decoder_cache is not None:
past_shared_key, past_shared_value = next_decoder_cache.get_shared(layer_idx)
# Forward Loop 1
attn_output, k1, v1 = decoder_layer.self_attn.forward_decode_loop1(
hidden_states=decoder_layer.input_layernorm(hidden_states),
past_shared_key=past_shared_key,
past_shared_value=past_shared_value,
attention_mask=causal_mask,
position_ids=position_ids,
cache_position=cache_position,
)
# Update shared cache with current token's Loop 1 KV
if use_cache:
next_decoder_cache.update_shared(k1, v1, layer_idx)
hidden_states = hidden_states + attn_output
# MLP
residual = hidden_states
hidden_states = decoder_layer.post_attention_layernorm(hidden_states)
hidden_states = decoder_layer.mlp(hidden_states)
hidden_states = residual + hidden_states
if output_attentions:
all_self_attns += (None,) # We don't return attention weights in decode loop
# ============ Loop 2 to loop_num: Mixed attention, store in local cache ============
# Store k1, v1 from Loop 1 for use in Loop 2+
loop1_kv = []
for layer_idx in range(len(self.layers)):
if next_decoder_cache is not None:
k1_full, v1_full = next_decoder_cache.get_shared(layer_idx)
if k1_full is not None and v1_full is not None:
# Get only the last token (current token)
loop1_kv.append((k1_full[:, :, -1:, :], v1_full[:, :, -1:, :], k1_full, v1_full))
else:
loop1_kv.append((None, None, None, None))
else:
loop1_kv.append((None, None, None, None))
for loop_idx in range(2, self.loop_num + 1):
for layer_idx, decoder_layer in enumerate(self.layers):
# Get k1, v1 (current token's Loop 1 KV) and full shared cache
k1_current, v1_current, k1_full, v1_full = loop1_kv[layer_idx]
if k1_current is None or v1_current is None:
continue
# Get past local KV cache
past_local_key, past_local_value = None, None
if next_decoder_cache is not None:
past_local_key, past_local_value = next_decoder_cache.get_local(layer_idx)
gate_proj = self.gate_projections[layer_idx]
# Forward Loop 2+
attn_output, k2, v2 = decoder_layer.self_attn.forward_decode_loop2(
hidden_states=decoder_layer.input_layernorm(hidden_states),
k1=k1_current,
v1=v1_current,
past_shared_key=k1_full[:, :, :-1, :] if k1_full is not None and k1_full.shape[2] > 1 else None,
past_shared_value=v1_full[:, :, :-1, :] if v1_full is not None and v1_full.shape[2] > 1 else None,
past_local_key=past_local_key,
past_local_value=past_local_value,
gate_proj=gate_proj,
attention_mask=causal_mask,
position_ids=position_ids,
loop_window_size=self.loop_window_size,
)
# Update local cache with current token's Loop 2+ KV
if use_cache and loop_idx == 2:
next_decoder_cache.update_local(k2, v2, layer_idx)
hidden_states = hidden_states + attn_output
# MLP
residual = hidden_states
hidden_states = decoder_layer.post_attention_layernorm(hidden_states)
hidden_states = decoder_layer.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
"""Create causal attention mask."""
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# Determine target length for attention
if past_key_values is not None:
# For DynamicCache: use get_seq_length() to get cached length
# target_length = cached_length + current_sequence_length
past_length = past_key_values.get_seq_length()
target_length = past_length + sequence_length
elif attention_mask is not None:
target_length = attention_mask.shape[-1]
else:
target_length = sequence_length
# Create causal mask
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
if sequence_length != 1:
# For prefill: standard causal mask
causal_mask = torch.triu(causal_mask, diagonal=1)
# Adjust for cache position (for generation steps after prefill)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone()
mask_length = attention_mask.shape[-1]
if mask_length <= target_length:
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
return causal_mask
class IQuestLoopCoderForCausalLM(IQuestLoopCoderPreTrainedModel, GenerationMixin):
"""IQuestLoopCoder model with a causal language modeling head."""
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = IQuestLoopCoderModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
use_cache=True,
**kwargs,
):
past_length = 0
if past_key_values is not None:
past_length = past_key_values.get_seq_length()
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
if cache_position is None:
cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device)
elif use_cache:
cache_position = cache_position[-input_ids.shape[1]:]
position_ids = cache_position.unsqueeze(0)
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
|