TinyWay-1.1.0 / modeling_tinyway.py
NNEngine's picture
Initial release: TinyWay 1.1.0 (83.17M params)
4f7e31d verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutput
# =========================
# Config
# =========================
class TinyWayConfig(PretrainedConfig):
model_type = "tinyway"
def __init__(
self,
vocab_size=50257,
n_positions=256,
n_embd=512,
n_layer=10,
n_head=8,
dropout=0.1,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.dropout = dropout
# 🔥 HuggingFace-required aliases
self.hidden_size = n_embd
self.num_hidden_layers = n_layer
self.num_attention_heads = n_head
self.max_position_embeddings = n_positions
# =========================
# Causal Self-Attention
# =========================
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.head_dim = config.n_embd // config.n_head
self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd)
self.proj = nn.Linear(config.n_embd, config.n_embd)
self.attn_dropout = nn.Dropout(config.dropout)
self.proj_dropout = nn.Dropout(config.dropout)
self.register_buffer(
"mask",
torch.tril(
torch.ones(
config.n_positions,
config.n_positions,
dtype=torch.bool
)
)
)
self.last_attn = None
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(
~self.mask[:T, :T],
torch.finfo(att.dtype).min
)
att = F.softmax(att, dim=-1)
self.last_attn = att.detach()
att = self.attn_dropout(att)
out = att @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
out = self.proj(out)
out = self.proj_dropout(out)
return out
# =========================
# Transformer Block
# =========================
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln2 = nn.LayerNorm(config.n_embd)
# 🔥 FFN EXACTLY MATCHES TRAINING
self.ffn = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.dropout),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.ffn(self.ln2(x))
return x
# =========================
# TinyWay Language Model
# =========================
class TinyWayForCausalLM(PreTrainedModel):
config_class = TinyWayConfig
def __init__(self, config):
super().__init__(config)
self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Embedding(config.n_positions, config.n_embd)
self.blocks = nn.ModuleList([
Block(config) for _ in range(config.n_layer)
])
self.ln = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(
config.n_embd,
config.vocab_size,
bias=False
)
# weight tying
self.head.weight = self.token_emb.weight
self.dropout = nn.Dropout(config.dropout)
self.post_init()
def forward(
self,
input_ids,
labels=None,
attention_mask=None, # intentionally unused (causal LM)
**kwargs # 🔥 accept return_dict, use_cache, etc.
):
B, T = input_ids.shape
pos = torch.arange(T, device=input_ids.device)
x = self.token_emb(input_ids) + self.pos_emb(pos)
x = self.dropout(x)
for block in self.blocks:
x = block(x)
x = self.ln(x)
logits = self.head(x)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1)
)
return CausalLMOutput(
loss=loss,
logits=logits
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}