Spaces:
Sleeping
Sleeping
File size: 4,870 Bytes
b798184 f746dfb 31126b4 b798184 6fedc6b c91df11 f746dfb 1414591 f746dfb 6fedc6b f746dfb 51d26e2 c91df11 51d26e2 f746dfb 6fedc6b 1414591 f746dfb 6fedc6b f746dfb 6fedc6b c91df11 f746dfb c91df11 6fedc6b f746dfb 31126b4 6fedc6b c91df11 6fedc6b c91df11 6fedc6b c91df11 6fedc6b c91df11 31126b4 51d26e2 c91df11 31126b4 c91df11 31126b4 1414591 c91df11 31126b4 51d26e2 6fedc6b 1414591 51d26e2 c91df11 1414591 31126b4 6fedc6b 1414591 31126b4 7be6eb8 c91df11 4e5a9fb c91df11 f746dfb 31126b4 c91df11 31126b4 f746dfb 6fedc6b c91df11 f746dfb 4f09d68 6fedc6b f746dfb 6fedc6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
# app.py
import gradio as gr
import os
import time
from train_vlm import train_vlm_stage
from transformers import AutoImageProcessor, AutoTokenizer
from custom_vlm import CustomScratchVLM
import torch
CHECKPOINT_ROOT = "./checkpoints"
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
current_stage = 0
model = None
image_processor = None
tokenizer = None
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Running on device: {device}")
if device == "cuda": print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")
def load_model_for_stage(stage):
global model, image_processor, tokenizer, current_stage
current_stage = stage
ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
if os.path.exists(os.path.join(ckpt_path, "config.json")):
print(f"✅ Loading FROM-SCRATCH checkpoint: Stage {stage}")
if model is not None: del model
if device == "cuda": torch.cuda.empty_cache()
model = CustomScratchVLM.from_pretrained(ckpt_path).to(device).eval()
image_processor = AutoImageProcessor.from_pretrained(ckpt_path)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
else:
print(f"⚠️ No checkpoint for Stage {stage} — model is not loaded.")
model, image_processor, tokenizer = None, None, None
def chat_with_image(image, text, chat_history):
if not all([model, image_processor, tokenizer]):
return "", chat_history + [{"role": "assistant", "content": "Model is not loaded or is currently training. Please wait."}]
if image is None:
return "", chat_history + [{"role": "user", "content": text}, {"role": "assistant", "content": "Please upload an image."}]
try:
pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
# For inference, we do not include the <IMAGE> token in the text prompt
prompt = f"USER: \nQuestion: {text}\nASSISTANT:"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
output_ids = model.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id
)
# Decode only the newly generated tokens
response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
chat_history.append({"role": "user", "content": text})
chat_history.append({"role": "assistant", "content": response or "[No response generated]"})
return "", chat_history
except Exception as e:
return "", chat_history + [{"role": "user", "content": text}, {"role": "assistant", "content": f"⚠️ Error: {e}"}]
def run_autonomous_training_and_update_ui():
yield "🚀 Initializing From-Scratch Trainer..."
for stage in [1, 2, 3]:
ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
if os.path.exists(os.path.join(ckpt_path, "config.json")):
status_message = f"⏭️ Stage {stage} already trained — loading..."
yield status_message
load_model_for_stage(stage)
continue
status_message = f"▶️ AUTO-TRAINING FROM SCRATCH: Stage {stage}"
yield status_message
try:
train_vlm_stage(stage, ckpt_path)
status_message = f"✅ Stage {stage} completed! Loading new model..."
yield status_message
load_model_for_stage(stage)
except Exception as e:
status_message = f"❌ Stage {stage} failed: {e}"
yield status_message; raise e # Stop execution on failure
yield "🎉 COCONUT-VLM Training Complete — All 3 Stages Finished!"
with gr.Blocks(title="🥥 COCONUT-VLM From Scratch") as demo:
gr.Markdown("# 🥥 COCONUT-VLM: Autonomous Trainer (From Scratch)")
gr.Markdown("Model is training itself **from random initialization**. You can interact with the latest trained model.")
with gr.Row():
with gr.Column(scale=1):
status = gr.Textbox(label="Training Status", value="Waiting to start...", interactive=False, lines=10)
with gr.Column(scale=2):
image_input = gr.Image(type="pil", label="Upload Image")
chatbot = gr.Chatbot(label="Chat with the VLM", height=400)
msg = gr.Textbox(label="Ask a question")
clear = gr.Button("Clear Chat")
msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot])
clear.click(lambda: (None, None, []), None, [image_input, msg, chatbot])
demo.load(fn=run_autonomous_training_and_update_ui, inputs=None, outputs=status)
demo.queue().launch(debug=True) |