ai_python / app.py
himu1780's picture
Update app.py
ebe1617 verified
import os
import time
import threading
import torch
import gradio as gr
from datetime import datetime
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling
)
from datasets import load_dataset
# ============ CONFIGURATION ============
# We load "gpt2" as the base to fix the "Unrecognized model" error.
# We save to YOUR repo: "himu1780/ai-python-model"
BASE_MODEL = "gpt2"
MODEL_REPO = "himu1780/ai-python-model"
DATASET_NAME = "iamtarun/python_code_instructions_18k_alpaca"
BATCH_SIZE = 4
GRADIENT_ACCUMULATION = 4
LEARNING_RATE = 2e-4
MAX_STEPS_PER_SESSION = 1000
LOGGING_STEPS = 10
SAVE_STEPS = 500
WAIT_BETWEEN_SESSIONS = 10
CONTINUOUS_TRAINING = True
# Global Status Dictionary
training_status = {
"is_training": False,
"message": "πŸ’€ Idle",
"session_count": 0,
"current_step": 0,
"total_loss": 0.0,
"last_save": "None",
"last_error": "",
"start_time": None
}
stop_requested = False
# ============ CUSTOM TRAINER (FIXED) ============
class StatusTrainer(Trainer):
def log(self, logs, start_time=None): # <--- FIXED: Added start_time
"""
Overriding log to update global status for Gradio UI.
"""
# Update global status
if "loss" in logs:
training_status["total_loss"] = logs["loss"]
if "step" in logs:
training_status["current_step"] = logs["step"]
# Call parent method properly with start_time
super().log(logs, start_time=start_time)
# ============ MEMORY CLEANUP ============
def cleanup_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
import gc
gc.collect()
# ============ TRAINING SESSION ============
def run_training_session():
global training_status
print(f"[INFO] Starting Session {training_status['session_count'] + 1}")
model = None
trainer = None
try:
# 1. Load Tokenizer (Use BASE_MODEL to avoid 'Unrecognized model' error)
print(f"[INFO] Loading tokenizer from {BASE_MODEL}...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 2. Load Model (Use BASE_MODEL to avoid 'Unrecognized model' error)
print(f"[INFO] Loading model from {BASE_MODEL}...")
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True
)
# 3. Load Dataset
print(f"[INFO] Loading dataset {DATASET_NAME}...")
dataset = load_dataset(DATASET_NAME, split="train")
def tokenize_function(examples):
# Simple text tokenization - adjust column name if needed
text_column = "text" if "text" in examples else list(examples.keys())[0]
return tokenizer(examples[text_column], truncation=True, padding="max_length", max_length=512)
print("[INFO] Tokenizing dataset...")
tokenized_datasets = dataset.map(tokenize_function, batched=True)
train_dataset = tokenized_datasets
if len(train_dataset) == 0:
print("❌ Empty dataset!")
return False
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
# 4. Configure Training
training_args = TrainingArguments(
output_dir="./temp_checkpoints",
overwrite_output_dir=True,
per_device_train_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRADIENT_ACCUMULATION,
learning_rate=LEARNING_RATE,
warmup_steps=100,
weight_decay=0.01,
logging_steps=LOGGING_STEPS,
save_steps=SAVE_STEPS,
save_total_limit=1,
push_to_hub=True, # This ensures we still PUSH to your repo
hub_model_id=MODEL_REPO, # Your repo: himu1780/ai-python-model
hub_strategy="every_save",
report_to="none",
max_steps=MAX_STEPS_PER_SESSION,
fp16=torch.cuda.is_available(),
dataloader_num_workers=0,
remove_unused_columns=True,
)
trainer = StatusTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
processing_class=tokenizer,
)
training_status["message"] = "πŸƒ Training in progress..."
print("[INFO] Starting training...")
trainer.train()
print("[INFO] Pushing to hub...")
trainer.push_to_hub()
training_status["session_count"] += 1
training_status["message"] = f"βœ… Session {training_status['session_count']} completed!"
return True
except KeyboardInterrupt:
training_status["message"] = "⏹️ Training stopped by user"
return False
except Exception as e:
training_status["message"] = f"❌ Error: {str(e)[:100]}"
training_status["last_error"] = str(e)
print(f"[ERROR] Training failed: {e}")
import traceback
traceback.print_exc()
return False
finally:
if model is not None:
del model
if trainer is not None:
del trainer
cleanup_memory()
# ============ MAIN TRAINING LOOP ============
def start_training():
global training_status, stop_requested
if training_status["is_training"]:
return "Training already in progress!"
training_status["is_training"] = True
training_status["start_time"] = datetime.now()
stop_requested = False
while not stop_requested:
training_status["message"] = f"πŸš€ Starting session {training_status['session_count'] + 1}..."
success = run_training_session()
if stop_requested:
break
if not CONTINUOUS_TRAINING:
break
if success:
training_status["message"] = f"⏳ Waiting {WAIT_BETWEEN_SESSIONS}s before next session..."
time.sleep(WAIT_BETWEEN_SESSIONS)
else:
training_status["message"] = "⚠️ Session failed, retrying in 60s..."
time.sleep(60)
training_status["is_training"] = False
stop_requested = False
training_status["message"] = f"βœ… Training finished! Total sessions: {training_status['session_count']}"
return training_status["message"]
# ============ GRADIO INTERFACE ============
def get_status():
elapsed = ""
if training_status["start_time"]:
delta = datetime.now() - training_status["start_time"]
hours, remainder = divmod(int(delta.total_seconds()), 3600)
minutes, seconds = divmod(remainder, 60)
elapsed = f"{hours}h {minutes}m {seconds}s"
if training_status["total_loss"]:
loss_str = f"{training_status['total_loss']:.4f}"
else:
loss_str = "N/A"
state_str = "🟒 Training" if training_status["is_training"] else "πŸ”΄ Stopped"
continuous_str = "βœ… Enabled" if CONTINUOUS_TRAINING else "❌ Disabled"
elapsed_str = elapsed if elapsed else "N/A"
effective_batch = BATCH_SIZE * GRADIENT_ACCUMULATION
error_str = training_status["last_error"][:100] if training_status["last_error"] else "None"
return f"""
## πŸ€– AI Python Model Trainer
### Status
| Item | Value |
|------|-------|
| **State** | {state_str} |
| **Message** | {training_status["message"]} |
| **Sessions Completed** | {training_status["session_count"]} |
| **Last Error** | {error_str} |
### Progress
| Metric | Value |
|--------|-------|
| **Current Step** | {training_status["current_step"]:,} / {MAX_STEPS_PER_SESSION:,} |
| **Current Loss** | {loss_str} |
| **Last Checkpoint** | {training_status["last_save"]} |
| **Elapsed Time** | {elapsed_str} |
### Configuration
| Setting | Value |
|---------|-------|
| **Base Model** | {BASE_MODEL} |
| **Target Repo** | [{MODEL_REPO}](https://huggingface.co/{MODEL_REPO}) |
| **Continuous Mode** | {continuous_str} |
| **Batch Size** | {BATCH_SIZE} (effective: {effective_batch}) |
"""
def start_training_async():
if training_status["is_training"]:
return "⚠️ Training already in progress!"
thread = threading.Thread(target=start_training, daemon=True)
thread.start()
return "πŸš€ Training started in background!"
def stop_training():
global stop_requested
if not training_status["is_training"]:
return "⚠️ No training in progress"
stop_requested = True
training_status["message"] = "⏹️ Stopping after current step..."
return "⏹️ Stop requested"
# ============ AUTO-START ============
def auto_start():
time.sleep(10)
while True:
if not training_status["is_training"] and not stop_requested:
print("[INFO] Auto-starting training session...")
start_training()
time.sleep(WAIT_BETWEEN_SESSIONS)
auto_thread = threading.Thread(target=auto_start, daemon=True)
auto_thread.start()
# ============ GRADIO APP ============
with gr.Blocks(title="AI Python Trainer") as demo:
gr.Markdown("# 🐍 AI Python Code Model Trainer")
gr.Markdown(f"**Continuous training** from `{BASE_MODEL}` -> `{MODEL_REPO}`")
status_display = gr.Markdown(get_status)
with gr.Row():
start_btn = gr.Button("▢️ Start Training", variant="primary")
stop_btn = gr.Button("⏹️ Stop Training", variant="stop")
refresh_btn = gr.Button("πŸ”„ Refresh Status")
output = gr.Textbox(label="Output", interactive=False)
start_btn.click(start_training_async, outputs=output)
stop_btn.click(stop_training, outputs=output)
refresh_btn.click(get_status, outputs=status_display)
demo.launch()