# app.py — FIXED: Gradio 4.x compatible, no deprecation warnings, auto-trains stages import gradio as gr import threading import os import time from train_vlm import train_vlm_stage from transformers import LlavaForConditionalGeneration, AutoProcessor import torch # --- Config --- MODEL_NAME = "bczhou/TinyLLaVA-3.1B" # or "llava-hf/llava-1.5-7b-hf" CHECKPOINT_ROOT = "./checkpoints" os.makedirs(CHECKPOINT_ROOT, exist_ok=True) # --- Global state --- current_stage = 1 model = None processor = None device = "cuda" if torch.cuda.is_available() else "cpu" training_status = "🚀 Initializing COCONUT-VLM Autonomous Trainer..." print(f"🖥️ Running on device: {device}") if device == "cuda": print(f"🎮 GPU: {torch.cuda.get_device_name(0)}") def load_model_for_stage(stage): global model, processor ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}" if os.path.exists(ckpt_path) and os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")): print(f"✅ Loading checkpoint: Stage {stage}") model = LlavaForConditionalGeneration.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device) processor = AutoProcessor.from_pretrained(ckpt_path) else: print(f"⚠️ No checkpoint for Stage {stage} — loading base model") model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device) processor = AutoProcessor.from_pretrained(MODEL_NAME) def chat_with_image(image, text, chat_history): if model is None or processor is None: load_model_for_stage(current_stage) try: # Format input for model conversation = [ {"role": "user", "content": f"\n{text}"}, ] prompt = processor.apply_chat_template(conversation, tokenize=False) inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) output = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7) response = processor.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) # Append as OpenAI-style messages (fixes deprecation warning) chat_history.append({"role": "user", "content": text}) chat_history.append({"role": "assistant", "content": response}) return "", chat_history except Exception as e: chat_history.append({"role": "user", "content": text}) chat_history.append({"role": "assistant", "content": f"⚠️ Error: {str(e)}"}) return "", chat_history # --- Autonomous Training Pipeline --- def auto_train_pipeline(): global current_stage, training_status for stage in [1, 2, 3]: current_stage = stage training_status = f"▶️ AUTO-TRAINING STARTED: Stage {stage}" print(training_status) ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}" # Skip if already trained if os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")): training_status = f"⏭️ Stage {stage} already trained — loading..." print(training_status) load_model_for_stage(stage) time.sleep(3) continue try: train_vlm_stage(stage, MODEL_NAME, ckpt_path) training_status = f"✅ Stage {stage} completed! Loading model..." print(training_status) load_model_for_stage(stage) if stage < 3: training_status = f"⏳ Advancing to Stage {stage + 1} in 5 seconds..." print(training_status) time.sleep(5) except Exception as e: training_status = f"❌ Stage {stage} failed: {str(e)}" print(training_status) break training_status = "🎉 COCONUT-VLM Training Complete — All 3 Stages Finished!" print(training_status) # --- Launch training on app start --- def initialize_autonomous_trainer(): training_thread = threading.Thread(target=auto_train_pipeline, daemon=True) training_thread.start() # Start the status update process return training_status # --- Status update function --- def update_status(): # Return the current status and trigger the next update time.sleep(0.5) # Small delay to prevent CPU overload return training_status, gr.update(autoscroll=True) # Also autoscroll chat window # --- Gradio UI (Chat-Only) --- with gr.Blocks(title="🥥 COCONUT-VLM Autonomous Trainer") as demo: gr.Markdown("# 🥥 COCONUT-VLM: Autonomous Vision-Language Trainer") gr.Markdown("Model is training itself in 3 stages automatically. **You can only chat.** Training is backend-only.") # We'll create a hidden component to trigger status updates hidden_dummy = gr.State() with gr.Row(): with gr.Column(scale=1): status = gr.Textbox( label="Training Status", value="Initializing...", interactive=False, show_label=False ) gr.Markdown("💡 _Training runs automatically. No buttons. No switching._") with gr.Column(scale=2): image_input = gr.Image(type="pil", label="Upload Image") # ✅ FIXED: Set type="messages" to avoid deprecation warning chatbot = gr.Chatbot(height=400, type="messages") msg = gr.Textbox(label="Ask a question about the image") clear = gr.Button("Clear Chat") # Chat logic - FIXED: Changed output to chatbot msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot]) clear.click(lambda: None, inputs=None, outputs=chatbot, queue=False) # ✅ FIXED: Combined initialization and status updates demo.load( fn=initialize_autonomous_trainer, inputs=None, outputs=status, then=[ fn=update_status, outputs=[status, chatbot], every=1.5, # Update every 1.5 seconds # We'll chain updates to create a continuous loop: then=update_status, outputs=[status, chatbot], every=1.5 ] ) demo.queue(max_size=20).launch()