File size: 4,870 Bytes
b798184
f746dfb
 
31126b4
b798184
6fedc6b
c91df11
f746dfb
 
 
 
 
1414591
f746dfb
6fedc6b
 
f746dfb
 
51d26e2
c91df11
51d26e2
f746dfb
6fedc6b
1414591
f746dfb
6fedc6b
 
 
 
 
 
 
 
 
f746dfb
6fedc6b
c91df11
f746dfb
 
c91df11
 
6fedc6b
 
 
f746dfb
31126b4
6fedc6b
 
c91df11
 
 
 
 
6fedc6b
 
 
c91df11
 
6fedc6b
 
c91df11
 
6fedc6b
 
c91df11
 
31126b4
51d26e2
c91df11
31126b4
 
c91df11
31126b4
1414591
c91df11
31126b4
51d26e2
6fedc6b
1414591
 
51d26e2
 
 
c91df11
1414591
31126b4
6fedc6b
1414591
 
31126b4
 
7be6eb8
c91df11
 
4e5a9fb
c91df11
 
 
f746dfb
31126b4
c91df11
31126b4
f746dfb
6fedc6b
c91df11
f746dfb
4f09d68
6fedc6b
 
f746dfb
6fedc6b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# app.py
import gradio as gr
import os
import time
from train_vlm import train_vlm_stage
from transformers import AutoImageProcessor, AutoTokenizer
from custom_vlm import CustomScratchVLM
import torch

CHECKPOINT_ROOT = "./checkpoints"
os.makedirs(CHECKPOINT_ROOT, exist_ok=True)

current_stage = 0
model = None
image_processor = None
tokenizer = None
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"🖥️ Running on device: {device}")
if device == "cuda": print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")

def load_model_for_stage(stage):
    global model, image_processor, tokenizer, current_stage
    current_stage = stage
    ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
    
    if os.path.exists(os.path.join(ckpt_path, "config.json")):
        print(f"✅ Loading FROM-SCRATCH checkpoint: Stage {stage}")
        if model is not None: del model
        if device == "cuda": torch.cuda.empty_cache()
            
        model = CustomScratchVLM.from_pretrained(ckpt_path).to(device).eval()
        image_processor = AutoImageProcessor.from_pretrained(ckpt_path)
        tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
    else:
        print(f"⚠️ No checkpoint for Stage {stage} — model is not loaded.")
        model, image_processor, tokenizer = None, None, None

def chat_with_image(image, text, chat_history):
    if not all([model, image_processor, tokenizer]):
        return "", chat_history + [{"role": "assistant", "content": "Model is not loaded or is currently training. Please wait."}]
    
    if image is None:
        return "", chat_history + [{"role": "user", "content": text}, {"role": "assistant", "content": "Please upload an image."}]

    try:
        pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device)
        
        # For inference, we do not include the <IMAGE> token in the text prompt
        prompt = f"USER: \nQuestion: {text}\nASSISTANT:"
        inputs = tokenizer(prompt, return_tensors="pt")
        input_ids = inputs.input_ids.to(device)
        attention_mask = inputs.attention_mask.to(device)

        output_ids = model.generate(
            pixel_values=pixel_values, 
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=256,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
        
        # Decode only the newly generated tokens
        response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)

        chat_history.append({"role": "user", "content": text})
        chat_history.append({"role": "assistant", "content": response or "[No response generated]"})
        return "", chat_history
    except Exception as e:
        return "", chat_history + [{"role": "user", "content": text}, {"role": "assistant", "content": f"⚠️ Error: {e}"}]

def run_autonomous_training_and_update_ui():
    yield "🚀 Initializing From-Scratch Trainer..."
    for stage in [1, 2, 3]:
        ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
        if os.path.exists(os.path.join(ckpt_path, "config.json")):
            status_message = f"⏭️ Stage {stage} already trained — loading..."
            yield status_message
            load_model_for_stage(stage)
            continue

        status_message = f"▶️ AUTO-TRAINING FROM SCRATCH: Stage {stage}"
        yield status_message
        try:
            train_vlm_stage(stage, ckpt_path)
            status_message = f"✅ Stage {stage} completed! Loading new model..."
            yield status_message
            load_model_for_stage(stage)
        except Exception as e:
            status_message = f"❌ Stage {stage} failed: {e}"
            yield status_message; raise e # Stop execution on failure
    yield "🎉 COCONUT-VLM Training Complete — All 3 Stages Finished!"

with gr.Blocks(title="🥥 COCONUT-VLM From Scratch") as demo:
    gr.Markdown("# 🥥 COCONUT-VLM: Autonomous Trainer (From Scratch)")
    gr.Markdown("Model is training itself **from random initialization**. You can interact with the latest trained model.")
    with gr.Row():
        with gr.Column(scale=1):
            status = gr.Textbox(label="Training Status", value="Waiting to start...", interactive=False, lines=10)
        with gr.Column(scale=2):
            image_input = gr.Image(type="pil", label="Upload Image")
            chatbot = gr.Chatbot(label="Chat with the VLM", height=400)
            msg = gr.Textbox(label="Ask a question")
            clear = gr.Button("Clear Chat")
    msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot])
    clear.click(lambda: (None, None, []), None, [image_input, msg, chatbot])
    demo.load(fn=run_autonomous_training_and_update_ui, inputs=None, outputs=status)

demo.queue().launch(debug=True)