razmars commited on
Commit
c12ca9e
·
verified ·
1 Parent(s): 1e0be9a

Update modeling_super_linear.py

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +1 -15
modeling_super_linear.py CHANGED
@@ -526,21 +526,7 @@ class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
526
 
527
  # backbone returns (B, pred_len, C)
528
  preds = self.backbone(x_enc)
529
- print(F"preds shape: {preds.shape}")
530
- #preds = preds[0]
531
- print(F"preds shape: {preds.shape}")
532
-
533
- # if we keep continuous values, treat them as logits directly
534
- logits = (preds if self.vocab_size is None else self.lm_head(preds).transpose(1, 2))
535
-
536
- loss = None
537
- if labels is not None:
538
- # shift for causal objective
539
- shift_logits = logits[..., :-1, :].contiguous()
540
- shift_labels = labels[..., 1:].contiguous()
541
- loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
542
-
543
- return CausalLMOutputWithCrossAttentions(loss=loss,logits=logits,past_key_values=None,hidden_states=None,attentions=None,)
544
 
545
 
546
  def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, **kwargs):
 
526
 
527
  # backbone returns (B, pred_len, C)
528
  preds = self.backbone(x_enc)
529
+ return CausalLMOutputWithCrossAttentions(loss=None,logits=preds,past_key_values=None,hidden_states=None,attentions=None,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
 
532
  def prepare_inputs_for_generation(self, inputs_embeds, past_key_values=None, **kwargs):