| | |
| | """ |
| | BitTransformerLM Conversation Test Script |
| | ========================================= |
| | |
| | Load the trained breakthrough model and test its conversational capabilities! |
| | """ |
| |
|
| | import sys |
| | import os |
| | import torch |
| | import torch.nn.functional as F |
| | from pathlib import Path |
| |
|
| | |
| | sys.path.append('/data') |
| | sys.path.append('/data/BitTransformerLM') |
| |
|
| | from bit_transformer import BitTransformerLM, text_to_bits, bits_to_text |
| |
|
| | def load_breakthrough_model(): |
| | """Load the trained breakthrough BitTransformerLM.""" |
| | print("π Loading breakthrough BitTransformerLM...") |
| | |
| | |
| | model = BitTransformerLM( |
| | d_model=512, |
| | nhead=16, |
| | num_layers=8, |
| | dim_feedforward=1024, |
| | max_seq_len=512, |
| | reversible=True, |
| | use_checkpoint=True, |
| | use_autocast=True, |
| | use_act=True, |
| | act_threshold=0.9, |
| | lambda_K=0.05, |
| | lambda_C=0.05, |
| | lambda_S=0.05 |
| | ) |
| | |
| | |
| | checkpoint_path = '/data/BitTransformerLM/checkpoints/checkpoint_latest.pt' |
| | print(f"Loading checkpoint: {checkpoint_path}") |
| | |
| | checkpoint = torch.load(checkpoint_path, map_location='cpu') |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | |
| | |
| | model.eval() |
| | |
| | print(f"β
Model loaded successfully!") |
| | print(f"π Checkpoint info:") |
| | print(f" - Epoch: {checkpoint['epoch']}") |
| | print(f" - Loss: {checkpoint['loss']:.6f}") |
| | print(f" - Best Loss: {checkpoint['best_loss']:.6f}") |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | print(f" - Parameters: {total_params:,}") |
| | |
| | return model |
| |
|
| | def generate_text(model, prompt, max_length=100, temperature=0.8, top_p=0.9): |
| | """Generate text using the breakthrough model.""" |
| | print(f"\nπ€ Generating response to: '{prompt}'") |
| | |
| | |
| | input_bits = text_to_bits(prompt) |
| | print(f"π Input bits: {len(input_bits)} bits") |
| | |
| | |
| | if len(input_bits) > 200: |
| | input_bits = input_bits[:200] |
| | |
| | |
| | input_tensor = torch.tensor(input_bits, dtype=torch.long).unsqueeze(0) |
| | |
| | generated_bits = input_bits.copy() |
| | |
| | print("π Generating...") |
| | |
| | with torch.no_grad(): |
| | for i in range(max_length): |
| | |
| | current_seq = generated_bits[-256:] if len(generated_bits) > 256 else generated_bits |
| | current_tensor = torch.tensor(current_seq, dtype=torch.long).unsqueeze(0) |
| | |
| | |
| | logits, telemetry = model(current_tensor) |
| | |
| | |
| | next_bit_logits = logits[0, -1, :] |
| | |
| | |
| | next_bit_logits = next_bit_logits / temperature |
| | |
| | |
| | if top_p < 1.0: |
| | sorted_logits, sorted_indices = torch.sort(next_bit_logits, descending=True) |
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| | |
| | |
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() |
| | sorted_indices_to_remove[0] = 0 |
| | |
| | next_bit_logits[sorted_indices[sorted_indices_to_remove]] = float('-inf') |
| | |
| | |
| | probs = F.softmax(next_bit_logits, dim=-1) |
| | next_bit = torch.multinomial(probs, 1).item() |
| | |
| | generated_bits.append(next_bit) |
| | |
| | |
| | if i % 9 == 8: |
| | try: |
| | partial_text = bits_to_text(generated_bits[len(input_bits):]) |
| | if len(partial_text) > 0: |
| | |
| | if partial_text.endswith(('.', '!', '?', '\n')): |
| | break |
| | except: |
| | continue |
| | |
| | |
| | generated_only = generated_bits[len(input_bits):] |
| | |
| | try: |
| | generated_text = bits_to_text(generated_only) |
| | print(f"β¨ Generated text: '{generated_text}'") |
| | print(f"π Generated {len(generated_only)} bits -> {len(generated_text)} characters") |
| | |
| | if telemetry: |
| | print(f"π Final telemetry: K={telemetry.get('negentropy_logits', 0):.3f}, " + |
| | f"C={telemetry.get('lz_complexity_logits', 0):.3f}, " + |
| | f"S={telemetry.get('symbiosis_score', 0):.3f}") |
| | |
| | return prompt + generated_text |
| | |
| | except Exception as e: |
| | print(f"β Failed to decode generated bits: {e}") |
| | print(f"Raw bits: {generated_only[:50]}..." if len(generated_only) > 50 else f"Raw bits: {generated_only}") |
| | return None |
| |
|
| | def interactive_conversation(model): |
| | """Interactive conversation loop.""" |
| | print("\nπ― BREAKTHROUGH BITRANSFORMERLM CONVERSATION TEST") |
| | print("=" * 60) |
| | print("Type 'quit' to exit, 'help' for commands") |
| | print() |
| | |
| | conversation_history = "" |
| | |
| | while True: |
| | try: |
| | |
| | user_input = input("You: ").strip() |
| | |
| | if user_input.lower() in ['quit', 'exit', 'q']: |
| | print("π Goodbye!") |
| | break |
| | |
| | if user_input.lower() == 'help': |
| | print("Commands:") |
| | print(" quit/exit/q - Exit conversation") |
| | print(" help - Show this help") |
| | print(" clear - Clear conversation history") |
| | continue |
| | |
| | if user_input.lower() == 'clear': |
| | conversation_history = "" |
| | print("π§Ή Conversation history cleared") |
| | continue |
| | |
| | if not user_input: |
| | continue |
| | |
| | |
| | conversation_history += f"Human: {user_input}\nAI: " |
| | |
| | |
| | response = generate_text( |
| | model, |
| | conversation_history, |
| | max_length=150, |
| | temperature=0.8, |
| | top_p=0.9 |
| | ) |
| | |
| | if response: |
| | |
| | ai_response = response[len(conversation_history):] |
| | print(f"AI: {ai_response}") |
| | |
| | |
| | conversation_history += ai_response + "\n" |
| | |
| | |
| | if len(conversation_history) > 500: |
| | |
| | lines = conversation_history.split('\n') |
| | conversation_history = '\n'.join(lines[-10:]) |
| | else: |
| | print("AI: [Failed to generate response]") |
| | |
| | except KeyboardInterrupt: |
| | print("\nπ Goodbye!") |
| | break |
| | except Exception as e: |
| | print(f"β Error: {e}") |
| |
|
| | def main(): |
| | """Main conversation test function.""" |
| | print("π BITRANSFORMERLM BREAKTHROUGH CONVERSATION TEST") |
| | print("=" * 60) |
| | |
| | |
| | model = load_breakthrough_model() |
| | |
| | print("\nπ§ͺ QUICK TESTS:") |
| | |
| | |
| | print("\n--- Test 1: Simple Greeting ---") |
| | generate_text(model, "Hello", max_length=50) |
| | |
| | |
| | print("\n--- Test 2: Question ---") |
| | generate_text(model, "What is", max_length=50) |
| | |
| | |
| | print("\n--- Test 3: Conversation ---") |
| | generate_text(model, "Hi there! How are you?", max_length=80) |
| | |
| | print("\n" + "=" * 60) |
| | print("Ready for interactive conversation!") |
| | |
| | |
| | interactive_conversation(model) |
| |
|
| | if __name__ == "__main__": |
| | main() |