Spaces:
Sleeping
Sleeping
| import argparse | |
| from dataclasses import dataclass | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description='SmolKimi - DeepSeek V3 Inspired Model Training') | |
| # Model Architecture | |
| parser.add_argument('--block_size', type=int, default=128, help='Maximum sequence length') | |
| parser.add_argument('--batch_size', type=int, default=256, help='Training batch size') | |
| parser.add_argument('--embeddings_dims', type=int, default=384, help='Model embedding dimensions') | |
| parser.add_argument('--no_of_heads', type=int, default=8, help='Number of attention heads') | |
| parser.add_argument('--no_of_decoder_layers', type=int, default=6, help='Number of decoder layers') | |
| parser.add_argument('--latent_dim', type=int, default=64, help='Latent dimension for attention') | |
| # MoE Configuration | |
| parser.add_argument('--experts', type=int, default=8, help='Number of MoE experts') | |
| parser.add_argument('--top_experts', type=int, default=2, help='Number of experts to route to (top-k)') | |
| parser.add_argument('--use_shared_expert', action='store_true', default=True, help='Enable shared expert in MoE') | |
| parser.add_argument('--noisy_topk', action='store_true', default=False, help='Use noisy top-k routing') | |
| parser.add_argument('--useauxFreeLoadBalancingLoss', action='store_true', default=True, help='Use auxiliary-free load balancing loss') | |
| parser.add_argument('--aux_free_bias_update_rate', type=float, default=0.001, help='Bias update rate for load balancing') | |
| parser.add_argument('--loss_scale', type=float, default=0.3, help='Loss scaling factor') | |
| # Training Hyperparameters | |
| parser.add_argument('--epochs', type=int, default=1, help='Number of training epochs') | |
| parser.add_argument('--max_lr', type=float, default=6e-4, help='Maximum learning rate') | |
| parser.add_argument('--weight_decay_optim', type=float, default=0.1, help='Weight decay for optimizer') | |
| parser.add_argument('--beta_1', type=float, default=0.9, help='Beta1 for optimizer') | |
| parser.add_argument('--beta_2', type=float, default=0.95, help='Beta2 for optimizer') | |
| parser.add_argument('--eps', type=float, default=1e-8, help='Epsilon for optimizer') | |
| parser.add_argument('--clip', type=float, default=1.0, help='Gradient clipping value') | |
| # Regularization | |
| parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate') | |
| parser.add_argument('--attn_dropout', type=float, default=0.1, help='Attention dropout rate') | |
| # System Configuration | |
| parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda/cpu)') | |
| parser.add_argument('--use_checkpointing', action='store_true', default=False, help='Use gradient checkpointing') | |
| parser.add_argument('--use_liger', action='store_true', default=True, help='Use Liger kernels for optimization') | |
| parser.add_argument('--ignore_pad_token_in_loss', action='store_true', default=True, help='Ignore padding tokens in loss calculation') | |
| # Data Configuration | |
| parser.add_argument('--vocab_size', type=int, default=32000 + 1 , help='Vocabulary size (updated based on tokenizer)') | |
| parser.add_argument('--base_freq', type=int, default=100000, help='Base frequency for positional encoding') | |
| parser.add_argument('--hf_token', type=str, default=None, help='Hugging Face token for accessing gated models like Llama-2') | |
| # Dataset Selection | |
| parser.add_argument('--dataset', type=str, default='tinystories', choices=['tinystories', 'fineweb', 'tinyshakespeare'], help='Dataset to use for training') | |
| # Generation Parameters | |
| parser.add_argument('--generation_max_length', type=int, default=50, help='Maximum length for text generation') | |
| parser.add_argument('--generation_top_k', type=int, default=50, help='Top-k value for sampling during generation') | |
| parser.add_argument('--generation_temperature', type=float, default=1.0, help='Temperature for sampling during generation') | |
| # Logging and Checkpointing | |
| parser.add_argument('--log_interval', type=int, default=100, help='Steps between logging') | |
| parser.add_argument('--save_interval', type=int, default=2000, help='Steps between saving checkpoints') | |
| parser.add_argument('--eval_interval', type=int, default=400, help='Steps between evaluation') | |
| parser.add_argument('--eval_iters', type=int, default=400, help='Number of iterations for evaluation') | |
| parser.add_argument('--warmup_iters', type=int, default=400, help='Number of warmup iterations') | |
| parser.add_argument('--total_iters', type=int, default=10000, help='Total training iterations') | |
| parser.add_argument('--lr_decay_iters', type=int, default=10000, help='Learning rate decay iterations') | |
| parser.add_argument('--wandb_project', type=str, default='smolkimi', help='Wandb project name') | |
| parser.add_argument('--wandb_run_name', type=str, default=None, help='Wandb run name') | |
| # Batch Size Configuration | |
| parser.add_argument('--total_batch_size', type=int, default=524288, help='Total batch size for gradient accumulation') | |
| parser.add_argument('--micro_batch_size', type=int, default=None, help='Micro batch size (defaults to batch_size)') | |
| # Distributed Training | |
| parser.add_argument('--use_ddp', action='store_true', default=False, help='Use distributed data parallel') | |
| return parser.parse_args() | |
| class ModelArgs: | |
| def __init__(self, args=None): | |
| if args is None: | |
| args = get_args() | |
| # Model Architecture | |
| self.block_size = args.block_size | |
| self.batch_size = args.batch_size | |
| self.embeddings_dims = args.embeddings_dims | |
| self.no_of_heads = args.no_of_heads | |
| self.no_of_decoder_layers = args.no_of_decoder_layers | |
| self.latent_dim = args.latent_dim | |
| # MoE Configuration | |
| self.experts = args.experts | |
| self.top_experts = args.top_experts | |
| self.use_shared_expert = args.use_shared_expert | |
| self.noisy_topk = args.noisy_topk | |
| self.useauxFreeLoadBalancingLoss = args.useauxFreeLoadBalancingLoss | |
| self.aux_free_bias_update_rate = args.aux_free_bias_update_rate | |
| self.loss_scale = args.loss_scale | |
| # Training Hyperparameters | |
| self.epochs = args.epochs | |
| self.max_lr = args.max_lr | |
| self.weight_decay_optim = args.weight_decay_optim | |
| self.beta_1 = args.beta_1 | |
| self.beta_2 = args.beta_2 | |
| self.eps = args.eps | |
| self.clip = args.clip | |
| # Regularization | |
| self.dropout = args.dropout | |
| self.attn_dropout = args.attn_dropout | |
| # System Configuration | |
| self.device = args.device | |
| self.use_checkpointing = args.use_checkpointing | |
| self.use_liger = args.use_liger | |
| self.ignore_pad_token_in_loss = args.ignore_pad_token_in_loss | |
| # Data Configuration | |
| self.vocab_size = args.vocab_size | |
| self.base_freq = args.base_freq | |
| self.hf_token = args.hf_token | |
| self.dataset = args.dataset | |
| # Generation Parameters | |
| self.generation_max_length = args.generation_max_length | |
| self.generation_top_k = args.generation_top_k | |
| self.generation_temperature = args.generation_temperature | |
| # Logging and Checkpointing | |
| self.log_interval = args.log_interval | |
| self.save_interval = args.save_interval | |
| self.eval_interval = args.eval_interval | |
| self.eval_iters = args.eval_iters | |
| self.warmup_iters = args.warmup_iters | |
| self.total_iters = args.total_iters | |
| self.lr_decay_iters = args.lr_decay_iters | |
| self.wandb_project = args.wandb_project | |
| self.wandb_run_name = args.wandb_run_name | |
| # Batch Size Configuration | |
| self.total_batch_size = args.total_batch_size | |
| self.micro_batch_size = args.micro_batch_size if args.micro_batch_size else args.batch_size | |
| self.gradient_accumulation_steps = self.total_batch_size // (self.micro_batch_size * self.block_size) | |
| # Calculated parameters | |
| self.min_lr = 0.1 * self.max_lr | |
| self.save_checkpoint_iter = self.save_interval | |
| self.eval_check = self.eval_interval | |
| # Distributed Training | |
| self.use_ddp = args.use_ddp | |