Girinath11 commited on
Commit
c2b8ae0
·
verified ·
1 Parent(s): b28dff6

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +383 -0
train.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import torch.nn.functional as F
6
+ from torch.cuda.amp import GradScaler, autocast
7
+ import os
8
+ import json
9
+ import argparse
10
+ import time
11
+ import math
12
+ import glob
13
+ from typing import Dict, List
14
+ from tqdm import tqdm
15
+ import numpy as np
16
+ import gc
17
+ from collections import defaultdict
18
+ import multiprocessing
19
+ # Import custom modules
20
+ try:
21
+ from model_slm import MixtureOfRecursions, count_parameters, TextGenerator
22
+ from custom_tokenizer import TechnicalTokenizer
23
+ except ImportError as e:
24
+ print(f"Import error: {e}")
25
+ exit(1)
26
+ class FastTechnicalTextDataset(Dataset):
27
+ """Ultra-fast dataset with aggressive optimizations for 4-5hr training"""
28
+ def __init__(self, data_file: str, tokenizer: TechnicalTokenizer, max_length: int = 128, max_examples: int = 50000):
29
+ self.tokenizer = tokenizer
30
+ self.max_length = max_length
31
+ self.pad_token_id = tokenizer.vocab.get('<pad>', 0)
32
+ self.max_examples = max_examples
33
+ print(f"FAST DATASET LOADING")
34
+ print(f"Data file: {data_file}")
35
+ print(f"Max sequence length: {max_length}")
36
+ print(f"Max examples: {max_examples}")
37
+ start_time = time.time()
38
+ self.examples = []
39
+ self._fast_load_data(data_file)
40
+ load_time = time.time() - start_time
41
+ print(f" Loaded {len(self.examples)} examples in {load_time:.1f}s")
42
+ self._tensorize_data()
43
+ gc.collect()
44
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
45
+ def _fast_load_data(self, data_file: str):
46
+ print("🔍 Fast reading file...")
47
+ with open(data_file, 'r', encoding='utf-8') as f:
48
+ lines = f.readlines()
49
+ print(f"File has {len(lines)} lines")
50
+ good_examples = []
51
+ seen_hashes = set()
52
+ for line in lines[:self.max_examples * 3]:
53
+ line = line.strip()
54
+ if (50 <= len(line) <= 400 and
55
+ line.count(' ') >= 8 and
56
+ not line.lower().startswith(('http', 'www', 'ftp')) and
57
+ line.count('.') <= len(line) * 0.1):
58
+ line_hash = hash(line[:100])
59
+ if line_hash not in seen_hashes:
60
+ seen_hashes.add(line_hash)
61
+ good_examples.append(line)
62
+ if len(good_examples) >= self.max_examples:
63
+ break
64
+ print(f"After fast filtering: {len(good_examples)} quality examples")
65
+ batch_size = 1000
66
+ for i in range(0, len(good_examples), batch_size):
67
+ batch = good_examples[i:i+batch_size]
68
+ for line in batch:
69
+ try:
70
+ if not line.endswith('<|endoftext|>'):
71
+ line += ' <|endoftext|>'
72
+ tokens = self.tokenizer.encode_ids(line, add_special_tokens=True)
73
+ if 30 <= len(tokens) <= self.max_length:
74
+ if len(tokens) < self.max_length:
75
+ tokens = tokens + [self.pad_token_id] * (self.max_length - len(tokens))
76
+ self.examples.append(tokens)
77
+ except:
78
+ continue
79
+ if i % 5000 == 0:
80
+ print(f"Processed {len(self.examples)} examples...")
81
+ print(f"Final dataset: {len(self.examples)} examples")
82
+ def _tensorize_data(self):
83
+ print("Pre-tensorizing data for maximum speed...")
84
+ seq_len = self.max_length - 1
85
+ tensorized_examples = []
86
+ for tokens in self.examples:
87
+ if len(tokens) < self.max_length:
88
+ continue
89
+ input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
90
+ targets = torch.tensor(tokens[1:], dtype=torch.long)
91
+ original_len = next((i for i, x in enumerate(tokens) if x == self.pad_token_id), self.max_length)
92
+ mask_len = min(original_len, seq_len)
93
+ attention_mask = torch.zeros(seq_len, dtype=torch.long)
94
+ attention_mask[:mask_len] = 1
95
+ tensorized_examples.append({
96
+ 'input_ids': input_ids,
97
+ 'targets': targets,
98
+ 'attention_mask': attention_mask
99
+ })
100
+ self.examples = tensorized_examples
101
+ print("All data pre-tensorized")
102
+ def __len__(self):
103
+ return len(self.examples)
104
+ def __getitem__(self, idx):
105
+ return self.examples[idx]
106
+ class FastCosineScheduler:
107
+ def __init__(self, optimizer, total_steps: int, warmup_ratio: float = 0.05):
108
+ self.optimizer = optimizer
109
+ self.total_steps = total_steps
110
+ self.warmup_steps = int(total_steps * warmup_ratio)
111
+ self.base_lr = optimizer.param_groups[0]['lr']
112
+ self.step_count = 0
113
+ def step(self):
114
+ self.step_count += 1
115
+ if self.step_count <= self.warmup_steps:
116
+ lr = self.base_lr * self.step_count / self.warmup_steps
117
+ else:
118
+ progress = (self.step_count - self.warmup_steps) / (self.total_steps - self.warmup_steps)
119
+ lr = self.base_lr * 0.5 * (1 + math.cos(math.pi * progress))
120
+ for param_group in self.optimizer.param_groups:
121
+ param_group['lr'] = lr
122
+ return lr
123
+ class UltraFastTrainer:
124
+ def __init__(self, model, tokenizer, train_dataset, val_dataset=None, config=None):
125
+ self.model = model
126
+ self.tokenizer = tokenizer
127
+ self.train_dataset = train_dataset
128
+ self.val_dataset = val_dataset
129
+ self.config = config or {}
130
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
131
+ self.model.to(self.device)
132
+ self._fast_init_weights()
133
+ self._setup_fast_optimizer()
134
+ epochs = self.config.get('epochs', 15)
135
+ batch_size = self.config.get('batch_size', 16)
136
+ total_steps = len(train_dataset) // batch_size * epochs
137
+ self.scheduler = FastCosineScheduler(self.optimizer, total_steps)
138
+ self.scaler = GradScaler()
139
+ self.global_step = 0
140
+ self.best_loss = float('inf')
141
+ self.grad_accum_steps = self.config.get('gradient_accumulation_steps', 1)
142
+ self.eval_every = self.config.get('eval_every', 500)
143
+ def _fast_init_weights(self):
144
+ def fast_init(module):
145
+ if isinstance(module, nn.Linear):
146
+ nn.init.normal_(module.weight, std=0.02)
147
+ if module.bias is not None:
148
+ nn.init.zeros_(module.bias)
149
+ elif isinstance(module, nn.Embedding):
150
+ nn.init.normal_(module.weight, std=0.02)
151
+ self.model.apply(fast_init)
152
+ def _setup_fast_optimizer(self):
153
+ lr = self.config.get('learning_rate', 5e-4)
154
+ params = [p for p in self.model.parameters() if p.requires_grad]
155
+ self.optimizer = optim.AdamW(params, lr=lr, betas=(0.9, 0.99), weight_decay=0.01, eps=1e-6)
156
+ def compute_fast_loss(self, logits, targets, mask):
157
+ logits_flat = logits.view(-1, logits.size(-1))
158
+ targets_flat = targets.view(-1)
159
+ mask_flat = mask.view(-1).bool()
160
+ if not mask_flat.any():
161
+ return torch.tensor(0.0, device=logits.device, requires_grad=True)
162
+ loss = F.cross_entropy(logits_flat[mask_flat], targets_flat[mask_flat])
163
+ return loss
164
+ def train_epoch_fast(self, epoch: int, dataloader: DataLoader) -> Dict[str, float]:
165
+ self.model.train()
166
+ total_loss = 0
167
+ num_batches = 0
168
+ start_time = time.time()
169
+ progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False, miniters=50)
170
+ for batch_idx, batch in enumerate(progress_bar):
171
+ input_ids = batch['input_ids'].to(self.device, non_blocking=True)
172
+ targets = batch['targets'].to(self.device, non_blocking=True)
173
+ mask = batch['attention_mask'].to(self.device, non_blocking=True)
174
+ with autocast():
175
+ logits, comp_loss = self.model(input_ids, mask)
176
+ lm_loss = self.compute_fast_loss(logits, targets, mask)
177
+ total_loss_step = lm_loss + 0.0001 * comp_loss
178
+ if self.grad_accum_steps > 1:
179
+ total_loss_step = total_loss_step / self.grad_accum_steps
180
+ self.scaler.scale(total_loss_step).backward()
181
+ if (batch_idx + 1) % self.grad_accum_steps == 0:
182
+ self.scaler.unscale_(self.optimizer)
183
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
184
+ self.scaler.step(self.optimizer)
185
+ self.scaler.update()
186
+ self.optimizer.zero_grad(set_to_none=True)
187
+ self.scheduler.step()
188
+ self.global_step += 1
189
+ total_loss += lm_loss.item()
190
+ num_batches += 1
191
+ if batch_idx % 100 == 0:
192
+ current_loss = total_loss / num_batches
193
+ progress_bar.set_postfix({'loss': f"{current_loss:.3f}", 'ppl': f"{math.exp(min(current_loss, 10)):.1f}"})
194
+ if batch_idx % 200 == 0 and batch_idx > 0:
195
+ torch.cuda.empty_cache()
196
+ epoch_time = time.time() - start_time
197
+ avg_loss = total_loss / max(num_batches, 1)
198
+ return {'loss': avg_loss, 'perplexity': math.exp(min(avg_loss, 10)), 'epoch_time_min': epoch_time / 60}
199
+ def validate_fast(self, dataloader: DataLoader) -> Dict[str, float]:
200
+ self.model.eval()
201
+ total_loss = 0
202
+ num_batches = 0
203
+ max_val_batches = min(100, len(dataloader))
204
+ with torch.no_grad():
205
+ for batch_idx, batch in enumerate(dataloader):
206
+ if batch_idx >= max_val_batches:
207
+ break
208
+ input_ids = batch['input_ids'].to(self.device, non_blocking=True)
209
+ targets = batch['targets'].to(self.device, non_blocking=True)
210
+ mask = batch['attention_mask'].to(self.device, non_blocking=True)
211
+ with autocast():
212
+ logits, _ = self.model(input_ids, mask)
213
+ loss = self.compute_fast_loss(logits, targets, mask)
214
+ total_loss += loss.item()
215
+ num_batches += 1
216
+ avg_loss = total_loss / max(num_batches, 1)
217
+ return {'loss': avg_loss, 'perplexity': math.exp(min(avg_loss, 10))}
218
+ def save_checkpoint_fast(self, epoch: int, metrics: Dict, save_dir: str = "checkpoints"):
219
+ os.makedirs(save_dir, exist_ok=True)
220
+ val_loss = metrics.get('val_loss', metrics.get('loss', float('inf')))
221
+ if val_loss < self.best_loss:
222
+ self.best_loss = val_loss
223
+ checkpoint = {
224
+ 'epoch': epoch,
225
+ 'model_state_dict': self.model.state_dict(),
226
+ 'optimizer_state_dict': self.optimizer.state_dict(),
227
+ 'metrics': metrics,
228
+ 'scaler_state_dict': self.scaler.state_dict()
229
+ }
230
+ best_path = os.path.join(save_dir, "best_model.pt")
231
+ torch.save(checkpoint, best_path)
232
+ print(f"New best! Loss: {val_loss:.4f}")
233
+ return best_path
234
+ return None
235
+ def train_ultra_fast(self, num_epochs: int = 15, batch_size: int = 16):
236
+ print(f"\n ULTRA-FAST TRAINING")
237
+ print(f" Target: Loss < 2.0, PPL < 12")
238
+ print(f" Time target: 4-5 hours")
239
+ print(f" Epochs: {num_epochs}")
240
+ print(f" Batch size: {batch_size}")
241
+ print("-" * 60)
242
+ train_loader = DataLoader(
243
+ self.train_dataset,
244
+ batch_size=batch_size,
245
+ shuffle=True,
246
+ num_workers=4,
247
+ pin_memory=True,
248
+ persistent_workers=True,
249
+ drop_last=True
250
+ )
251
+ val_loader = None
252
+ if self.val_dataset:
253
+ val_loader = DataLoader(
254
+ self.val_dataset,
255
+ batch_size=batch_size * 2,
256
+ shuffle=False,
257
+ num_workers=2,
258
+ pin_memory=True
259
+ )
260
+ total_start_time = time.time()
261
+ history = []
262
+ for epoch in range(1, num_epochs + 1):
263
+ epoch_start = time.time()
264
+ print(f"\n EPOCH {epoch}/{num_epochs}")
265
+ train_metrics = self.train_epoch_fast(epoch, train_loader)
266
+ val_metrics = {}
267
+ if val_loader and (epoch % 2 == 0 or epoch == num_epochs):
268
+ val_metrics = self.validate_fast(val_loader)
269
+ epoch_time = time.time() - epoch_start
270
+ epoch_info = {
271
+ 'epoch': epoch,
272
+ 'train_loss': train_metrics['loss'],
273
+ 'train_ppl': train_metrics['perplexity'],
274
+ 'epoch_time_min': epoch_time / 60
275
+ }
276
+ if val_metrics:
277
+ epoch_info.update({'val_loss': val_metrics['loss'], 'val_ppl': val_metrics['perplexity']})
278
+ history.append(epoch_info)
279
+ elapsed_hours = (time.time() - total_start_time) / 3600
280
+ remaining_hours = elapsed_hours * (num_epochs - epoch) / epoch
281
+ print(f"\n EPOCH {epoch} RESULTS:")
282
+ print(f" Epoch time: {epoch_time/60:.1f} min")
283
+ print(f" Total elapsed: {elapsed_hours:.1f}h")
284
+ print(f" Est. remaining: {remaining_hours:.1f}h")
285
+ print(f" Train Loss: {train_metrics['loss']:.4f}")
286
+ print(f" Train PPL: {train_metrics['perplexity']:.1f}")
287
+ if val_metrics:
288
+ print(f" Val Loss: {val_metrics['loss']:.4f}")
289
+ print(f" Val PPL: {val_metrics['perplexity']:.1f}")
290
+ current_loss = val_metrics.get('loss', train_metrics['loss'])
291
+ current_ppl = val_metrics.get('perplexity', train_metrics['perplexity'])
292
+ if current_loss < 2.0 and current_ppl < 12:
293
+ print(f" TARGETS ACHIEVED!")
294
+ print(f" Loss: {current_loss:.4f} < 2.0")
295
+ print(f" PPL: {current_ppl:.1f} < 12")
296
+ combined_metrics = {**train_metrics}
297
+ if val_metrics:
298
+ combined_metrics.update({f"val_{k}": v for k, v in val_metrics.items()})
299
+ self.save_checkpoint_fast(epoch, combined_metrics)
300
+ torch.cuda.empty_cache()
301
+ gc.collect()
302
+ if current_loss < 1.8 and current_ppl < 10:
303
+ print(f"EARLY STOPPING - Excellent performance achieved!")
304
+ break
305
+ total_time = time.time() - total_start_time
306
+ print(f"\n TRAINING COMPLETED!")
307
+ print(f"Total time: {total_time/3600:.1f} hours")
308
+ print(f" Best loss: {self.best_loss:.4f}")
309
+ return history
310
+ def run_ultra_fast_training():
311
+ parser = argparse.ArgumentParser(description="Ultra-Fast Training for 4-5 Hours")
312
+ parser.add_argument("--train_file", default=None)
313
+ parser.add_argument("--val_file", default=None)
314
+ parser.add_argument("--tokenizer_dir", default="tokenizer")
315
+ parser.add_argument("--max_examples", type=int, default=50000)
316
+ parser.add_argument("--d_model", type=int, default=384)
317
+ parser.add_argument("--n_layers", type=int, default=6)
318
+ parser.add_argument("--n_heads", type=int, default=6)
319
+ parser.add_argument("--max_seq_len", type=int, default=128)
320
+ parser.add_argument("--epochs", type=int, default=15)
321
+ parser.add_argument("--batch_size", type=int, default=16)
322
+ parser.add_argument("--learning_rate", type=float, default=5e-4)
323
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
324
+ parser.add_argument("--eval_every", type=int, default=500)
325
+ args = parser.parse_args()
326
+ torch.manual_seed(42)
327
+ np.random.seed(42)
328
+ print("Training My Model")
329
+ print("-" * 50)
330
+ if args.train_file is None:
331
+ patterns = ["*train*.txt", "*_train.txt"]
332
+ files = []
333
+ for pattern in patterns:
334
+ files.extend(glob.glob(pattern))
335
+ files.extend(glob.glob(f"split_data/{pattern}"))
336
+ files.extend(glob.glob(f"data/{pattern}"))
337
+ if files:
338
+ args.train_file = files[0]
339
+ print(f"Found: {args.train_file}")
340
+ else:
341
+ print(" No training files found!")
342
+ return 1
343
+ tokenizer = TechnicalTokenizer()
344
+ try:
345
+ tokenizer.load(args.tokenizer_dir)
346
+ print(f"Tokenizer loaded. Vocab size: {tokenizer.get_vocab_size()}")
347
+ except Exception as e:
348
+ print(f" Tokenizer error: {e}")
349
+ return 1
350
+ print(" Creating ultra-fast dataset...")
351
+ train_dataset = FastTechnicalTextDataset(
352
+ args.train_file, tokenizer, args.max_seq_len, args.max_examples
353
+ )
354
+ val_dataset = None
355
+ if args.val_file and os.path.exists(args.val_file):
356
+ val_dataset = FastTechnicalTextDataset(
357
+ args.val_file, tokenizer, args.max_seq_len, max_examples=5000
358
+ )
359
+ model = MixtureOfRecursions(
360
+ vocab_size=tokenizer.get_vocab_size(),
361
+ d_model=args.d_model,
362
+ n_layers=args.n_layers,
363
+ n_heads=args.n_heads,
364
+ max_seq_len=args.max_seq_len - 1, # Pass the actual sequence length to the model
365
+ padding_idx=tokenizer.vocab.get('<pad>', 0)
366
+ )
367
+ config = {
368
+ 'learning_rate': args.learning_rate,
369
+ 'gradient_accumulation_steps': args.gradient_accumulation_steps,
370
+ 'eval_every': args.eval_every,
371
+ 'batch_size': args.batch_size,
372
+ 'epochs': args.epochs
373
+ }
374
+ trainer = UltraFastTrainer(model, tokenizer, train_dataset, val_dataset, config)
375
+ print(f"\n START TRAINING")
376
+ results = trainer.train_ultra_fast(args.epochs, args.batch_size)
377
+ with open('ultra_fast_results.json', 'w') as f:
378
+ json.dump(results, f, indent=2)
379
+ print("\n Training Completed!")
380
+ print(" Results saved to: ultra_fast_results.json")
381
+ return 0
382
+ if __name__ == "__main__":
383
+ exit(run_ultra_fast_training())