| | |
| | """ |
| | Final Breakthrough BitTransformerLM Training Script |
| | ================================================= |
| | |
| | The complete training script using the ACTUAL BitTransformerLM model |
| | with the breakthrough Fixed RL Adafactor configuration and full |
| | HuggingFace dataset support with checkpoint resumption. |
| | """ |
| |
|
| | import sys |
| | import os |
| | import json |
| | import logging |
| | from pathlib import Path |
| | from datetime import datetime |
| | from typing import Optional, Dict, Any |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from datasets import load_dataset |
| | from huggingface_hub import login |
| |
|
| | |
| | sys.path.append('/data') |
| | sys.path.append('/data/BitTransformerLM') |
| |
|
| | from bit_transformer import BitTransformerLM, text_to_bits |
| | from BTLM_Extensions import configure_adafactor_optimizer |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format='%(asctime)s - %(levelname)s - %(message)s', |
| | handlers=[ |
| | logging.FileHandler('/data/BitTransformerLM/breakthrough_training.log'), |
| | logging.StreamHandler() |
| | ] |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | class BreakthroughTrainer: |
| | """Production-grade BitTransformerLM trainer with breakthrough configuration.""" |
| | |
| | def __init__(self, config: Dict[str, Any]): |
| | self.config = config |
| | self.device = torch.device('cpu') |
| | self.model = None |
| | self.optimizer = None |
| | self.scheduler = None |
| | self.dataset = None |
| | self.checkpoint_dir = Path(config['checkpoint_dir']) |
| | self.checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | self.current_epoch = 0 |
| | self.total_steps = 0 |
| | self.best_loss = float('inf') |
| | self.training_history = [] |
| |
|
| | def load_and_prepare_dataset(self): |
| | """Load HF dataset and convert to proper bit tensors.""" |
| | logger.info("Loading WCNegentropy/BitTransformerLM dataset...") |
| | |
| | |
| | login(token=self.config['hf_token']) |
| | |
| | |
| | dataset = load_dataset("WCNegentropy/BitTransformerLM") |
| | train_data = dataset['train'] |
| | |
| | logger.info(f"Dataset loaded: {len(train_data)} samples") |
| | |
| | |
| | bit_sequences = [] |
| | for i, sample in enumerate(train_data): |
| | if i % 1000 == 0: |
| | logger.info(f"Processing sample {i}/{len(train_data)}") |
| | |
| | |
| | text = None |
| | if 'original_text' in sample and sample['original_text']: |
| | text = sample['original_text'] |
| | elif 'text' in sample and sample['text']: |
| | text = sample['text'] |
| | |
| | if text and text.strip(): |
| | |
| | bits = text_to_bits(text) |
| | if len(bits) >= self.config['sequence_length']: |
| | bit_sequences.append(bits) |
| | |
| | logger.info(f"Processed {len(bit_sequences)} valid bit sequences") |
| | |
| | |
| | seq_len = self.config['sequence_length'] |
| | training_sequences = [] |
| | |
| | for bits in bit_sequences: |
| | |
| | for i in range(0, len(bits) - seq_len + 1, seq_len // 2): |
| | chunk = bits[i:i + seq_len] |
| | if len(chunk) == seq_len: |
| | training_sequences.append(chunk) |
| | |
| | |
| | self.dataset = torch.tensor(training_sequences, dtype=torch.long) |
| | logger.info(f"Created training dataset: {self.dataset.shape}") |
| | |
| | return self.dataset |
| |
|
| | def create_breakthrough_model(self): |
| | """Create the EXACT breakthrough 16M parameter BitTransformerLM.""" |
| | logger.info("Creating breakthrough 16M parameter BitTransformerLM...") |
| | |
| | |
| | self.model = BitTransformerLM( |
| | d_model=512, |
| | nhead=16, |
| | num_layers=8, |
| | dim_feedforward=1024, |
| | max_seq_len=self.config['sequence_length'], |
| | lambda_K=0.05, |
| | lambda_C=0.05, |
| | lambda_S=0.05, |
| | reversible=True, |
| | use_checkpoint=True, |
| | use_autocast=True, |
| | use_act=True, |
| | act_threshold=0.9 |
| | ).to(self.device) |
| | |
| | |
| | total_params = sum(p.numel() for p in self.model.parameters()) |
| | trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) |
| | |
| | logger.info(f"Model created: {total_params:,} total parameters ({trainable_params:,} trainable)") |
| | logger.info(f"Target: ~16M parameters - {'β' if 15_000_000 <= total_params <= 17_000_000 else 'β'}") |
| | |
| | return self.model |
| |
|
| | def setup_optimizer(self): |
| | """Setup Fixed RL Adafactor optimizer (the breakthrough secret sauce).""" |
| | logger.info("Setting up Fixed RL Adafactor optimizer...") |
| | |
| | |
| | steps_per_epoch = len(self.dataset) // self.config['batch_size'] |
| | total_steps = steps_per_epoch * self.config['num_epochs'] |
| | |
| | |
| | self.optimizer, self.scheduler = configure_adafactor_optimizer( |
| | self.model, |
| | lr=self.config['learning_rate'], |
| | weight_decay=self.config['weight_decay'], |
| | total_steps=total_steps |
| | ) |
| | |
| | logger.info(f"Fixed RL Adafactor configured with LR={self.config['learning_rate']}") |
| | logger.info(f"Total training steps: {total_steps}") |
| | |
| | return self.optimizer, self.scheduler |
| |
|
| | def save_checkpoint(self, epoch: int, loss: float, is_best: bool = False): |
| | """Save complete model checkpoint with all training state.""" |
| | checkpoint_data = { |
| | 'epoch': epoch, |
| | 'total_steps': self.total_steps, |
| | 'model_state_dict': self.model.state_dict(), |
| | 'optimizer_state_dict': self.optimizer.state_dict(), |
| | 'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None, |
| | 'loss': loss, |
| | 'best_loss': self.best_loss, |
| | 'config': self.config, |
| | 'training_history': self.training_history, |
| | 'timestamp': datetime.now().isoformat(), |
| | 'model_config': self.model._current_params() |
| | } |
| | |
| | |
| | latest_path = self.checkpoint_dir / 'checkpoint_latest.pt' |
| | torch.save(checkpoint_data, latest_path) |
| | logger.info(f"Saved checkpoint: {latest_path}") |
| | |
| | |
| | epoch_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch:04d}.pt' |
| | torch.save(checkpoint_data, epoch_path) |
| | |
| | |
| | if is_best: |
| | best_path = self.checkpoint_dir / 'checkpoint_best.pt' |
| | torch.save(checkpoint_data, best_path) |
| | logger.info(f"π NEW BEST MODEL! Loss: {loss:.6f} -> {best_path}") |
| | |
| | |
| | config_path = self.checkpoint_dir / 'training_config.json' |
| | with open(config_path, 'w') as f: |
| | json.dump(self.config, f, indent=2) |
| |
|
| | def load_checkpoint(self, checkpoint_path: Optional[str] = None) -> bool: |
| | """Load checkpoint if available and resume training.""" |
| | if checkpoint_path is None: |
| | checkpoint_path = self.checkpoint_dir / 'checkpoint_latest.pt' |
| | |
| | checkpoint_path = Path(checkpoint_path) |
| | if not checkpoint_path.exists(): |
| | logger.info("No checkpoint found - starting fresh training") |
| | return False |
| | |
| | logger.info(f"Loading checkpoint: {checkpoint_path}") |
| | try: |
| | checkpoint = torch.load(checkpoint_path, map_location=self.device) |
| | |
| | |
| | self.model.load_state_dict(checkpoint['model_state_dict']) |
| | |
| | |
| | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | |
| | |
| | if self.scheduler and checkpoint.get('scheduler_state_dict'): |
| | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| | |
| | |
| | self.current_epoch = checkpoint['epoch'] |
| | self.total_steps = checkpoint['total_steps'] |
| | self.best_loss = checkpoint['best_loss'] |
| | self.training_history = checkpoint.get('training_history', []) |
| | |
| | logger.info(f"β
Resumed from epoch {self.current_epoch}, best loss: {self.best_loss:.6f}") |
| | logger.info(f"Total steps completed: {self.total_steps}") |
| | return True |
| | |
| | except Exception as e: |
| | logger.error(f"Failed to load checkpoint: {e}") |
| | return False |
| |
|
| | def training_step(self, batch: torch.Tensor) -> Dict[str, float]: |
| | """Single training step following the ACTUAL model pattern.""" |
| | batch = batch.to(self.device) |
| | |
| | |
| | self.optimizer.zero_grad() |
| | |
| | |
| | logits, telemetry = self.model(batch) |
| | |
| | |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = batch[:, 1:].reshape(-1) |
| | loss = F.cross_entropy(pred, target) |
| | |
| | |
| | loss.backward() |
| | |
| | |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['max_grad_norm']) |
| | |
| | |
| | self.optimizer.step() |
| | if self.scheduler: |
| | self.scheduler.step() |
| | |
| | self.total_steps += 1 |
| | |
| | |
| | metrics = {'loss': loss.item()} |
| | if telemetry: |
| | for key, value in telemetry.items(): |
| | if torch.is_tensor(value): |
| | metrics[key] = value.mean().item() |
| | else: |
| | metrics[key] = value |
| | |
| | return metrics |
| |
|
| | def train_epoch(self) -> Dict[str, float]: |
| | """Train for one complete epoch.""" |
| | logger.info(f"Starting epoch {self.current_epoch + 1}") |
| | |
| | |
| | self.model.train() |
| | epoch_losses = [] |
| | |
| | |
| | batch_size = self.config['batch_size'] |
| | for i in range(0, len(self.dataset), batch_size): |
| | batch = self.dataset[i:i + batch_size] |
| | if len(batch) < batch_size: |
| | continue |
| | |
| | batch = batch.to(self.device) |
| | |
| | |
| | self.optimizer.zero_grad() |
| | |
| | |
| | logits, telemetry = self.model(batch) |
| | |
| | |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = batch[:, 1:].reshape(-1) |
| | loss = F.cross_entropy(pred, target) |
| | |
| | |
| | loss.backward() |
| | |
| | |
| | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config['max_grad_norm']) |
| | |
| | |
| | self.optimizer.step() |
| | if self.scheduler: |
| | self.scheduler.step() |
| | |
| | self.total_steps += 1 |
| | epoch_losses.append(loss.item()) |
| | |
| | |
| | avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else float('inf') |
| | |
| | epoch_summary = { |
| | 'epoch': self.current_epoch + 1, |
| | 'avg_loss': avg_loss |
| | } |
| | |
| | self.training_history.append(epoch_summary) |
| | |
| | logger.info( |
| | f"Epoch {self.current_epoch + 1} completed: " |
| | f"Avg Loss={avg_loss:.6f}" |
| | ) |
| | |
| | return epoch_summary |
| |
|
| | def train(self): |
| | """Main training loop.""" |
| | logger.info("π STARTING BREAKTHROUGH BITRANSFORMERLM TRAINING!") |
| | logger.info("Configuration: Fixed RL Adafactor + 16M parameters + CPU training") |
| | |
| | start_epoch = self.current_epoch |
| | |
| | for epoch in range(start_epoch, self.config['num_epochs']): |
| | try: |
| | |
| | epoch_metrics = self.train_epoch() |
| | avg_loss = epoch_metrics['avg_loss'] |
| | |
| | |
| | is_best = avg_loss < self.best_loss |
| | if is_best: |
| | self.best_loss = avg_loss |
| | |
| | |
| | self.save_checkpoint(self.current_epoch + 1, avg_loss, is_best) |
| | |
| | self.current_epoch += 1 |
| | |
| | |
| | logger.info(f"=== EPOCH {self.current_epoch} COMPLETE ===") |
| | logger.info(f"Loss: {avg_loss:.6f} (best: {self.best_loss:.6f})") |
| | |
| | |
| | if avg_loss < 3.0: |
| | logger.info("π BREAKTHROUGH PERFORMANCE ACHIEVED! Loss < 3.0!") |
| | |
| | except KeyboardInterrupt: |
| | logger.info("Training interrupted by user") |
| | |
| | try: |
| | self.save_checkpoint(self.current_epoch, float('inf'), False) |
| | except: |
| | pass |
| | break |
| | except Exception as e: |
| | logger.error(f"Error in epoch {self.current_epoch + 1}: {e}") |
| | |
| | try: |
| | self.save_checkpoint(self.current_epoch, float('inf'), False) |
| | except: |
| | pass |
| | raise |
| |
|
| | def main(): |
| | """Main function to run breakthrough training.""" |
| | |
| | |
| | config = { |
| | |
| | 'sequence_length': 512, |
| | |
| | |
| | 'learning_rate': 1e-3, |
| | 'weight_decay': 0.01, |
| | 'batch_size': 4, |
| | 'num_epochs': 50, |
| | 'max_grad_norm': 1.0, |
| | |
| | |
| | 'hf_token': None, |
| | |
| | |
| | 'log_interval': 100, |
| | 'checkpoint_dir': '/data/BitTransformerLM/checkpoints', |
| | } |
| | |
| | |
| | trainer = BreakthroughTrainer(config) |
| | |
| | |
| | logger.info("Setting up training components...") |
| | trainer.load_and_prepare_dataset() |
| | trainer.create_breakthrough_model() |
| | trainer.setup_optimizer() |
| | |
| | |
| | trainer.load_checkpoint() |
| | |
| | |
| | trainer.train() |
| | |
| | logger.info("π BREAKTHROUGH TRAINING COMPLETED!") |
| | logger.info(f"Best loss achieved: {trainer.best_loss:.6f}") |
| | logger.info(f"Checkpoints saved to: {trainer.checkpoint_dir}") |
| |
|
| | if __name__ == "__main__": |
| | main() |