import os import time import threading import torch import gradio as gr from datetime import datetime from transformers import ( AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling ) from datasets import load_dataset # ============ CONFIGURATION ============ # We load "gpt2" as the base to fix the "Unrecognized model" error. # We save to YOUR repo: "himu1780/ai-python-model" BASE_MODEL = "gpt2" MODEL_REPO = "himu1780/ai-python-model" DATASET_NAME = "iamtarun/python_code_instructions_18k_alpaca" BATCH_SIZE = 4 GRADIENT_ACCUMULATION = 4 LEARNING_RATE = 2e-4 MAX_STEPS_PER_SESSION = 1000 LOGGING_STEPS = 10 SAVE_STEPS = 500 WAIT_BETWEEN_SESSIONS = 10 CONTINUOUS_TRAINING = True # Global Status Dictionary training_status = { "is_training": False, "message": "💤 Idle", "session_count": 0, "current_step": 0, "total_loss": 0.0, "last_save": "None", "last_error": "", "start_time": None } stop_requested = False # ============ CUSTOM TRAINER (FIXED) ============ class StatusTrainer(Trainer): def log(self, logs, start_time=None): # <--- FIXED: Added start_time """ Overriding log to update global status for Gradio UI. """ # Update global status if "loss" in logs: training_status["total_loss"] = logs["loss"] if "step" in logs: training_status["current_step"] = logs["step"] # Call parent method properly with start_time super().log(logs, start_time=start_time) # ============ MEMORY CLEANUP ============ def cleanup_memory(): if torch.cuda.is_available(): torch.cuda.empty_cache() import gc gc.collect() # ============ TRAINING SESSION ============ def run_training_session(): global training_status print(f"[INFO] Starting Session {training_status['session_count'] + 1}") model = None trainer = None try: # 1. Load Tokenizer (Use BASE_MODEL to avoid 'Unrecognized model' error) print(f"[INFO] Loading tokenizer from {BASE_MODEL}...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # 2. Load Model (Use BASE_MODEL to avoid 'Unrecognized model' error) print(f"[INFO] Loading model from {BASE_MODEL}...") model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True ) # 3. Load Dataset print(f"[INFO] Loading dataset {DATASET_NAME}...") dataset = load_dataset(DATASET_NAME, split="train") def tokenize_function(examples): # Simple text tokenization - adjust column name if needed text_column = "text" if "text" in examples else list(examples.keys())[0] return tokenizer(examples[text_column], truncation=True, padding="max_length", max_length=512) print("[INFO] Tokenizing dataset...") tokenized_datasets = dataset.map(tokenize_function, batched=True) train_dataset = tokenized_datasets if len(train_dataset) == 0: print("❌ Empty dataset!") return False data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, ) # 4. Configure Training training_args = TrainingArguments( output_dir="./temp_checkpoints", overwrite_output_dir=True, per_device_train_batch_size=BATCH_SIZE, gradient_accumulation_steps=GRADIENT_ACCUMULATION, learning_rate=LEARNING_RATE, warmup_steps=100, weight_decay=0.01, logging_steps=LOGGING_STEPS, save_steps=SAVE_STEPS, save_total_limit=1, push_to_hub=True, # This ensures we still PUSH to your repo hub_model_id=MODEL_REPO, # Your repo: himu1780/ai-python-model hub_strategy="every_save", report_to="none", max_steps=MAX_STEPS_PER_SESSION, fp16=torch.cuda.is_available(), dataloader_num_workers=0, remove_unused_columns=True, ) trainer = StatusTrainer( model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator, processing_class=tokenizer, ) training_status["message"] = "🏃 Training in progress..." print("[INFO] Starting training...") trainer.train() print("[INFO] Pushing to hub...") trainer.push_to_hub() training_status["session_count"] += 1 training_status["message"] = f"✅ Session {training_status['session_count']} completed!" return True except KeyboardInterrupt: training_status["message"] = "âšī¸ Training stopped by user" return False except Exception as e: training_status["message"] = f"❌ Error: {str(e)[:100]}" training_status["last_error"] = str(e) print(f"[ERROR] Training failed: {e}") import traceback traceback.print_exc() return False finally: if model is not None: del model if trainer is not None: del trainer cleanup_memory() # ============ MAIN TRAINING LOOP ============ def start_training(): global training_status, stop_requested if training_status["is_training"]: return "Training already in progress!" training_status["is_training"] = True training_status["start_time"] = datetime.now() stop_requested = False while not stop_requested: training_status["message"] = f"🚀 Starting session {training_status['session_count'] + 1}..." success = run_training_session() if stop_requested: break if not CONTINUOUS_TRAINING: break if success: training_status["message"] = f"âŗ Waiting {WAIT_BETWEEN_SESSIONS}s before next session..." time.sleep(WAIT_BETWEEN_SESSIONS) else: training_status["message"] = "âš ī¸ Session failed, retrying in 60s..." time.sleep(60) training_status["is_training"] = False stop_requested = False training_status["message"] = f"✅ Training finished! Total sessions: {training_status['session_count']}" return training_status["message"] # ============ GRADIO INTERFACE ============ def get_status(): elapsed = "" if training_status["start_time"]: delta = datetime.now() - training_status["start_time"] hours, remainder = divmod(int(delta.total_seconds()), 3600) minutes, seconds = divmod(remainder, 60) elapsed = f"{hours}h {minutes}m {seconds}s" if training_status["total_loss"]: loss_str = f"{training_status['total_loss']:.4f}" else: loss_str = "N/A" state_str = "đŸŸĸ Training" if training_status["is_training"] else "🔴 Stopped" continuous_str = "✅ Enabled" if CONTINUOUS_TRAINING else "❌ Disabled" elapsed_str = elapsed if elapsed else "N/A" effective_batch = BATCH_SIZE * GRADIENT_ACCUMULATION error_str = training_status["last_error"][:100] if training_status["last_error"] else "None" return f""" ## 🤖 AI Python Model Trainer ### Status | Item | Value | |------|-------| | **State** | {state_str} | | **Message** | {training_status["message"]} | | **Sessions Completed** | {training_status["session_count"]} | | **Last Error** | {error_str} | ### Progress | Metric | Value | |--------|-------| | **Current Step** | {training_status["current_step"]:,} / {MAX_STEPS_PER_SESSION:,} | | **Current Loss** | {loss_str} | | **Last Checkpoint** | {training_status["last_save"]} | | **Elapsed Time** | {elapsed_str} | ### Configuration | Setting | Value | |---------|-------| | **Base Model** | {BASE_MODEL} | | **Target Repo** | [{MODEL_REPO}](https://huggingface.co/{MODEL_REPO}) | | **Continuous Mode** | {continuous_str} | | **Batch Size** | {BATCH_SIZE} (effective: {effective_batch}) | """ def start_training_async(): if training_status["is_training"]: return "âš ī¸ Training already in progress!" thread = threading.Thread(target=start_training, daemon=True) thread.start() return "🚀 Training started in background!" def stop_training(): global stop_requested if not training_status["is_training"]: return "âš ī¸ No training in progress" stop_requested = True training_status["message"] = "âšī¸ Stopping after current step..." return "âšī¸ Stop requested" # ============ AUTO-START ============ def auto_start(): time.sleep(10) while True: if not training_status["is_training"] and not stop_requested: print("[INFO] Auto-starting training session...") start_training() time.sleep(WAIT_BETWEEN_SESSIONS) auto_thread = threading.Thread(target=auto_start, daemon=True) auto_thread.start() # ============ GRADIO APP ============ with gr.Blocks(title="AI Python Trainer") as demo: gr.Markdown("# 🐍 AI Python Code Model Trainer") gr.Markdown(f"**Continuous training** from `{BASE_MODEL}` -> `{MODEL_REPO}`") status_display = gr.Markdown(get_status) with gr.Row(): start_btn = gr.Button("â–ļī¸ Start Training", variant="primary") stop_btn = gr.Button("âšī¸ Stop Training", variant="stop") refresh_btn = gr.Button("🔄 Refresh Status") output = gr.Textbox(label="Output", interactive=False) start_btn.click(start_training_async, outputs=output) stop_btn.click(stop_training, outputs=output) refresh_btn.click(get_status, outputs=status_display) demo.launch()