Update model_slm.py
Browse files- model_slm.py +99 -119
model_slm.py
CHANGED
|
@@ -54,7 +54,7 @@ class MixtureOfRecursionsConfig(PretrainedConfig):
|
|
| 54 |
self.max_position_embeddings = max_position_embeddings or max_seq_len
|
| 55 |
|
| 56 |
# ============================================================================
|
| 57 |
-
# EMBEDDINGS MODULE
|
| 58 |
# ============================================================================
|
| 59 |
|
| 60 |
DEFAULT_BASE = 10000.0
|
|
@@ -400,66 +400,113 @@ class RecursiveTransformerLayer(nn.Module):
|
|
| 400 |
active_batches &= (steps > step)
|
| 401 |
return x, computation_loss
|
| 402 |
|
| 403 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
"""Transformer model with mixture of recursive layers for technical content."""
|
| 405 |
|
| 406 |
-
def __init__(
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
d_model
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
dim_feedforward: int = DEFAULT_DIM_FEEDFORWARD,
|
| 414 |
-
dropout: float = DEFAULT_DROPOUT,
|
| 415 |
-
max_seq_len: int = DEFAULT_MAX_SEQ_LEN,
|
| 416 |
-
router_type: str = DEFAULT_ROUTER_TYPE,
|
| 417 |
-
padding_idx: int = DEFAULT_PADDING_IDX,
|
| 418 |
-
pos_encoding: str = "learned"
|
| 419 |
-
):
|
| 420 |
-
super().__init__()
|
| 421 |
-
self.d_model = d_model
|
| 422 |
-
self.vocab_size = vocab_size
|
| 423 |
-
self.padding_idx = padding_idx
|
| 424 |
self.embeddings = TechEmbeddingLayer(
|
| 425 |
-
vocab_size=vocab_size,
|
| 426 |
-
d_model=d_model,
|
| 427 |
-
max_seq_len=max_seq_len,
|
| 428 |
-
dropout=dropout,
|
| 429 |
-
padding_idx=padding_idx,
|
| 430 |
-
pos_encoding=pos_encoding
|
| 431 |
)
|
|
|
|
| 432 |
self.layers = nn.ModuleList([
|
| 433 |
RecursiveTransformerLayer(
|
| 434 |
-
d_model=d_model,
|
| 435 |
-
n_heads=n_heads,
|
| 436 |
-
dim_feedforward=dim_feedforward,
|
| 437 |
-
max_steps=max_steps,
|
| 438 |
-
dropout=dropout,
|
| 439 |
-
router_type=router_type
|
| 440 |
-
) for _ in range(n_layers)
|
| 441 |
])
|
| 442 |
-
|
| 443 |
-
self.
|
| 444 |
-
self.
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
|
| 449 |
-
def forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
batch_size, seq_len = input_ids.shape
|
|
|
|
|
|
|
| 451 |
padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
|
| 452 |
causal_mask = create_causal_mask(seq_len, input_ids.device)
|
| 453 |
combined_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) | causal_mask.unsqueeze(0)
|
|
|
|
|
|
|
| 454 |
x = self.embeddings(input_ids)
|
| 455 |
pos_encoding = self.embeddings.get_positional_encoding()
|
|
|
|
| 456 |
total_computation_loss = torch.tensor(0.0, device=x.device)
|
| 457 |
for layer in self.layers:
|
| 458 |
x, comp_loss = layer(x, combined_mask, pos_encoding)
|
| 459 |
total_computation_loss += comp_loss
|
|
|
|
| 460 |
x = self.final_norm(x)
|
| 461 |
logits = self.lm_head(x)
|
| 462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
|
| 464 |
def generate_step(
|
| 465 |
self,
|
|
@@ -470,11 +517,14 @@ class MixtureOfRecursions(nn.Module):
|
|
| 470 |
) -> torch.Tensor:
|
| 471 |
self.eval()
|
| 472 |
with torch.no_grad():
|
| 473 |
-
|
|
|
|
| 474 |
last_logits = logits[:, -1, :] / temperature
|
|
|
|
| 475 |
if top_k is not None:
|
| 476 |
indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
|
| 477 |
last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
|
|
|
|
| 478 |
if top_p is not None:
|
| 479 |
sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
|
| 480 |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
@@ -483,83 +533,12 @@ class MixtureOfRecursions(nn.Module):
|
|
| 483 |
sorted_indices_to_remove[..., 0] = False
|
| 484 |
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 485 |
last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
|
|
|
|
| 486 |
probs = F.softmax(last_logits, dim=-1)
|
| 487 |
return torch.multinomial(probs, num_samples=1)
|
| 488 |
|
| 489 |
-
class
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
def __init__(self, model: nn.Module, tokenizer: 'Tokenizer', max_length: int = DEFAULT_MAX_SEQ_LEN, device: Optional[torch.device] = None):
|
| 493 |
-
self.model = model
|
| 494 |
-
self.tokenizer = tokenizer
|
| 495 |
-
self.max_length = max_length
|
| 496 |
-
self.device = device if device else next(model.parameters()).device
|
| 497 |
-
self.model.to(self.device)
|
| 498 |
-
self.eos_token_id = tokenizer.vocab.get('<|endoftext|>', -1)
|
| 499 |
-
self.assistant_token_id = tokenizer.vocab.get('<|assistant|>', -1)
|
| 500 |
-
|
| 501 |
-
def generate(
|
| 502 |
-
self,
|
| 503 |
-
prompt: str,
|
| 504 |
-
method: str = "nucleus",
|
| 505 |
-
temperature: float = 1.0,
|
| 506 |
-
top_k: Optional[int] = 50,
|
| 507 |
-
top_p: Optional[float] = 0.9,
|
| 508 |
-
max_new_tokens: Optional[int] = None
|
| 509 |
-
) -> str:
|
| 510 |
-
max_new_tokens = max_new_tokens or self.max_length
|
| 511 |
-
input_text = f"<|user|> {prompt}"
|
| 512 |
-
input_ids = self.tokenizer.encode_ids(input_text, add_special_tokens=True)
|
| 513 |
-
input_tensor = torch.tensor([input_ids], device=self.device)
|
| 514 |
-
self.model.eval()
|
| 515 |
-
generated_ids = []
|
| 516 |
-
with torch.no_grad():
|
| 517 |
-
for _ in range(max_new_tokens):
|
| 518 |
-
if input_tensor.size(1) > self.max_length:
|
| 519 |
-
input_tensor = input_tensor[:, -self.max_length:]
|
| 520 |
-
if method == "greedy":
|
| 521 |
-
next_token = self._greedy_generate(input_tensor)
|
| 522 |
-
elif method == "sample":
|
| 523 |
-
next_token = self._sample_generate(input_tensor, temperature)
|
| 524 |
-
elif method == "top_k":
|
| 525 |
-
next_token = self._top_k_generate(input_tensor, temperature, top_k)
|
| 526 |
-
elif method == "nucleus" or method == "top_p":
|
| 527 |
-
next_token = self._nucleus_generate(input_tensor, temperature, top_p)
|
| 528 |
-
else:
|
| 529 |
-
raise ValueError(f"Unknown generation method: {method}")
|
| 530 |
-
next_token_id = next_token.item()
|
| 531 |
-
generated_ids.append(next_token_id)
|
| 532 |
-
input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
|
| 533 |
-
if next_token_id == self.eos_token_id or (self.assistant_token_id != -1 and next_token_id == self.assistant_token_id):
|
| 534 |
-
break
|
| 535 |
-
full_ids = input_ids + generated_ids
|
| 536 |
-
full_text = self.tokenizer.decode_ids(full_ids, skip_special_tokens=False)
|
| 537 |
-
if "<|assistant|>" in full_text:
|
| 538 |
-
response = full_text.split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip()
|
| 539 |
-
else:
|
| 540 |
-
response = full_text.split("<|endoftext|>")[0].strip()
|
| 541 |
-
return response if response else "No response generated."
|
| 542 |
-
|
| 543 |
-
def _greedy_generate(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 544 |
-
logits, _ = self.model(input_tensor)
|
| 545 |
-
return torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
|
| 546 |
-
|
| 547 |
-
def _sample_generate(self, input_tensor: torch.Tensor, temperature: float) -> torch.Tensor:
|
| 548 |
-
logits, _ = self.model(input_tensor)
|
| 549 |
-
logits = logits[:, -1, :] / temperature
|
| 550 |
-
probs = F.softmax(logits, dim=-1)
|
| 551 |
-
return torch.multinomial(probs, num_samples=1)
|
| 552 |
-
|
| 553 |
-
def _top_k_generate(self, input_tensor: torch.Tensor, temperature: float, top_k: int) -> torch.Tensor:
|
| 554 |
-
logits, _ = self.model(input_tensor)
|
| 555 |
-
logits = logits[:, -1, :] / temperature
|
| 556 |
-
top_k_logits, top_k_indices = torch.topk(logits, top_k)
|
| 557 |
-
probs = F.softmax(top_k_logits, dim=-1)
|
| 558 |
-
next_token_idx = torch.multinomial(probs, num_samples=1)
|
| 559 |
-
return top_k_indices.gather(-1, next_token_idx)
|
| 560 |
-
|
| 561 |
-
def _nucleus_generate(self, input_tensor: torch.Tensor, temperature: float, top_p: float) -> torch.Tensor:
|
| 562 |
-
return self.model.generate_step(input_tensor, temperature, top_p=top_p)
|
| 563 |
|
| 564 |
def count_parameters(model: nn.Module) -> Tuple[int, int]:
|
| 565 |
total_params = sum(p.numel() for p in model.parameters())
|
|
@@ -569,7 +548,7 @@ def count_parameters(model: nn.Module) -> Tuple[int, int]:
|
|
| 569 |
def main():
|
| 570 |
"""Test the MixtureOfRecursions model and its components."""
|
| 571 |
print("Initializing MixtureOfRecursions model...")
|
| 572 |
-
|
| 573 |
vocab_size=DEFAULT_VOCAB_SIZE,
|
| 574 |
d_model=DEFAULT_D_MODEL,
|
| 575 |
n_layers=DEFAULT_N_LAYERS,
|
|
@@ -579,6 +558,7 @@ def main():
|
|
| 579 |
dropout=DEFAULT_DROPOUT,
|
| 580 |
router_type=DEFAULT_ROUTER_TYPE
|
| 581 |
)
|
|
|
|
| 582 |
|
| 583 |
total_params, trainable_params = count_parameters(model)
|
| 584 |
print(f"Total parameters: {total_params:,}")
|
|
@@ -590,13 +570,13 @@ def main():
|
|
| 590 |
attention_mask = torch.ones_like(input_ids)
|
| 591 |
attention_mask[:, -10:] = 0
|
| 592 |
|
| 593 |
-
|
|
|
|
| 594 |
|
| 595 |
assert logits.shape == (batch_size, seq_len, DEFAULT_VOCAB_SIZE), f"Unexpected logits shape: {logits.shape}"
|
| 596 |
print(f"Input shape: {input_ids.shape}")
|
| 597 |
print(f"Output logits shape: {logits.shape}")
|
| 598 |
print(f"Expected logits shape: ({batch_size}, {seq_len}, {DEFAULT_VOCAB_SIZE})")
|
| 599 |
-
print(f"Computation loss: {comp_loss:.4f}")
|
| 600 |
|
| 601 |
print("\nTesting generation step...")
|
| 602 |
next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9)
|
|
|
|
| 54 |
self.max_position_embeddings = max_position_embeddings or max_seq_len
|
| 55 |
|
| 56 |
# ============================================================================
|
| 57 |
+
# EMBEDDINGS MODULE
|
| 58 |
# ============================================================================
|
| 59 |
|
| 60 |
DEFAULT_BASE = 10000.0
|
|
|
|
| 400 |
active_batches &= (steps > step)
|
| 401 |
return x, computation_loss
|
| 402 |
|
| 403 |
+
# ============================================================================
|
| 404 |
+
# PRETRAINED MODEL WRAPPER
|
| 405 |
+
# ============================================================================
|
| 406 |
+
|
| 407 |
+
class MixtureOfRecursionsPreTrainedModel(PreTrainedModel):
|
| 408 |
+
"""PreTrainedModel wrapper for MixtureOfRecursions."""
|
| 409 |
+
|
| 410 |
+
config_class = MixtureOfRecursionsConfig
|
| 411 |
+
base_model_prefix = "model"
|
| 412 |
+
supports_gradient_checkpointing = True
|
| 413 |
+
|
| 414 |
+
def _init_weights(self, module):
|
| 415 |
+
"""Initialize weights."""
|
| 416 |
+
if isinstance(module, nn.Linear):
|
| 417 |
+
module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5)
|
| 418 |
+
if module.bias is not None:
|
| 419 |
+
module.bias.data.zero_()
|
| 420 |
+
elif isinstance(module, nn.Embedding):
|
| 421 |
+
module.weight.data.normal_(mean=0.0, std=self.config.d_model ** -0.5)
|
| 422 |
+
if module.padding_idx is not None:
|
| 423 |
+
module.weight.data[module.padding_idx].zero_()
|
| 424 |
+
elif isinstance(module, nn.LayerNorm):
|
| 425 |
+
module.bias.data.zero_()
|
| 426 |
+
module.weight.data.fill_(1.0)
|
| 427 |
+
|
| 428 |
+
class MixtureOfRecursions(MixtureOfRecursionsPreTrainedModel):
|
| 429 |
"""Transformer model with mixture of recursive layers for technical content."""
|
| 430 |
|
| 431 |
+
def __init__(self, config: MixtureOfRecursionsConfig):
|
| 432 |
+
super().__init__(config)
|
| 433 |
+
self.config = config
|
| 434 |
+
self.d_model = config.d_model
|
| 435 |
+
self.vocab_size = config.vocab_size
|
| 436 |
+
self.padding_idx = config.padding_idx
|
| 437 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
self.embeddings = TechEmbeddingLayer(
|
| 439 |
+
vocab_size=config.vocab_size,
|
| 440 |
+
d_model=config.d_model,
|
| 441 |
+
max_seq_len=config.max_seq_len,
|
| 442 |
+
dropout=config.dropout,
|
| 443 |
+
padding_idx=config.padding_idx,
|
| 444 |
+
pos_encoding=config.pos_encoding
|
| 445 |
)
|
| 446 |
+
|
| 447 |
self.layers = nn.ModuleList([
|
| 448 |
RecursiveTransformerLayer(
|
| 449 |
+
d_model=config.d_model,
|
| 450 |
+
n_heads=config.n_heads,
|
| 451 |
+
dim_feedforward=config.dim_feedforward,
|
| 452 |
+
max_steps=config.max_steps,
|
| 453 |
+
dropout=config.dropout,
|
| 454 |
+
router_type=config.router_type
|
| 455 |
+
) for _ in range(config.n_layers)
|
| 456 |
])
|
| 457 |
+
|
| 458 |
+
self.final_norm = nn.LayerNorm(config.d_model)
|
| 459 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
| 460 |
+
|
| 461 |
+
# Initialize weights
|
| 462 |
+
self.post_init()
|
| 463 |
|
| 464 |
+
def forward(
|
| 465 |
+
self,
|
| 466 |
+
input_ids: torch.Tensor,
|
| 467 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 468 |
+
labels: Optional[torch.Tensor] = None,
|
| 469 |
+
return_dict: bool = True
|
| 470 |
+
):
|
| 471 |
batch_size, seq_len = input_ids.shape
|
| 472 |
+
|
| 473 |
+
# Create masks
|
| 474 |
padding_mask = create_padding_mask(input_ids, self.padding_idx) if attention_mask is None else (attention_mask == 0)
|
| 475 |
causal_mask = create_causal_mask(seq_len, input_ids.device)
|
| 476 |
combined_mask = padding_mask.unsqueeze(1).expand(batch_size, seq_len, seq_len) | causal_mask.unsqueeze(0)
|
| 477 |
+
|
| 478 |
+
# Forward pass
|
| 479 |
x = self.embeddings(input_ids)
|
| 480 |
pos_encoding = self.embeddings.get_positional_encoding()
|
| 481 |
+
|
| 482 |
total_computation_loss = torch.tensor(0.0, device=x.device)
|
| 483 |
for layer in self.layers:
|
| 484 |
x, comp_loss = layer(x, combined_mask, pos_encoding)
|
| 485 |
total_computation_loss += comp_loss
|
| 486 |
+
|
| 487 |
x = self.final_norm(x)
|
| 488 |
logits = self.lm_head(x)
|
| 489 |
+
|
| 490 |
+
loss = None
|
| 491 |
+
if labels is not None:
|
| 492 |
+
# Shift logits and labels for language modeling
|
| 493 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 494 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 495 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 496 |
+
loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))
|
| 497 |
+
loss += 0.01 * total_computation_loss # Add computation loss
|
| 498 |
+
|
| 499 |
+
if not return_dict:
|
| 500 |
+
output = (logits,)
|
| 501 |
+
return ((loss,) + output) if loss is not None else output
|
| 502 |
+
|
| 503 |
+
from transformers.modeling_outputs import CausalLMOutput
|
| 504 |
+
return CausalLMOutput(
|
| 505 |
+
loss=loss,
|
| 506 |
+
logits=logits,
|
| 507 |
+
hidden_states=None,
|
| 508 |
+
attentions=None,
|
| 509 |
+
)
|
| 510 |
|
| 511 |
def generate_step(
|
| 512 |
self,
|
|
|
|
| 517 |
) -> torch.Tensor:
|
| 518 |
self.eval()
|
| 519 |
with torch.no_grad():
|
| 520 |
+
outputs = self.forward(input_ids, return_dict=True)
|
| 521 |
+
logits = outputs.logits
|
| 522 |
last_logits = logits[:, -1, :] / temperature
|
| 523 |
+
|
| 524 |
if top_k is not None:
|
| 525 |
indices_to_remove = last_logits < torch.topk(last_logits, top_k)[0][..., -1, None]
|
| 526 |
last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
|
| 527 |
+
|
| 528 |
if top_p is not None:
|
| 529 |
sorted_logits, sorted_indices = torch.sort(last_logits, descending=True)
|
| 530 |
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
| 533 |
sorted_indices_to_remove[..., 0] = False
|
| 534 |
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 535 |
last_logits = last_logits.masked_fill(indices_to_remove, float('-inf'))
|
| 536 |
+
|
| 537 |
probs = F.softmax(last_logits, dim=-1)
|
| 538 |
return torch.multinomial(probs, num_samples=1)
|
| 539 |
|
| 540 |
+
# Register the model for auto class
|
| 541 |
+
MixtureOfRecursions.register_for_auto_class("AutoModelForCausalLM")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
|
| 543 |
def count_parameters(model: nn.Module) -> Tuple[int, int]:
|
| 544 |
total_params = sum(p.numel() for p in model.parameters())
|
|
|
|
| 548 |
def main():
|
| 549 |
"""Test the MixtureOfRecursions model and its components."""
|
| 550 |
print("Initializing MixtureOfRecursions model...")
|
| 551 |
+
config = MixtureOfRecursionsConfig(
|
| 552 |
vocab_size=DEFAULT_VOCAB_SIZE,
|
| 553 |
d_model=DEFAULT_D_MODEL,
|
| 554 |
n_layers=DEFAULT_N_LAYERS,
|
|
|
|
| 558 |
dropout=DEFAULT_DROPOUT,
|
| 559 |
router_type=DEFAULT_ROUTER_TYPE
|
| 560 |
)
|
| 561 |
+
model = MixtureOfRecursions(config)
|
| 562 |
|
| 563 |
total_params, trainable_params = count_parameters(model)
|
| 564 |
print(f"Total parameters: {total_params:,}")
|
|
|
|
| 570 |
attention_mask = torch.ones_like(input_ids)
|
| 571 |
attention_mask[:, -10:] = 0
|
| 572 |
|
| 573 |
+
outputs = model(input_ids, attention_mask, return_dict=True)
|
| 574 |
+
logits = outputs.logits
|
| 575 |
|
| 576 |
assert logits.shape == (batch_size, seq_len, DEFAULT_VOCAB_SIZE), f"Unexpected logits shape: {logits.shape}"
|
| 577 |
print(f"Input shape: {input_ids.shape}")
|
| 578 |
print(f"Output logits shape: {logits.shape}")
|
| 579 |
print(f"Expected logits shape: ({batch_size}, {seq_len}, {DEFAULT_VOCAB_SIZE})")
|
|
|
|
| 580 |
|
| 581 |
print("\nTesting generation step...")
|
| 582 |
next_token = model.generate_step(input_ids[:1], temperature=0.8, top_p=0.9)
|