Spaces:
Running
Running
| # /// script | |
| # dependencies = [ | |
| # "torch", | |
| # "numpy", | |
| # "kernels", | |
| # ] | |
| # /// | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from kernels import get_kernel, get_local_kernel | |
| from utils import to_dtype, tensor_stats, set_seed, bench_context | |
| from config import ( | |
| NUM_EXPERTS, HIDDEN_SIZE, TOP_K, | |
| BATCH_SIZE, SEQ_LEN, DTYPE, DEVICE, | |
| WEIGHT_SEED, EXPERT_SEED, INPUT_SEED, GENERAL_SEED | |
| ) | |
| from pathlib import Path | |
| from collections import namedtuple | |
| import os | |
| # Discover the upstream artifact directory from env | |
| data_dir = os.environ.get('UVNOTE_INPUT_SAVE_DATA', '.') | |
| print(f"Loading weights from: {data_dir}") | |
| router_weight = torch.load(Path(data_dir) / 'router_weight.pt') | |
| router_bias = torch.load(Path(data_dir) / 'router_bias.pt') | |
| gate_up_proj = torch.load(Path(data_dir) / 'gate_up_proj.pt') | |
| gate_up_proj_bias = torch.load(Path(data_dir) / 'gate_up_proj_bias.pt') | |
| down_proj = torch.load(Path(data_dir) / 'down_proj.pt') | |
| down_proj_bias = torch.load(Path(data_dir) / 'down_proj_bias.pt') | |
| print("Loaded shared weights from artifacts") | |
| print(f"Router weight sum: {router_weight.sum().item():.6f}") | |
| print(f"Gate/up sum: {gate_up_proj.sum().item():.6f}") | |
| print(f"Down sum: {down_proj.sum().item():.6f}") | |
| def build_megablocks_model(device: torch.device, dtype: torch.dtype): | |
| # Download optimized kernels from the Hugging Face hub | |
| megablocks = get_kernel("kernels-community/megablocks") | |
| # megablocks = get_local_kernel( | |
| # Path("/home/ubuntu/Projects/megablocks-moe/build"), "megablocks") | |
| model = megablocks.layers.MegaBlocksMoeMLP() | |
| # Create attribute container for expert weights | |
| model.experts = namedtuple( | |
| "Experts", ["gate_up_proj", "gate_up_proj_bias", "down_proj", "down_proj_bias", "hidden_size"] | |
| ) | |
| # Use loaded router weights for consistency | |
| model.router = torch.nn.Linear(HIDDEN_SIZE, NUM_EXPERTS, device=device, dtype=dtype) | |
| with torch.no_grad(): | |
| model.router.weight.copy_(router_weight.to(dtype)) | |
| model.router.bias.copy_(router_bias.to(dtype)) | |
| # Attach loaded expert weights to the experts container | |
| e = model.experts | |
| e.alpha = 1.702 | |
| e.capacity_factor = 4 | |
| e.gate_up_proj = torch.nn.Parameter(gate_up_proj.clone().to(device, dtype=dtype)) | |
| e.gate_up_proj_bias = torch.nn.Parameter(gate_up_proj_bias.clone().to(device, dtype=dtype)) | |
| e.down_proj = torch.nn.Parameter(down_proj.clone().to(device, dtype=dtype)) | |
| e.down_proj_bias = torch.nn.Parameter(down_proj_bias.clone().to(device, dtype=dtype)) | |
| e.hidden_size = HIDDEN_SIZE | |
| # Log weight statistics for comparison | |
| print(f"[MegaBlocks] Router weight sum: {model.router.weight.sum().item():.6f}") | |
| print(f"[MegaBlocks] Gate/up projection shape: {tuple(e.gate_up_proj.shape)}, sum: {e.gate_up_proj.sum().item():.6f}") | |
| print(f"[MegaBlocks] Down projection shape: {tuple(e.down_proj.shape)}, sum: {e.down_proj.sum().item():.6f}") | |
| return model | |
| # Create a wrapper to match the interface of other implementations | |
| class MegaBlocksMoEWrapper(nn.Module): | |
| def __init__(self, megablocks_model): | |
| super().__init__() | |
| self.model = megablocks_model | |
| def forward(self, hidden_states): | |
| # MegaBlocks expects input in the format (batch, seq_len, hidden_dim) | |
| output, dummy_routing_weights = self.model(hidden_states) | |
| # Return output and dummy routing weights for consistency with other implementations | |
| # dummy_routing_weights = torch.zeros( | |
| # hidden_states.shape[0] * hidden_states.shape[1], | |
| # NUM_EXPERTS, | |
| # device=hidden_states.device, | |
| # dtype=hidden_states.dtype | |
| # ) | |
| return output, dummy_routing_weights | |
| # Run the model | |
| set_seed(GENERAL_SEED) | |
| device = torch.device(DEVICE) | |
| dtype = to_dtype(DTYPE) | |
| print("\n=== MegaBlocks Implementation ===") | |
| # Build MegaBlocks model with loaded weights | |
| megablocks_model = build_megablocks_model(device, dtype) | |
| model = MegaBlocksMoEWrapper(megablocks_model).to(device=device, dtype=dtype) | |
| # Benchmark the model using different input tensors on each iteration | |
| tokens = BATCH_SIZE * SEQ_LEN | |
| input_shape = (BATCH_SIZE, SEQ_LEN, HIDDEN_SIZE) | |
| with bench_context(warmup=10, iters=50, device=device, dtype=dtype, tokens=tokens, | |
| save_json="megablocks_results.json", input_shape=input_shape, input_seed_base=INPUT_SEED) as bench: | |
| output, stats = bench(model) | |
| print(f"\nOutput sum: {output[0].sum().item():.6f}") |