File size: 8,322 Bytes
3b70c60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import argparse
from dataclasses import dataclass

def get_args():
    parser = argparse.ArgumentParser(description='SmolKimi - DeepSeek V3 Inspired Model Training')
    
    # Model Architecture
    parser.add_argument('--block_size', type=int, default=128, help='Maximum sequence length')
    parser.add_argument('--batch_size', type=int, default=256, help='Training batch size')
    parser.add_argument('--embeddings_dims', type=int, default=384, help='Model embedding dimensions')
    parser.add_argument('--no_of_heads', type=int, default=8, help='Number of attention heads')
    parser.add_argument('--no_of_decoder_layers', type=int, default=6, help='Number of decoder layers')
    parser.add_argument('--latent_dim', type=int, default=64, help='Latent dimension for attention')
    
    # MoE Configuration
    parser.add_argument('--experts', type=int, default=8, help='Number of MoE experts')
    parser.add_argument('--top_experts', type=int, default=2, help='Number of experts to route to (top-k)')
    parser.add_argument('--use_shared_expert', action='store_true', default=True, help='Enable shared expert in MoE')
    parser.add_argument('--noisy_topk', action='store_true', default=False, help='Use noisy top-k routing')
    parser.add_argument('--useauxFreeLoadBalancingLoss', action='store_true', default=True, help='Use auxiliary-free load balancing loss')
    parser.add_argument('--aux_free_bias_update_rate', type=float, default=0.001, help='Bias update rate for load balancing')
    parser.add_argument('--loss_scale', type=float, default=0.3, help='Loss scaling factor')
    
    # Training Hyperparameters
    parser.add_argument('--epochs', type=int, default=1, help='Number of training epochs')
    parser.add_argument('--max_lr', type=float, default=6e-4, help='Maximum learning rate')
    parser.add_argument('--weight_decay_optim', type=float, default=0.1, help='Weight decay for optimizer')
    parser.add_argument('--beta_1', type=float, default=0.9, help='Beta1 for optimizer')
    parser.add_argument('--beta_2', type=float, default=0.95, help='Beta2 for optimizer')
    parser.add_argument('--eps', type=float, default=1e-8, help='Epsilon for optimizer')
    parser.add_argument('--clip', type=float, default=1.0, help='Gradient clipping value')
    
    # Regularization
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout rate')
    parser.add_argument('--attn_dropout', type=float, default=0.1, help='Attention dropout rate')
    
    # System Configuration
    parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda/cpu)')
    parser.add_argument('--use_checkpointing', action='store_true', default=False, help='Use gradient checkpointing')
    parser.add_argument('--use_liger', action='store_true', default=True, help='Use Liger kernels for optimization')
    parser.add_argument('--ignore_pad_token_in_loss', action='store_true', default=True, help='Ignore padding tokens in loss calculation')
    
    # Data Configuration
    parser.add_argument('--vocab_size', type=int, default=32000 + 1 , help='Vocabulary size (updated based on tokenizer)')
    parser.add_argument('--base_freq', type=int, default=100000, help='Base frequency for positional encoding')
    parser.add_argument('--hf_token', type=str, default=None, help='Hugging Face token for accessing gated models like Llama-2')
    
    # Dataset Selection
    parser.add_argument('--dataset', type=str, default='tinystories', choices=['tinystories', 'fineweb', 'tinyshakespeare'], help='Dataset to use for training')
    
    # Generation Parameters
    parser.add_argument('--generation_max_length', type=int, default=50, help='Maximum length for text generation')
    parser.add_argument('--generation_top_k', type=int, default=50, help='Top-k value for sampling during generation')
    parser.add_argument('--generation_temperature', type=float, default=1.0, help='Temperature for sampling during generation')
    
    # Logging and Checkpointing
    parser.add_argument('--log_interval', type=int, default=100, help='Steps between logging')
    parser.add_argument('--save_interval', type=int, default=2000, help='Steps between saving checkpoints')
    parser.add_argument('--eval_interval', type=int, default=400, help='Steps between evaluation')
    parser.add_argument('--eval_iters', type=int, default=400, help='Number of iterations for evaluation')
    parser.add_argument('--warmup_iters', type=int, default=400, help='Number of warmup iterations')
    parser.add_argument('--total_iters', type=int, default=10000, help='Total training iterations')
    parser.add_argument('--lr_decay_iters', type=int, default=10000, help='Learning rate decay iterations')
    parser.add_argument('--wandb_project', type=str, default='smolkimi', help='Wandb project name')
    parser.add_argument('--wandb_run_name', type=str, default=None, help='Wandb run name')
    
    # Batch Size Configuration
    parser.add_argument('--total_batch_size', type=int, default=524288, help='Total batch size for gradient accumulation')
    parser.add_argument('--micro_batch_size', type=int, default=None, help='Micro batch size (defaults to batch_size)')
    
    # Distributed Training
    parser.add_argument('--use_ddp', action='store_true', default=False, help='Use distributed data parallel')
    
    return parser.parse_args()

@dataclass
class ModelArgs:
    def __init__(self, args=None):
        if args is None:
            args = get_args()
        
        # Model Architecture
        self.block_size = args.block_size
        self.batch_size = args.batch_size
        self.embeddings_dims = args.embeddings_dims
        self.no_of_heads = args.no_of_heads
        self.no_of_decoder_layers = args.no_of_decoder_layers
        self.latent_dim = args.latent_dim
        
        # MoE Configuration
        self.experts = args.experts
        self.top_experts = args.top_experts
        self.use_shared_expert = args.use_shared_expert
        self.noisy_topk = args.noisy_topk
        self.useauxFreeLoadBalancingLoss = args.useauxFreeLoadBalancingLoss
        self.aux_free_bias_update_rate = args.aux_free_bias_update_rate
        self.loss_scale = args.loss_scale
        
        # Training Hyperparameters
        self.epochs = args.epochs
        self.max_lr = args.max_lr
        self.weight_decay_optim = args.weight_decay_optim
        self.beta_1 = args.beta_1
        self.beta_2 = args.beta_2
        self.eps = args.eps
        self.clip = args.clip
        
        # Regularization
        self.dropout = args.dropout
        self.attn_dropout = args.attn_dropout
        
        # System Configuration
        self.device = args.device
        self.use_checkpointing = args.use_checkpointing
        self.use_liger = args.use_liger
        self.ignore_pad_token_in_loss = args.ignore_pad_token_in_loss
        
        # Data Configuration
        self.vocab_size = args.vocab_size
        self.base_freq = args.base_freq
        self.hf_token = args.hf_token
        self.dataset = args.dataset
        
        # Generation Parameters
        self.generation_max_length = args.generation_max_length
        self.generation_top_k = args.generation_top_k
        self.generation_temperature = args.generation_temperature
        
        # Logging and Checkpointing
        self.log_interval = args.log_interval
        self.save_interval = args.save_interval
        self.eval_interval = args.eval_interval
        self.eval_iters = args.eval_iters
        self.warmup_iters = args.warmup_iters
        self.total_iters = args.total_iters
        self.lr_decay_iters = args.lr_decay_iters
        self.wandb_project = args.wandb_project
        self.wandb_run_name = args.wandb_run_name
        
        # Batch Size Configuration
        self.total_batch_size = args.total_batch_size
        self.micro_batch_size = args.micro_batch_size if args.micro_batch_size else args.batch_size
        self.gradient_accumulation_steps = self.total_batch_size // (self.micro_batch_size * self.block_size)
        
        # Calculated parameters
        self.min_lr = 0.1 * self.max_lr
        self.save_checkpoint_iter = self.save_interval
        self.eval_check = self.eval_interval
        
        # Distributed Training
        self.use_ddp = args.use_ddp