modeling
Browse files- modeling_super_linear.py +2 -3
modeling_super_linear.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
|
| 2 |
-
from typing import Optional, Tuple
|
| 3 |
import torch, torch.nn as nn, torch.nn.functional as F
|
| 4 |
|
| 5 |
from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
|
|
@@ -546,12 +546,11 @@ class superLinear(nn.Module):
|
|
| 546 |
|
| 547 |
"-------------------------------------------------------------------------------------------------------------------"
|
| 548 |
class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
| 549 |
-
config_class:
|
| 550 |
|
| 551 |
def __init__(self, config: Union[SuperLinearConfig, SuperLinearConfigFS]):
|
| 552 |
super().__init__(config)
|
| 553 |
|
| 554 |
-
|
| 555 |
# the backbone keeps its own Config dataclass, so build one on‑the‑fly:
|
| 556 |
backbone_cfg = type("Cfg", (), config.to_dict())()
|
| 557 |
self.args = backbone_cfg
|
|
|
|
| 1 |
|
| 2 |
+
from typing import Optional, Tuple, Union
|
| 3 |
import torch, torch.nn as nn, torch.nn.functional as F
|
| 4 |
|
| 5 |
from transformers import (PreTrainedModel,GenerationMixin,AutoConfig,AutoModelForCausalLM,)
|
|
|
|
| 546 |
|
| 547 |
"-------------------------------------------------------------------------------------------------------------------"
|
| 548 |
class SuperLinearForCausalLM(PreTrainedModel, GenerationMixin):
|
| 549 |
+
config_class: Union[SuperLinearConfig, SuperLinearConfigFS]
|
| 550 |
|
| 551 |
def __init__(self, config: Union[SuperLinearConfig, SuperLinearConfigFS]):
|
| 552 |
super().__init__(config)
|
| 553 |
|
|
|
|
| 554 |
# the backbone keeps its own Config dataclass, so build one on‑the‑fly:
|
| 555 |
backbone_cfg = type("Cfg", (), config.to_dict())()
|
| 556 |
self.args = backbone_cfg
|