Girinath11 commited on
Commit
16896b6
·
verified ·
1 Parent(s): bea7101

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +441 -181
train.py CHANGED
@@ -10,197 +10,389 @@ import argparse
10
  import time
11
  import math
12
  import glob
13
- from typing import Dict, List
14
  from tqdm import tqdm
15
  import numpy as np
16
  import gc
 
17
  from collections import defaultdict
18
  import multiprocessing
 
19
  # Import custom modules
20
  try:
21
  from model_slm import MixtureOfRecursions, count_parameters, TextGenerator
22
  from custom_tokenizer import TechnicalTokenizer
23
  except ImportError as e:
24
- print(f"Import error: {e}")
25
- exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class FastTechnicalTextDataset(Dataset):
27
- """Ultra-fast dataset with aggressive optimizations for 4-5hr training"""
28
- def __init__(self, data_file: str, tokenizer: TechnicalTokenizer, max_length: int = 128, max_examples: int = 50000):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  self.tokenizer = tokenizer
30
  self.max_length = max_length
31
  self.pad_token_id = tokenizer.vocab.get('<pad>', 0)
32
- self.max_examples = max_examples
33
- print(f"FAST DATASET LOADING")
34
- print(f"Data file: {data_file}")
35
- print(f"Max sequence length: {max_length}")
36
- print(f"Max examples: {max_examples}")
37
- start_time = time.time()
38
  self.examples = []
39
- self._fast_load_data(data_file)
40
- load_time = time.time() - start_time
41
- print(f" Loaded {len(self.examples)} examples in {load_time:.1f}s")
 
42
  self._tensorize_data()
 
 
 
 
43
  gc.collect()
44
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
45
- def _fast_load_data(self, data_file: str):
46
- print("🔍 Fast reading file...")
 
47
  with open(data_file, 'r', encoding='utf-8') as f:
48
- lines = f.readlines()
49
- print(f"File has {len(lines)} lines")
 
50
  good_examples = []
51
- seen_hashes = set()
 
52
  for line in lines[:self.max_examples * 3]:
53
  line = line.strip()
54
- if (50 <= len(line) <= 400 and
 
55
  line.count(' ') >= 8 and
56
  not line.lower().startswith(('http', 'www', 'ftp')) and
57
- line.count('.') <= len(line) * 0.1):
 
58
  line_hash = hash(line[:100])
59
  if line_hash not in seen_hashes:
60
  seen_hashes.add(line_hash)
61
  good_examples.append(line)
62
  if len(good_examples) >= self.max_examples:
63
- break
64
- print(f"After fast filtering: {len(good_examples)} quality examples")
 
 
65
  batch_size = 1000
66
  for i in range(0, len(good_examples), batch_size):
67
- batch = good_examples[i:i+batch_size]
68
  for line in batch:
69
  try:
70
  if not line.endswith('<|endoftext|>'):
71
- line += ' <|endoftext|>'
72
  tokens = self.tokenizer.encode_ids(line, add_special_tokens=True)
73
  if 30 <= len(tokens) <= self.max_length:
74
  if len(tokens) < self.max_length:
75
- tokens = tokens + [self.pad_token_id] * (self.max_length - len(tokens))
76
  self.examples.append(tokens)
77
- except:
 
78
  continue
79
  if i % 5000 == 0:
80
- print(f"Processed {len(self.examples)} examples...")
81
- print(f"Final dataset: {len(self.examples)} examples")
82
- def _tensorize_data(self):
83
- print("Pre-tensorizing data for maximum speed...")
84
- seq_len = self.max_length - 1
 
 
 
85
  tensorized_examples = []
 
86
  for tokens in self.examples:
87
- if len(tokens) < self.max_length:
88
- continue
89
  input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
90
- targets = torch.tensor(tokens[1:], dtype=torch.long)
91
  original_len = next((i for i, x in enumerate(tokens) if x == self.pad_token_id), self.max_length)
92
  mask_len = min(original_len, seq_len)
93
  attention_mask = torch.zeros(seq_len, dtype=torch.long)
94
- attention_mask[:mask_len] = 1
95
  tensorized_examples.append({
96
  'input_ids': input_ids,
97
  'targets': targets,
98
  'attention_mask': attention_mask
99
  })
 
100
  self.examples = tensorized_examples
101
- print("All data pre-tensorized")
102
- def __len__(self):
103
- return len(self.examples)
104
- def __getitem__(self, idx):
 
 
 
 
105
  return self.examples[idx]
 
106
  class FastCosineScheduler:
107
- def __init__(self, optimizer, total_steps: int, warmup_ratio: float = 0.05):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  self.optimizer = optimizer
109
  self.total_steps = total_steps
110
  self.warmup_steps = int(total_steps * warmup_ratio)
111
  self.base_lr = optimizer.param_groups[0]['lr']
112
- self.step_count = 0
113
- def step(self):
 
 
 
 
 
 
 
114
  self.step_count += 1
115
  if self.step_count <= self.warmup_steps:
116
  lr = self.base_lr * self.step_count / self.warmup_steps
117
  else:
118
  progress = (self.step_count - self.warmup_steps) / (self.total_steps - self.warmup_steps)
119
  lr = self.base_lr * 0.5 * (1 + math.cos(math.pi * progress))
 
120
  for param_group in self.optimizer.param_groups:
121
  param_group['lr'] = lr
122
  return lr
 
123
  class UltraFastTrainer:
124
- def __init__(self, model, tokenizer, train_dataset, val_dataset=None, config=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  self.model = model
126
  self.tokenizer = tokenizer
127
  self.train_dataset = train_dataset
128
  self.val_dataset = val_dataset
129
- self.config = config or {}
130
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
131
- self.model.to(self.device)
 
 
132
  self._fast_init_weights()
133
- self._setup_fast_optimizer()
134
- epochs = self.config.get('epochs', 15)
135
- batch_size = self.config.get('batch_size', 16)
 
136
  total_steps = len(train_dataset) // batch_size * epochs
137
  self.scheduler = FastCosineScheduler(self.optimizer, total_steps)
138
- self.scaler = GradScaler()
139
  self.global_step = 0
140
  self.best_loss = float('inf')
141
- self.grad_accum_steps = self.config.get('gradient_accumulation_steps', 1)
142
- self.eval_every = self.config.get('eval_every', 500)
143
- def _fast_init_weights(self):
144
- def fast_init(module):
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  if isinstance(module, nn.Linear):
146
  nn.init.normal_(module.weight, std=0.02)
147
  if module.bias is not None:
148
  nn.init.zeros_(module.bias)
149
  elif isinstance(module, nn.Embedding):
150
  nn.init.normal_(module.weight, std=0.02)
151
- self.model.apply(fast_init)
152
- def _setup_fast_optimizer(self):
153
- lr = self.config.get('learning_rate', 5e-4)
 
 
 
154
  params = [p for p in self.model.parameters() if p.requires_grad]
155
- self.optimizer = optim.AdamW(params, lr=lr, betas=(0.9, 0.99), weight_decay=0.01, eps=1e-6)
156
- def compute_fast_loss(self, logits, targets, mask):
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  logits_flat = logits.view(-1, logits.size(-1))
158
  targets_flat = targets.view(-1)
159
  mask_flat = mask.view(-1).bool()
 
160
  if not mask_flat.any():
161
  return torch.tensor(0.0, device=logits.device, requires_grad=True)
162
- loss = F.cross_entropy(logits_flat[mask_flat], targets_flat[mask_flat])
163
- return loss
 
164
  def train_epoch_fast(self, epoch: int, dataloader: DataLoader) -> Dict[str, float]:
 
 
 
 
 
 
 
 
 
 
165
  self.model.train()
166
  total_loss = 0
167
  num_batches = 0
168
- start_time = time.time()
 
169
  progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False, miniters=50)
170
  for batch_idx, batch in enumerate(progress_bar):
171
  input_ids = batch['input_ids'].to(self.device, non_blocking=True)
172
  targets = batch['targets'].to(self.device, non_blocking=True)
173
- mask = batch['attention_mask'].to(self.device, non_blocking=True)
174
- with autocast():
 
175
  logits, comp_loss = self.model(input_ids, mask)
176
  lm_loss = self.compute_fast_loss(logits, targets, mask)
177
  total_loss_step = lm_loss + 0.0001 * comp_loss
178
  if self.grad_accum_steps > 1:
179
- total_loss_step = total_loss_step / self.grad_accum_steps
180
- self.scaler.scale(total_loss_step).backward()
181
- if (batch_idx + 1) % self.grad_accum_steps == 0:
182
- self.scaler.unscale_(self.optimizer)
183
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
184
- self.scaler.step(self.optimizer)
185
- self.scaler.update()
186
- self.optimizer.zero_grad(set_to_none=True)
187
- self.scheduler.step()
188
- self.global_step += 1
 
 
 
 
 
 
 
 
 
 
 
189
  total_loss += lm_loss.item()
190
  num_batches += 1
 
191
  if batch_idx % 100 == 0:
192
  current_loss = total_loss / num_batches
193
  progress_bar.set_postfix({'loss': f"{current_loss:.3f}", 'ppl': f"{math.exp(min(current_loss, 10)):.1f}"})
194
- if batch_idx % 200 == 0 and batch_idx > 0:
195
- torch.cuda.empty_cache()
196
- epoch_time = time.time() - start_time
 
197
  avg_loss = total_loss / max(num_batches, 1)
198
- return {'loss': avg_loss, 'perplexity': math.exp(min(avg_loss, 10)), 'epoch_time_min': epoch_time / 60}
 
 
 
 
 
199
  def validate_fast(self, dataloader: DataLoader) -> Dict[str, float]:
 
 
 
 
 
 
 
 
 
200
  self.model.eval()
201
  total_loss = 0
202
  num_batches = 0
203
- max_val_batches = min(100, len(dataloader))
 
204
  with torch.no_grad():
205
  for batch_idx, batch in enumerate(dataloader):
206
  if batch_idx >= max_val_batches:
@@ -208,16 +400,32 @@ class UltraFastTrainer:
208
  input_ids = batch['input_ids'].to(self.device, non_blocking=True)
209
  targets = batch['targets'].to(self.device, non_blocking=True)
210
  mask = batch['attention_mask'].to(self.device, non_blocking=True)
211
- with autocast():
 
212
  logits, _ = self.model(input_ids, mask)
213
  loss = self.compute_fast_loss(logits, targets, mask)
 
214
  total_loss += loss.item()
215
- num_batches += 1
 
216
  avg_loss = total_loss / max(num_batches, 1)
217
- return {'loss': avg_loss, 'perplexity': math.exp(min(avg_loss, 10))}
218
- def save_checkpoint_fast(self, epoch: int, metrics: Dict, save_dir: str = "checkpoints"):
 
 
 
 
 
 
 
 
 
 
 
 
219
  os.makedirs(save_dir, exist_ok=True)
220
  val_loss = metrics.get('val_loss', metrics.get('loss', float('inf')))
 
221
  if val_loss < self.best_loss:
222
  self.best_loss = val_loss
223
  checkpoint = {
@@ -225,159 +433,211 @@ class UltraFastTrainer:
225
  'model_state_dict': self.model.state_dict(),
226
  'optimizer_state_dict': self.optimizer.state_dict(),
227
  'metrics': metrics,
228
- 'scaler_state_dict': self.scaler.state_dict()
229
  }
230
  best_path = os.path.join(save_dir, "best_model.pt")
231
  torch.save(checkpoint, best_path)
232
- print(f"New best! Loss: {val_loss:.4f}")
233
  return best_path
234
- return None
235
- def train_ultra_fast(self, num_epochs: int = 15, batch_size: int = 16):
236
- print(f"\n ULTRA-FAST TRAINING")
237
- print(f" Target: Loss < 2.0, PPL < 12")
238
- print(f" Time target: 4-5 hours")
239
- print(f" Epochs: {num_epochs}")
240
- print(f" Batch size: {batch_size}")
241
- print("-" * 60)
 
 
 
 
 
 
 
 
242
  train_loader = DataLoader(
243
  self.train_dataset,
244
  batch_size=batch_size,
245
  shuffle=True,
246
- num_workers=4,
247
- pin_memory=True,
248
  persistent_workers=True,
249
  drop_last=True
250
- )
 
251
  val_loader = None
252
  if self.val_dataset:
253
  val_loader = DataLoader(
254
  self.val_dataset,
255
  batch_size=batch_size * 2,
256
  shuffle=False,
257
- num_workers=2,
258
- pin_memory=True
259
- )
 
260
  total_start_time = time.time()
261
- history = []
 
262
  for epoch in range(1, num_epochs + 1):
263
- epoch_start = time.time()
264
- print(f"\n EPOCH {epoch}/{num_epochs}")
265
- train_metrics = self.train_epoch_fast(epoch, train_loader)
266
  val_metrics = {}
267
  if val_loader and (epoch % 2 == 0 or epoch == num_epochs):
268
- val_metrics = self.validate_fast(val_loader)
269
- epoch_time = time.time() - epoch_start
 
270
  epoch_info = {
271
  'epoch': epoch,
272
  'train_loss': train_metrics['loss'],
273
  'train_ppl': train_metrics['perplexity'],
274
- 'epoch_time_min': epoch_time / 60
275
  }
276
  if val_metrics:
277
  epoch_info.update({'val_loss': val_metrics['loss'], 'val_ppl': val_metrics['perplexity']})
278
- history.append(epoch_info)
 
 
279
  elapsed_hours = (time.time() - total_start_time) / 3600
280
- remaining_hours = elapsed_hours * (num_epochs - epoch) / epoch
281
- print(f"\n EPOCH {epoch} RESULTS:")
282
- print(f" Epoch time: {epoch_time/60:.1f} min")
283
- print(f" Total elapsed: {elapsed_hours:.1f}h")
284
- print(f" Est. remaining: {remaining_hours:.1f}h")
285
- print(f" Train Loss: {train_metrics['loss']:.4f}")
286
- print(f" Train PPL: {train_metrics['perplexity']:.1f}")
 
287
  if val_metrics:
288
- print(f" Val Loss: {val_metrics['loss']:.4f}")
289
- print(f" Val PPL: {val_metrics['perplexity']:.1f}")
 
290
  current_loss = val_metrics.get('loss', train_metrics['loss'])
291
  current_ppl = val_metrics.get('perplexity', train_metrics['perplexity'])
292
  if current_loss < 2.0 and current_ppl < 12:
293
- print(f" TARGETS ACHIEVED!")
294
- print(f" Loss: {current_loss:.4f} < 2.0")
295
- print(f" PPL: {current_ppl:.1f} < 12")
296
  combined_metrics = {**train_metrics}
297
  if val_metrics:
298
  combined_metrics.update({f"val_{k}": v for k, v in val_metrics.items()})
299
- self.save_checkpoint_fast(epoch, combined_metrics)
300
- torch.cuda.empty_cache()
301
- gc.collect()
 
 
 
302
  if current_loss < 1.8 and current_ppl < 10:
303
- print(f"EARLY STOPPING - Excellent performance achieved!")
304
- break
305
- total_time = time.time() - total_start_time
306
- print(f"\n TRAINING COMPLETED!")
307
- print(f"Total time: {total_time/3600:.1f} hours")
308
- print(f" Best loss: {self.best_loss:.4f}")
309
  return history
310
- def run_ultra_fast_training():
311
- parser = argparse.ArgumentParser(description="Ultra-Fast Training for 4-5 Hours")
312
- parser.add_argument("--train_file", default=None)
313
- parser.add_argument("--val_file", default=None)
314
- parser.add_argument("--tokenizer_dir", default="tokenizer")
315
- parser.add_argument("--max_examples", type=int, default=50000)
316
- parser.add_argument("--d_model", type=int, default=384)
317
- parser.add_argument("--n_layers", type=int, default=6)
318
- parser.add_argument("--n_heads", type=int, default=6)
319
- parser.add_argument("--max_seq_len", type=int, default=128)
320
- parser.add_argument("--epochs", type=int, default=15)
321
- parser.add_argument("--batch_size", type=int, default=16)
322
- parser.add_argument("--learning_rate", type=float, default=5e-4)
323
- parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
324
- parser.add_argument("--eval_every", type=int, default=500)
325
- args = parser.parse_args()
 
 
 
 
 
 
 
 
 
326
  torch.manual_seed(42)
327
- np.random.seed(42)
328
- print("Training My Model")
329
- print("-" * 50)
 
330
  if args.train_file is None:
331
  patterns = ["*train*.txt", "*_train.txt"]
332
  files = []
333
  for pattern in patterns:
334
  files.extend(glob.glob(pattern))
335
- files.extend(glob.glob(f"split_data/{pattern}"))
336
- files.extend(glob.glob(f"data/{pattern}"))
337
  if files:
338
  args.train_file = files[0]
339
- print(f"Found: {args.train_file}")
340
  else:
341
- print(" No training files found!")
342
- return 1
343
- tokenizer = TechnicalTokenizer()
344
  try:
 
345
  tokenizer.load(args.tokenizer_dir)
346
- print(f"Tokenizer loaded. Vocab size: {tokenizer.get_vocab_size()}")
347
  except Exception as e:
348
- print(f" Tokenizer error: {e}")
349
- return 1
350
- print(" Creating ultra-fast dataset...")
351
- train_dataset = FastTechnicalTextDataset(
352
- args.train_file, tokenizer, args.max_seq_len, args.max_examples
353
- )
 
 
 
 
 
 
354
  val_dataset = None
355
  if args.val_file and os.path.exists(args.val_file):
356
- val_dataset = FastTechnicalTextDataset(
357
- args.val_file, tokenizer, args.max_seq_len, max_examples=5000
358
- )
359
- model = MixtureOfRecursions(
360
- vocab_size=tokenizer.get_vocab_size(),
361
- d_model=args.d_model,
362
- n_layers=args.n_layers,
363
- n_heads=args.n_heads,
364
- max_seq_len=args.max_seq_len - 1, # Pass the actual sequence length to the model
365
- padding_idx=tokenizer.vocab.get('<pad>', 0)
366
- )
 
 
 
 
 
 
 
 
 
 
 
367
  config = {
368
  'learning_rate': args.learning_rate,
369
  'gradient_accumulation_steps': args.gradient_accumulation_steps,
370
  'eval_every': args.eval_every,
371
  'batch_size': args.batch_size,
372
  'epochs': args.epochs
373
- }
374
- trainer = UltraFastTrainer(model, tokenizer, train_dataset, val_dataset, config)
375
- print(f"\n START TRAINING")
376
- results = trainer.train_ultra_fast(args.epochs, args.batch_size)
377
- with open('ultra_fast_results.json', 'w') as f:
378
- json.dump(results, f, indent=2)
379
- print("\n Training Completed!")
380
- print(" Results saved to: ultra_fast_results.json")
381
- return 0
 
 
 
 
 
 
382
  if __name__ == "__main__":
383
  exit(run_ultra_fast_training())
 
10
  import time
11
  import math
12
  import glob
13
+ from typing import Dict, List, Optional
14
  from tqdm import tqdm
15
  import numpy as np
16
  import gc
17
+ import logging
18
  from collections import defaultdict
19
  import multiprocessing
20
+
21
  # Import custom modules
22
  try:
23
  from model_slm import MixtureOfRecursions, count_parameters, TextGenerator
24
  from custom_tokenizer import TechnicalTokenizer
25
  except ImportError as e:
26
+ raise ImportError(f"Failed to import custom modules: {e}")
27
+
28
+ # Constants for configuration
29
+ DEFAULT_MAX_LENGTH = 128
30
+ DEFAULT_MAX_EXAMPLES = 50000
31
+ DEFAULT_D_MODEL = 384
32
+ DEFAULT_N_LAYERS = 6
33
+ DEFAULT_N_HEADS = 6
34
+ DEFAULT_EPOCHS = 15
35
+ DEFAULT_BATCH_SIZE = 16
36
+ DEFAULT_LEARNING_RATE = 5e-4
37
+ DEFAULT_GRAD_ACCUM_STEPS = 1
38
+ DEFAULT_EVAL_EVERY = 500
39
+ DEFAULT_WARMUP_RATIO = 0.05
40
+ DEFAULT_CHECKPOINT_DIR = "checkpoints"
41
+ DEFAULT_LOG_LEVEL = "INFO"
42
+
43
+ # Set up logging
44
+ logging.basicConfig(
45
+ level=DEFAULT_LOG_LEVEL,
46
+ format="%(asctime)s [%(levelname)s] %(message)s",
47
+ handlers=[
48
+ logging.StreamHandler(),
49
+ logging.FileHandler("training.log")
50
+ ]
51
+ )
52
+ logger = logging.getLogger(__name__)
53
+
54
  class FastTechnicalTextDataset(Dataset):
55
+ """Optimized dataset for fast loading and processing of technical text."""
56
+
57
+ def __init__(
58
+ self,
59
+ data_file: str,
60
+ tokenizer: TechnicalTokenizer,
61
+ max_length: int = DEFAULT_MAX_LENGTH,
62
+ max_examples: int = DEFAULT_MAX_EXAMPLES
63
+ ):
64
+ """
65
+ Initialize the dataset with optimized loading.
66
+
67
+ Args:
68
+ data_file (str): Path to the training data file.
69
+ tokenizer (TechnicalTokenizer): Tokenizer for encoding text.
70
+ max_length (int): Maximum sequence length.
71
+ max_examples (int): Maximum number of examples to load.
72
+
73
+ Raises:
74
+ FileNotFoundError: If the data file does not exist.
75
+ ValueError: If max_length or max_examples is invalid.
76
+ """
77
+ if not os.path.exists(data_file):
78
+ raise FileNotFoundError(f"Data file not found: {data_file}")
79
+ if max_length <= 0 or max_examples <= 0:
80
+ raise ValueError("max_length and max_examples must be positive")
81
+
82
  self.tokenizer = tokenizer
83
  self.max_length = max_length
84
  self.pad_token_id = tokenizer.vocab.get('<pad>', 0)
85
+ self.max_examples = max_examples
 
 
 
 
 
86
  self.examples = []
87
+
88
+ logger.info(f"Loading dataset from {data_file} with max_length={max_length}, max_examples={max_examples}")
89
+ start_time = time.time()
90
+ self._fast_load_data(data_file)
91
  self._tensorize_data()
92
+ logger.info(f"Loaded {len(self.examples)} examples in {time.time() - start_time:.1f}s")
93
+
94
+ if torch.cuda.is_available():
95
+ torch.cuda.empty_cache()
96
  gc.collect()
97
+
98
+ def _fast_load_data(self, data_file: str) -> None:
99
+ """Load and filter data efficiently."""
100
+ logger.info("Reading and filtering data...")
101
  with open(data_file, 'r', encoding='utf-8') as f:
102
+ lines = f.readlines()
103
+
104
+ logger.info(f"File contains {len(lines)} lines")
105
  good_examples = []
106
+ seen_hashes = set()
107
+
108
  for line in lines[:self.max_examples * 3]:
109
  line = line.strip()
110
+ if (
111
+ 50 <= len(line) <= 400 and
112
  line.count(' ') >= 8 and
113
  not line.lower().startswith(('http', 'www', 'ftp')) and
114
+ line.count('.') <= len(line) * 0.1
115
+ ):
116
  line_hash = hash(line[:100])
117
  if line_hash not in seen_hashes:
118
  seen_hashes.add(line_hash)
119
  good_examples.append(line)
120
  if len(good_examples) >= self.max_examples:
121
+ break
122
+
123
+ logger.info(f"Filtered to {len(good_examples)} quality examples")
124
+
125
  batch_size = 1000
126
  for i in range(0, len(good_examples), batch_size):
127
+ batch = good_examples[i:i + batch_size]
128
  for line in batch:
129
  try:
130
  if not line.endswith('<|endoftext|>'):
131
+ line += ' <|endoftext|>'
132
  tokens = self.tokenizer.encode_ids(line, add_special_tokens=True)
133
  if 30 <= len(tokens) <= self.max_length:
134
  if len(tokens) < self.max_length:
135
+ tokens.extend([self.pad_token_id] * (self.max_length - len(tokens)))
136
  self.examples.append(tokens)
137
+ except Exception as e:
138
+ logger.warning(f"Failed to process line: {e}")
139
  continue
140
  if i % 5000 == 0:
141
+ logger.info(f"Processed {len(self.examples)} examples...")
142
+
143
+ logger.info(f"Final dataset size: {len(self.examples)} examples")
144
+
145
+ def _tensorize_data(self) -> None:
146
+ """Pre-tensorize data for faster training."""
147
+ logger.info("Pre-tensorizing data...")
148
+ seq_len = self.max_length - 1
149
  tensorized_examples = []
150
+
151
  for tokens in self.examples:
152
+ if len(tokens) != self.max_length:
153
+ continue
154
  input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
155
+ targets = torch.tensor(tokens[1:], dtype=torch.long)
156
  original_len = next((i for i, x in enumerate(tokens) if x == self.pad_token_id), self.max_length)
157
  mask_len = min(original_len, seq_len)
158
  attention_mask = torch.zeros(seq_len, dtype=torch.long)
159
+ attention_mask[:mask_len] = 1
160
  tensorized_examples.append({
161
  'input_ids': input_ids,
162
  'targets': targets,
163
  'attention_mask': attention_mask
164
  })
165
+
166
  self.examples = tensorized_examples
167
+ logger.info("Data pre-tensorized successfully")
168
+
169
+ def __len__(self) -> int:
170
+ """Return the number of examples in the dataset."""
171
+ return len(self.examples)
172
+
173
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
174
+ """Return a single example from the dataset."""
175
  return self.examples[idx]
176
+
177
  class FastCosineScheduler:
178
+ """Cosine learning rate scheduler with warmup."""
179
+
180
+ def __init__(self, optimizer: optim.Optimizer, total_steps: int, warmup_ratio: float = DEFAULT_WARMUP_RATIO):
181
+ """
182
+ Initialize the cosine scheduler.
183
+
184
+ Args:
185
+ optimizer (optim.Optimizer): Optimizer to schedule.
186
+ total_steps (int): Total training steps.
187
+ warmup_ratio (float): Ratio of steps for warmup phase.
188
+
189
+ Raises:
190
+ ValueError: If total_steps or warmup_ratio is invalid.
191
+ """
192
+ if total_steps <= 0 or not 0 <= warmup_ratio <= 1:
193
+ raise ValueError("total_steps must be positive and warmup_ratio must be in [0, 1]")
194
+
195
  self.optimizer = optimizer
196
  self.total_steps = total_steps
197
  self.warmup_steps = int(total_steps * warmup_ratio)
198
  self.base_lr = optimizer.param_groups[0]['lr']
199
+ self.step_count = 0
200
+
201
+ def step(self) -> float:
202
+ """
203
+ Update the learning rate.
204
+
205
+ Returns:
206
+ float: Current learning rate.
207
+ """
208
  self.step_count += 1
209
  if self.step_count <= self.warmup_steps:
210
  lr = self.base_lr * self.step_count / self.warmup_steps
211
  else:
212
  progress = (self.step_count - self.warmup_steps) / (self.total_steps - self.warmup_steps)
213
  lr = self.base_lr * 0.5 * (1 + math.cos(math.pi * progress))
214
+
215
  for param_group in self.optimizer.param_groups:
216
  param_group['lr'] = lr
217
  return lr
218
+
219
  class UltraFastTrainer:
220
+ """Trainer optimized for fast training of transformer models."""
221
+
222
+ def __init__(
223
+ self,
224
+ model: nn.Module,
225
+ tokenizer: TechnicalTokenizer,
226
+ train_dataset: FastTechnicalTextDataset,
227
+ val_dataset: Optional[FastTechnicalTextDataset] = None,
228
+ config: Optional[Dict] = None
229
+ ):
230
+ """
231
+ Initialize the trainer.
232
+
233
+ Args:
234
+ model (nn.Module): The transformer model to train.
235
+ tokenizer (TechnicalTokenizer): Tokenizer for encoding/decoding.
236
+ train_dataset (FastTechnicalTextDataset): Training dataset.
237
+ val_dataset (Optional[FastTechnicalTextDataset]): Validation dataset.
238
+ config (Optional[Dict]): Training configuration.
239
+
240
+ Raises:
241
+ ValueError: If config contains invalid parameters.
242
+ """
243
  self.model = model
244
  self.tokenizer = tokenizer
245
  self.train_dataset = train_dataset
246
  self.val_dataset = val_dataset
247
+ self.config = config or {}
248
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
249
+ self.model.to(self.device)
250
+
251
+ self._validate_config()
252
  self._fast_init_weights()
253
+ self._setup_fast_optimizer()
254
+
255
+ epochs = self.config.get('epochs', DEFAULT_EPOCHS)
256
+ batch_size = self.config.get('batch_size', DEFAULT_BATCH_SIZE)
257
  total_steps = len(train_dataset) // batch_size * epochs
258
  self.scheduler = FastCosineScheduler(self.optimizer, total_steps)
259
+ self.scaler = GradScaler() if self.device.type == 'cuda' else None
260
  self.global_step = 0
261
  self.best_loss = float('inf')
262
+ self.grad_accum_steps = self.config.get('gradient_accumulation_steps', DEFAULT_GRAD_ACCUM_STEPS)
263
+ self.eval_every = self.config.get('eval_every', DEFAULT_EVAL_EVERY)
264
+
265
+ def _validate_config(self) -> None:
266
+ """Validate training configuration."""
267
+ if self.config.get('batch_size', DEFAULT_BATCH_SIZE) <= 0:
268
+ raise ValueError("batch_size must be positive")
269
+ if self.config.get('epochs', DEFAULT_EPOCHS) <= 0:
270
+ raise ValueError("epochs must be positive")
271
+ if self.config.get('learning_rate', DEFAULT_LEARNING_RATE) <= 0:
272
+ raise ValueError("learning_rate must be positive")
273
+ if self.config.get('gradient_accumulation_steps', DEFAULT_GRAD_ACCUM_STEPS) <= 0:
274
+ raise ValueError("gradient_accumulation_steps must be positive")
275
+
276
+ def _fast_init_weights(self) -> None:
277
+ """Initialize model weights."""
278
+ def fast_init(module: nn.Module) -> None:
279
  if isinstance(module, nn.Linear):
280
  nn.init.normal_(module.weight, std=0.02)
281
  if module.bias is not None:
282
  nn.init.zeros_(module.bias)
283
  elif isinstance(module, nn.Embedding):
284
  nn.init.normal_(module.weight, std=0.02)
285
+ self.model.apply(fast_init)
286
+ logger.info("Model weights initialized")
287
+
288
+ def _setup_fast_optimizer(self) -> None:
289
+ """Set up AdamW optimizer."""
290
+ lr = self.config.get('learning_rate', DEFAULT_LEARNING_RATE)
291
  params = [p for p in self.model.parameters() if p.requires_grad]
292
+ self.optimizer = optim.AdamW(params, lr=lr, betas=(0.9, 0.99), weight_decay=0.01, eps=1e-6)
293
+ logger.info(f"Optimizer initialized with learning rate: {lr}")
294
+
295
+ def compute_fast_loss(self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
296
+ """
297
+ Compute masked cross-entropy loss.
298
+
299
+ Args:
300
+ logits (torch.Tensor): Model output logits of shape (batch_size, seq_len, vocab_size).
301
+ targets (torch.Tensor): Target token IDs of shape (batch_size, seq_len).
302
+ mask (torch.Tensor): Attention mask of shape (batch_size, seq_len).
303
+
304
+ Returns:
305
+ torch.Tensor: Computed loss.
306
+ """
307
  logits_flat = logits.view(-1, logits.size(-1))
308
  targets_flat = targets.view(-1)
309
  mask_flat = mask.view(-1).bool()
310
+
311
  if not mask_flat.any():
312
  return torch.tensor(0.0, device=logits.device, requires_grad=True)
313
+
314
+ return F.cross_entropy(logits_flat[mask_flat], targets_flat[mask_flat])
315
+
316
  def train_epoch_fast(self, epoch: int, dataloader: DataLoader) -> Dict[str, float]:
317
+ """
318
+ Train for one epoch.
319
+
320
+ Args:
321
+ epoch (int): Current epoch number.
322
+ dataloader (DataLoader): Training data loader.
323
+
324
+ Returns:
325
+ Dict[str, float]: Training metrics (loss, perplexity, epoch_time_min).
326
+ """
327
  self.model.train()
328
  total_loss = 0
329
  num_batches = 0
330
+ start_time = time.time()
331
+
332
  progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False, miniters=50)
333
  for batch_idx, batch in enumerate(progress_bar):
334
  input_ids = batch['input_ids'].to(self.device, non_blocking=True)
335
  targets = batch['targets'].to(self.device, non_blocking=True)
336
+ mask = batch['attention_mask'].to(self.device, non_blocking=True)
337
+
338
+ with autocast(enabled=self.device.type == 'cuda'):
339
  logits, comp_loss = self.model(input_ids, mask)
340
  lm_loss = self.compute_fast_loss(logits, targets, mask)
341
  total_loss_step = lm_loss + 0.0001 * comp_loss
342
  if self.grad_accum_steps > 1:
343
+ total_loss_step = total_loss_step / self.grad_accum_steps
344
+
345
+ if self.scaler:
346
+ self.scaler.scale(total_loss_step).backward()
347
+ if (batch_idx + 1) % self.grad_accum_steps == 0:
348
+ self.scaler.unscale_(self.optimizer)
349
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
350
+ self.scaler.step(self.optimizer)
351
+ self.scaler.update()
352
+ self.optimizer.zero_grad(set_to_none=True)
353
+ self.scheduler.step()
354
+ self.global_step += 1
355
+ else:
356
+ total_loss_step.backward()
357
+ if (batch_idx + 1) % self.grad_accum_steps == 0:
358
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
359
+ self.optimizer.step()
360
+ self.optimizer.zero_grad(set_to_none=True)
361
+ self.scheduler.step()
362
+ self.global_step += 1
363
+
364
  total_loss += lm_loss.item()
365
  num_batches += 1
366
+
367
  if batch_idx % 100 == 0:
368
  current_loss = total_loss / num_batches
369
  progress_bar.set_postfix({'loss': f"{current_loss:.3f}", 'ppl': f"{math.exp(min(current_loss, 10)):.1f}"})
370
+
371
+ if batch_idx % 200 == 0 and batch_idx > 0 and self.device.type == 'cuda':
372
+ torch.cuda.empty_cache()
373
+
374
  avg_loss = total_loss / max(num_batches, 1)
375
+ return {
376
+ 'loss': avg_loss,
377
+ 'perplexity': math.exp(min(avg_loss, 10)),
378
+ 'epoch_time_min': (time.time() - start_time) / 60
379
+ }
380
+
381
  def validate_fast(self, dataloader: DataLoader) -> Dict[str, float]:
382
+ """
383
+ Validate the model on the validation dataset.
384
+
385
+ Args:
386
+ dataloader (DataLoader): Validation data loader.
387
+
388
+ Returns:
389
+ Dict[str, float]: Validation metrics (loss, perplexity).
390
+ """
391
  self.model.eval()
392
  total_loss = 0
393
  num_batches = 0
394
+ max_val_batches = min(100, len(dataloader))
395
+
396
  with torch.no_grad():
397
  for batch_idx, batch in enumerate(dataloader):
398
  if batch_idx >= max_val_batches:
 
400
  input_ids = batch['input_ids'].to(self.device, non_blocking=True)
401
  targets = batch['targets'].to(self.device, non_blocking=True)
402
  mask = batch['attention_mask'].to(self.device, non_blocking=True)
403
+
404
+ with autocast(enabled=self.device.type == 'cuda'):
405
  logits, _ = self.model(input_ids, mask)
406
  loss = self.compute_fast_loss(logits, targets, mask)
407
+
408
  total_loss += loss.item()
409
+ num_batches += 1
410
+
411
  avg_loss = total_loss / max(num_batches, 1)
412
+ return {'loss': avg_loss, 'perplexity': math.exp(min(avg_loss, 10))}
413
+
414
+ def save_checkpoint_fast(self, epoch: int, metrics: Dict, save_dir: str = DEFAULT_CHECKPOINT_DIR) -> Optional[str]:
415
+ """
416
+ Save a checkpoint if the loss improves.
417
+
418
+ Args:
419
+ epoch (int): Current epoch number.
420
+ metrics (Dict): Training and validation metrics.
421
+ save_dir (str): Directory to save checkpoints.
422
+
423
+ Returns:
424
+ Optional[str]: Path to the saved checkpoint or None.
425
+ """
426
  os.makedirs(save_dir, exist_ok=True)
427
  val_loss = metrics.get('val_loss', metrics.get('loss', float('inf')))
428
+
429
  if val_loss < self.best_loss:
430
  self.best_loss = val_loss
431
  checkpoint = {
 
433
  'model_state_dict': self.model.state_dict(),
434
  'optimizer_state_dict': self.optimizer.state_dict(),
435
  'metrics': metrics,
436
+ 'scaler_state_dict': self.scaler.state_dict() if self.scaler else None
437
  }
438
  best_path = os.path.join(save_dir, "best_model.pt")
439
  torch.save(checkpoint, best_path)
440
+ logger.info(f"New best checkpoint saved: {best_path}, Loss: {val_loss:.4f}")
441
  return best_path
442
+ return None
443
+
444
+ def train_ultra_fast(self, num_epochs: int = DEFAULT_EPOCHS, batch_size: int = DEFAULT_BATCH_SIZE) -> List[Dict]:
445
+ """
446
+ Train the model with optimized settings.
447
+
448
+ Args:
449
+ num_epochs (int): Number of training epochs.
450
+ batch_size (int): Batch size for training.
451
+
452
+ Returns:
453
+ List[Dict]: Training history with metrics for each epoch.
454
+ """
455
+ logger.info(f"Starting ultra-fast training: {num_epochs} epochs, batch_size={batch_size}")
456
+ logger.info("Target: Loss < 2.0, PPL < 12, Time: 4-5 hours")
457
+
458
  train_loader = DataLoader(
459
  self.train_dataset,
460
  batch_size=batch_size,
461
  shuffle=True,
462
+ num_workers=min(multiprocessing.cpu_count(), 4),
463
+ pin_memory=self.device.type == 'cuda',
464
  persistent_workers=True,
465
  drop_last=True
466
+ )
467
+
468
  val_loader = None
469
  if self.val_dataset:
470
  val_loader = DataLoader(
471
  self.val_dataset,
472
  batch_size=batch_size * 2,
473
  shuffle=False,
474
+ num_workers=min(multiprocessing.cpu_count() // 2, 2),
475
+ pin_memory=self.device.type == 'cuda'
476
+ )
477
+
478
  total_start_time = time.time()
479
+ history = []
480
+
481
  for epoch in range(1, num_epochs + 1):
482
+ logger.info(f"Starting epoch {epoch}/{num_epochs}")
483
+ train_metrics = self.train_epoch_fast(epoch, train_loader)
484
+
485
  val_metrics = {}
486
  if val_loader and (epoch % 2 == 0 or epoch == num_epochs):
487
+ val_metrics = self.validate_fast(val_loader)
488
+
489
+ epoch_time = train_metrics['epoch_time_min'] * 60
490
  epoch_info = {
491
  'epoch': epoch,
492
  'train_loss': train_metrics['loss'],
493
  'train_ppl': train_metrics['perplexity'],
494
+ 'epoch_time_min': train_metrics['epoch_time_min']
495
  }
496
  if val_metrics:
497
  epoch_info.update({'val_loss': val_metrics['loss'], 'val_ppl': val_metrics['perplexity']})
498
+
499
+ history.append(epoch_info)
500
+
501
  elapsed_hours = (time.time() - total_start_time) / 3600
502
+ remaining_hours = elapsed_hours * (num_epochs - epoch) / max(epoch, 1)
503
+
504
+ logger.info(f"Epoch {epoch} results:")
505
+ logger.info(f" Epoch time: {epoch_time/60:.1f} min")
506
+ logger.info(f" Total elapsed: {elapsed_hours:.1f}h")
507
+ logger.info(f" Est. remaining: {remaining_hours:.1f}h")
508
+ logger.info(f" Train Loss: {train_metrics['loss']:.4f}")
509
+ logger.info(f" Train PPL: {train_metrics['perplexity']:.1f}")
510
  if val_metrics:
511
+ logger.info(f" Val Loss: {val_metrics['loss']:.4f}")
512
+ logger.info(f" Val PPL: {val_metrics['perplexity']:.1f}")
513
+
514
  current_loss = val_metrics.get('loss', train_metrics['loss'])
515
  current_ppl = val_metrics.get('perplexity', train_metrics['perplexity'])
516
  if current_loss < 2.0 and current_ppl < 12:
517
+ logger.info(f"Targets achieved: Loss={current_loss:.4f} < 2.0, PPL={current_ppl:.1f} < 12")
518
+
 
519
  combined_metrics = {**train_metrics}
520
  if val_metrics:
521
  combined_metrics.update({f"val_{k}": v for k, v in val_metrics.items()})
522
+ self.save_checkpoint_fast(epoch, combined_metrics)
523
+
524
+ if self.device.type == 'cuda':
525
+ torch.cuda.empty_cache()
526
+ gc.collect()
527
+
528
  if current_loss < 1.8 and current_ppl < 10:
529
+ logger.info("Early stopping: Excellent performance achieved!")
530
+ break
531
+
532
+ total_time = (time.time() - total_start_time) / 3600
533
+ logger.info(f"Training completed in {total_time:.1f} hours")
534
+ logger.info(f"Best loss: {self.best_loss:.4f}")
535
  return history
536
+
537
+ def run_ultra_fast_training() -> int:
538
+ """
539
+ Run the ultra-fast training pipeline.
540
+
541
+ Returns:
542
+ int: Exit code (0 for success, 1 for failure).
543
+ """
544
+ parser = argparse.ArgumentParser(description="Ultra-Fast Training for MixtureOfRecursions Model")
545
+ parser.add_argument("--train_file", default=None, help="Path to training data file")
546
+ parser.add_argument("--val_file", default=None, help="Path to validation data file")
547
+ parser.add_argument("--tokenizer_dir", default="tokenizer", help="Directory for tokenizer files")
548
+ parser.add_argument("--max_examples", type=int, default=DEFAULT_MAX_EXAMPLES, help="Maximum number of training examples")
549
+ parser.add_argument("--d_model", type=int, default=DEFAULT_D_MODEL, help="Model embedding dimension")
550
+ parser.add_argument("--n_layers", type=int, default=DEFAULT_N_LAYERS, help="Number of transformer layers")
551
+ parser.add_argument("--n_heads", type=int, default=DEFAULT_N_HEADS, help="Number of attention heads")
552
+ parser.add_argument("--max_seq_len", type=int, default=DEFAULT_MAX_LENGTH, help="Maximum sequence length")
553
+ parser.add_argument("--epochs", type=int, default=DEFAULT_EPOCHS, help="Number of training epochs")
554
+ parser.add_argument("--batch_size", type=int, default=DEFAULT_BATCH_SIZE, help="Batch size for training")
555
+ parser.add_argument("--learning_rate", type=float, default=DEFAULT_LEARNING_RATE, help="Learning rate")
556
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=DEFAULT_GRAD_ACCUM_STEPS, help="Gradient accumulation steps")
557
+ parser.add_argument("--eval_every", type=int, default=DEFAULT_EVAL_EVERY, help="Evaluate every N steps")
558
+
559
+ args = parser.parse_args()
560
+
561
  torch.manual_seed(42)
562
+ np.random.seed(42)
563
+
564
+ logger.info("Starting ultra-fast training pipeline")
565
+
566
  if args.train_file is None:
567
  patterns = ["*train*.txt", "*_train.txt"]
568
  files = []
569
  for pattern in patterns:
570
  files.extend(glob.glob(pattern))
571
+ files.extend(glob.glob(os.path.join("split_data", pattern)))
572
+ files.extend(glob.glob(os.path.join("data", pattern)))
573
  if files:
574
  args.train_file = files[0]
575
+ logger.info(f"Found training file: {args.train_file}")
576
  else:
577
+ logger.error("No training files found!")
578
+ return 1
579
+
580
  try:
581
+ tokenizer = TechnicalTokenizer()
582
  tokenizer.load(args.tokenizer_dir)
583
+ logger.info(f"Tokenizer loaded with vocab size: {tokenizer.get_vocab_size()}")
584
  except Exception as e:
585
+ logger.error(f"Failed to load tokenizer: {e}")
586
+ return 1
587
+
588
+ logger.info("Creating training dataset...")
589
+ try:
590
+ train_dataset = FastTechnicalTextDataset(
591
+ args.train_file, tokenizer, args.max_seq_len, args.max_examples
592
+ )
593
+ except Exception as e:
594
+ logger.error(f"Failed to create training dataset: {e}")
595
+ return 1
596
+
597
  val_dataset = None
598
  if args.val_file and os.path.exists(args.val_file):
599
+ try:
600
+ val_dataset = FastTechnicalTextDataset(
601
+ args.val_file, tokenizer, args.max_seq_len, max_examples=5000
602
+ )
603
+ logger.info("Validation dataset created")
604
+ except Exception as e:
605
+ logger.warning(f"Failed to create validation dataset: {e}")
606
+
607
+ try:
608
+ model = MixtureOfRecursions(
609
+ vocab_size=tokenizer.get_vocab_size(),
610
+ d_model=args.d_model,
611
+ n_layers=args.n_layers,
612
+ n_heads=args.n_heads,
613
+ max_seq_len=args.max_seq_len - 1,
614
+ padding_idx=tokenizer.vocab.get('<pad>', 0)
615
+ )
616
+ logger.info("Model initialized")
617
+ except Exception as e:
618
+ logger.error(f"Failed to initialize model: {e}")
619
+ return 1
620
+
621
  config = {
622
  'learning_rate': args.learning_rate,
623
  'gradient_accumulation_steps': args.gradient_accumulation_steps,
624
  'eval_every': args.eval_every,
625
  'batch_size': args.batch_size,
626
  'epochs': args.epochs
627
+ }
628
+
629
+ try:
630
+ trainer = UltraFastTrainer(model, tokenizer, train_dataset, val_dataset, config)
631
+ results = trainer.train_ultra_fast(args.epochs, args.batch_size)
632
+
633
+ with open('ultra_fast_results.json', 'w') as f:
634
+ json.dump(results, f, indent=2)
635
+ logger.info("Training results saved to ultra_fast_results.json")
636
+
637
+ return 0
638
+ except Exception as e:
639
+ logger.error(f"Training failed: {e}")
640
+ return 1
641
+
642
  if __name__ == "__main__":
643
  exit(run_ultra_fast_training())