modeling update
Browse files- modeling_super_linear.py +4 -2
modeling_super_linear.py
CHANGED
|
@@ -5,6 +5,8 @@ import torch, torch.nn as nn, torch.nn.functional as F
|
|
| 5 |
from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
|
| 6 |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 7 |
from .configuration_super_linear import SuperLinearConfig
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
import math
|
|
@@ -544,9 +546,9 @@ class superLinear(nn.Module):
|
|
| 544 |
|
| 545 |
"-------------------------------------------------------------------------------------------------------------------"
|
| 546 |
class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
| 547 |
-
config_class = SuperLinearConfig
|
| 548 |
|
| 549 |
-
def __init__(self, config: SuperLinearConfig):
|
| 550 |
super().__init__(config)
|
| 551 |
|
| 552 |
|
|
|
|
| 5 |
from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
|
| 6 |
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
| 7 |
from .configuration_super_linear import SuperLinearConfig
|
| 8 |
+
from .configuration_super_linear_fs import SuperLinearConfigFS
|
| 9 |
+
from typing import Tuple, Union
|
| 10 |
|
| 11 |
|
| 12 |
import math
|
|
|
|
| 546 |
|
| 547 |
"-------------------------------------------------------------------------------------------------------------------"
|
| 548 |
class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
| 549 |
+
config_class: Tuple[type] = (SuperLinearConfig, SuperLinearConfigFS)
|
| 550 |
|
| 551 |
+
def __init__(self, config: Union[SuperLinearConfig, SuperLinearConfigFS]):
|
| 552 |
super().__init__(config)
|
| 553 |
|
| 554 |
|