razmars commited on
Commit
148f740
·
1 Parent(s): 16704e7

modeling update

Browse files
Files changed (1) hide show
  1. modeling_super_linear.py +4 -2
modeling_super_linear.py CHANGED
@@ -5,6 +5,8 @@ import torch, torch.nn as nn, torch.nn.functional as F
5
  from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
6
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
7
  from .configuration_super_linear import SuperLinearConfig
 
 
8
 
9
 
10
  import math
@@ -544,9 +546,9 @@ class superLinear(nn.Module):
544
 
545
  "-------------------------------------------------------------------------------------------------------------------"
546
  class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
547
- config_class = SuperLinearConfig
548
 
549
- def __init__(self, config: SuperLinearConfig):
550
  super().__init__(config)
551
 
552
 
 
5
  from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
6
  from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
7
  from .configuration_super_linear import SuperLinearConfig
8
+ from .configuration_super_linear_fs import SuperLinearConfigFS
9
+ from typing import Tuple, Union
10
 
11
 
12
  import math
 
546
 
547
  "-------------------------------------------------------------------------------------------------------------------"
548
  class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
549
+ config_class: Tuple[type] = (SuperLinearConfig, SuperLinearConfigFS)
550
 
551
+ def __init__(self, config: Union[SuperLinearConfig, SuperLinearConfigFS]):
552
  super().__init__(config)
553
 
554