Keeby-smilyai commited on
Commit
f746dfb
·
verified ·
1 Parent(s): 622c2dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import threading
4
+ import os
5
+ from train_vlm import train_vlm_stage
6
+ from transformers import LlavaForConditionalGeneration, AutoProcessor
7
+ import torch
8
+
9
+ # --- Config ---
10
+ MODEL_NAME = "bczhou/TinyLLaVA-3.1B" # or "" for faster training
11
+ HF_USERNAME = "Smilyai-labs-research"
12
+ YOUR_SPACE_REPO = "Smilyai-labs-research/VISION-LLM-COT"
13
+ CHECKPOINT_ROOT = "./checkpoints"
14
+
15
+ os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
16
+
17
+ # --- Global state ---
18
+ current_stage = 1
19
+ model = None
20
+ processor = None
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ def load_model_for_stage(stage):
24
+ global model, processor
25
+ ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
26
+ if os.path.exists(ckpt_path):
27
+ print(f"Loading checkpoint from {ckpt_path}")
28
+ model = LlavaForConditionalGeneration.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
29
+ processor = AutoProcessor.from_pretrained(ckpt_path)
30
+ else:
31
+ print(f"No checkpoint for stage {stage}, loading base model")
32
+ model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
33
+ processor = AutoProcessor.from_pretrained(MODEL_NAME)
34
+
35
+ def chat_with_image(image, text, chat_history):
36
+ if model is None or processor is None:
37
+ load_model_for_stage(current_stage)
38
+
39
+ conversation = [
40
+ {"role": "user", "content": f"<image>\n{text}"},
41
+ ]
42
+ prompt = processor.apply_chat_template(conversation, tokenize=False)
43
+
44
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
45
+ output = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
46
+ response = processor.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
47
+
48
+ chat_history.append((text, response))
49
+ return "", chat_history
50
+
51
+ def start_training(stage):
52
+ global current_stage
53
+ current_stage = stage
54
+ thread = threading.Thread(target=train_vlm_stage, args=(stage, MODEL_NAME, f"{CHECKPOINT_ROOT}/stage_{stage}"))
55
+ thread.start()
56
+ return f"▶️ Training started for Stage {stage}. Check logs."
57
+
58
+ def switch_stage(stage):
59
+ global current_stage
60
+ current_stage = stage
61
+ load_model_for_stage(stage)
62
+ return f"✅ Switched to Stage {stage}. Model reloaded."
63
+
64
+ # --- Gradio UI ---
65
+ with gr.Blocks(title="🥥 VLM COCONUT Trainer") as demo:
66
+ gr.Markdown("# 🥥 Vision-Language COCONUT CoT Trainer (Real Training!)")
67
+ gr.Markdown("Train a VLM in 3 stages. Chat with the latest stage.")
68
+
69
+ with gr.Row():
70
+ with gr.Column():
71
+ stage_btn1 = gr.Button("Stage 1: Plain CoT", variant="primary")
72
+ stage_btn2 = gr.Button("Stage 2: Masked Thought")
73
+ stage_btn3 = gr.Button("Stage 3: COCONUT Mode")
74
+ status = gr.Textbox(label="Status", interactive=False)
75
+
76
+ with gr.Column():
77
+ image_input = gr.Image(type="pil", label="Upload Image")
78
+ chatbot = gr.Chatbot(height=400)
79
+ msg = gr.Textbox(label="Ask a question about the image")
80
+ clear = gr.Button("Clear Chat")
81
+
82
+ # Event bindings
83
+ stage_btn1.click(lambda: switch_stage(1), None, status)
84
+ stage_btn2.click(lambda: switch_stage(2), None, status)
85
+ stage_btn3.click(lambda: switch_stage(3), None, status)
86
+
87
+ msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot])
88
+ clear.click(lambda: None, None, chatbot, queue=False)
89
+
90
+ gr.Markdown("## ⚙️ Start Training (Uses your GPU Grant!)")
91
+ train_btn1 = gr.Button("▶️ Train Stage 1")
92
+ train_btn2 = gr.Button("▶️ Train Stage 2")
93
+ train_btn3 = gr.Button("▶️ Train Stage 3")
94
+
95
+ train_btn1.click(lambda: start_training(1), None, status)
96
+ train_btn2.click(lambda: start_training(2), None, status)
97
+ train_btn3.click(lambda: start_training(3), None, status)
98
+
99
+ demo.queue(max_size=10).launch()