| | |
| | """ |
| | Basic BitTransformerLM Training Script |
| | ===================================== |
| | |
| | A simple working training script that follows the ACTUAL BitTransformerLM |
| | model implementation exactly as it exists in the codebase. |
| | """ |
| |
|
| | import sys |
| | import os |
| | import logging |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | |
| | sys.path.append('/data') |
| | sys.path.append('/data/BitTransformerLM') |
| |
|
| | from bit_transformer import BitTransformerLM, text_to_bits |
| | from BTLM_Extensions import configure_adafactor_optimizer |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| | logger = logging.getLogger(__name__) |
| |
|
| | def create_simple_dataset(): |
| | """Create a simple bit dataset for testing.""" |
| | logger.info("Creating simple bit dataset...") |
| | |
| | |
| | texts = [ |
| | "Hello world! This is a test.", |
| | "BitTransformerLM processes bits natively.", |
| | "Training on binary sequences is interesting.", |
| | "Each character becomes 9 bits with parity.", |
| | "The model learns bit patterns directly.", |
| | ] |
| | |
| | |
| | bit_sequences = [] |
| | for text in texts: |
| | bits = text_to_bits(text) |
| | bit_sequences.append(bits) |
| | |
| | |
| | max_len = min(64, max(len(bits) for bits in bit_sequences)) |
| | |
| | training_data = [] |
| | for bits in bit_sequences: |
| | if len(bits) >= max_len: |
| | |
| | for i in range(0, len(bits) - max_len + 1, max_len // 2): |
| | chunk = bits[i:i + max_len] |
| | if len(chunk) == max_len: |
| | training_data.append(chunk) |
| | |
| | |
| | data_tensor = torch.tensor(training_data, dtype=torch.long) |
| | logger.info(f"Created dataset: {data_tensor.shape}") |
| | |
| | return data_tensor |
| |
|
| | def create_model(): |
| | """Create a small BitTransformerLM model for testing.""" |
| | logger.info("Creating BitTransformerLM model...") |
| | |
| | |
| | model = BitTransformerLM( |
| | d_model=128, |
| | nhead=8, |
| | num_layers=2, |
| | dim_feedforward=256, |
| | max_seq_len=64, |
| | lambda_K=0.1, |
| | lambda_C=0.1, |
| | lambda_S=0.1, |
| | use_checkpoint=False, |
| | use_autocast=False, |
| | use_act=False |
| | ) |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | logger.info(f"Model created: {total_params:,} parameters") |
| | |
| | return model |
| |
|
| | def train_basic(): |
| | """Basic training loop following the example_training_step pattern.""" |
| | logger.info("Starting basic BitTransformerLM training...") |
| | |
| | |
| | model = create_model() |
| | data = create_simple_dataset() |
| | |
| | |
| | batch_size = 2 |
| | epochs = 5 |
| | total_steps = (len(data) // batch_size) * epochs |
| | |
| | |
| | logger.info("Configuring Fixed RL Adafactor optimizer...") |
| | optimizer, scheduler = configure_adafactor_optimizer( |
| | model, |
| | lr=1e-3, |
| | weight_decay=0.01, |
| | total_steps=total_steps |
| | ) |
| | |
| | logger.info("Starting training loop...") |
| | |
| | |
| | |
| | model.train() |
| | |
| | for epoch in range(epochs): |
| | epoch_losses = [] |
| | |
| | |
| | for i in range(0, len(data), batch_size): |
| | batch = data[i:i + batch_size] |
| | if len(batch) < batch_size: |
| | continue |
| | |
| | |
| | optimizer.zero_grad() |
| | |
| | |
| | logits, telemetry = model(batch) |
| | |
| | |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = batch[:, 1:].reshape(-1) |
| | loss = F.cross_entropy(pred, target) |
| | |
| | |
| | loss.backward() |
| | |
| | |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | |
| | |
| | optimizer.step() |
| | if scheduler: |
| | scheduler.step() |
| | |
| | epoch_losses.append(loss.item()) |
| | |
| | |
| | avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else float('inf') |
| | logger.info(f"Epoch {epoch + 1}/{epochs}: Average Loss = {avg_loss:.6f}") |
| | |
| | |
| | if telemetry: |
| | for key, value in telemetry.items(): |
| | if torch.is_tensor(value): |
| | logger.info(f" {key}: {value.mean().item():.4f}") |
| | |
| | logger.info("Basic training completed successfully!") |
| | return model |
| |
|
| | def main(): |
| | """Main function.""" |
| | logger.info("π Starting basic BitTransformerLM training test") |
| | |
| | try: |
| | trained_model = train_basic() |
| | logger.info("β
Basic training test PASSED!") |
| | |
| | |
| | torch.save(trained_model.state_dict(), '/data/BitTransformerLM/basic_model.pt') |
| | logger.info("Model saved to basic_model.pt") |
| | |
| | except Exception as e: |
| | logger.error(f"β Training failed: {e}") |
| | raise |
| |
|
| | if __name__ == "__main__": |
| | main() |