Keeby-smilyai commited on
Commit
31126b4
·
verified ·
1 Parent(s): 5d7aeb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -56
app.py CHANGED
@@ -1,17 +1,15 @@
1
- # app.py
2
  import gradio as gr
3
  import threading
4
  import os
 
5
  from train_vlm import train_vlm_stage
6
  from transformers import LlavaForConditionalGeneration, AutoProcessor
7
  import torch
8
 
9
  # --- Config ---
10
- MODEL_NAME = "bczhou/TinyLLaVA-3.1B" # or "" for faster training
11
- HF_USERNAME = "Smilyai-labs-research"
12
- YOUR_SPACE_REPO = "Smilyai-labs-research/VISION-LLM-COT"
13
  CHECKPOINT_ROOT = "./checkpoints"
14
-
15
  os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
16
 
17
  # --- Global state ---
@@ -19,16 +17,17 @@ current_stage = 1
19
  model = None
20
  processor = None
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
22
 
23
  def load_model_for_stage(stage):
24
  global model, processor
25
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
26
  if os.path.exists(ckpt_path):
27
- print(f"Loading checkpoint from {ckpt_path}")
28
  model = LlavaForConditionalGeneration.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
29
  processor = AutoProcessor.from_pretrained(ckpt_path)
30
  else:
31
- print(f"No checkpoint for stage {stage}, loading base model")
32
  model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
33
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
34
 
@@ -36,64 +35,83 @@ def chat_with_image(image, text, chat_history):
36
  if model is None or processor is None:
37
  load_model_for_stage(current_stage)
38
 
39
- conversation = [
40
- {"role": "user", "content": f"<image>\n{text}"},
41
- ]
42
- prompt = processor.apply_chat_template(conversation, tokenize=False)
43
-
44
- inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
45
- output = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
46
- response = processor.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
47
-
48
- chat_history.append((text, response))
49
- return "", chat_history
50
-
51
- def start_training(stage):
52
- global current_stage
53
- current_stage = stage
54
- thread = threading.Thread(target=train_vlm_stage, args=(stage, MODEL_NAME, f"{CHECKPOINT_ROOT}/stage_{stage}"))
55
- thread.start()
56
- return f"▶️ Training started for Stage {stage}. Check logs."
57
-
58
- def switch_stage(stage):
59
- global current_stage
60
- current_stage = stage
61
- load_model_for_stage(stage)
62
- return f"✅ Switched to Stage {stage}. Model reloaded."
63
-
64
- # --- Gradio UI ---
65
- with gr.Blocks(title="🥥 VLM COCONUT Trainer") as demo:
66
- gr.Markdown("# 🥥 Vision-Language COCONUT CoT Trainer (Real Training!)")
67
- gr.Markdown("Train a VLM in 3 stages. Chat with the latest stage.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  with gr.Row():
70
- with gr.Column():
71
- stage_btn1 = gr.Button("Stage 1: Plain CoT", variant="primary")
72
- stage_btn2 = gr.Button("Stage 2: Masked Thought")
73
- stage_btn3 = gr.Button("Stage 3: COCONUT Mode")
74
- status = gr.Textbox(label="Status", interactive=False)
75
 
76
- with gr.Column():
77
  image_input = gr.Image(type="pil", label="Upload Image")
78
  chatbot = gr.Chatbot(height=400)
79
  msg = gr.Textbox(label="Ask a question about the image")
80
  clear = gr.Button("Clear Chat")
81
 
82
- # Event bindings
83
- stage_btn1.click(lambda: switch_stage(1), None, status)
84
- stage_btn2.click(lambda: switch_stage(2), None, status)
85
- stage_btn3.click(lambda: switch_stage(3), None, status)
86
-
87
  msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot])
88
  clear.click(lambda: None, None, chatbot, queue=False)
89
 
90
- gr.Markdown("## ⚙️ Start Training (Uses your GPU Grant!)")
91
- train_btn1 = gr.Button("▶️ Train Stage 1")
92
- train_btn2 = gr.Button("▶️ Train Stage 2")
93
- train_btn3 = gr.Button("▶️ Train Stage 3")
94
-
95
- train_btn1.click(lambda: start_training(1), None, status)
96
- train_btn2.click(lambda: start_training(2), None, status)
97
- train_btn3.click(lambda: start_training(3), None, status)
98
 
99
- demo.queue(max_size=10).launch()
 
1
+ # app.py — Fully autonomous 3-stage VLM trainer. UI is chat-only.
2
  import gradio as gr
3
  import threading
4
  import os
5
+ import time
6
  from train_vlm import train_vlm_stage
7
  from transformers import LlavaForConditionalGeneration, AutoProcessor
8
  import torch
9
 
10
  # --- Config ---
11
+ MODEL_NAME = "bczhou/TinyLLaVA-3.1B" # or "llava-hf/llava-1.5-7b-hf"
 
 
12
  CHECKPOINT_ROOT = "./checkpoints"
 
13
  os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
14
 
15
  # --- Global state ---
 
17
  model = None
18
  processor = None
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ training_status = "🚀 Initializing COCONUT-VLM Autonomous Trainer..."
21
 
22
  def load_model_for_stage(stage):
23
  global model, processor
24
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
25
  if os.path.exists(ckpt_path):
26
+ print(f"Loading checkpoint: Stage {stage}")
27
  model = LlavaForConditionalGeneration.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
28
  processor = AutoProcessor.from_pretrained(ckpt_path)
29
  else:
30
+ print(f"⚠️ No checkpoint for Stage {stage} loading base model")
31
  model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
32
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
33
 
 
35
  if model is None or processor is None:
36
  load_model_for_stage(current_stage)
37
 
38
+ try:
39
+ conversation = [
40
+ {"role": "user", "content": f"<image>\n{text}"},
41
+ ]
42
+ prompt = processor.apply_chat_template(conversation, tokenize=False)
43
+
44
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
45
+ output = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
46
+ response = processor.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
47
+
48
+ chat_history.append((text, response))
49
+ return "", chat_history
50
+ except Exception as e:
51
+ chat_history.append((text, f"⚠️ Error: {str(e)}"))
52
+ return "", chat_history
53
+
54
+ # --- Autonomous Training Pipeline ---
55
+ def auto_train_pipeline():
56
+ global current_stage, training_status
57
+
58
+ for stage in [1, 2, 3]:
59
+ current_stage = stage
60
+ training_status = f"▶️ AUTO-TRAINING STARTED: Stage {stage}"
61
+ print(training_status)
62
+
63
+ try:
64
+ # Train stage
65
+ train_vlm_stage(stage, MODEL_NAME, f"{CHECKPOINT_ROOT}/stage_{stage}")
66
+
67
+ # Update status
68
+ training_status = f"✅ Stage {stage} completed! Loading model..."
69
+ print(training_status)
70
+
71
+ # Load newly trained model
72
+ load_model_for_stage(stage)
73
+
74
+ # Brief pause before next stage
75
+ if stage < 3:
76
+ training_status = f"⏳ Advancing to Stage {stage + 1} in 5 seconds..."
77
+ print(training_status)
78
+ time.sleep(5)
79
+
80
+ except Exception as e:
81
+ training_status = f"❌ Stage {stage} failed: {str(e)}"
82
+ print(training_status)
83
+ break # Stop pipeline on failure
84
+
85
+ training_status = "🎉 COCONUT-VLM Training Complete — All 3 Stages Finished!"
86
+ print(training_status)
87
+
88
+ # --- Launch training on app start ---
89
+ def initialize_autonomous_trainer():
90
+ training_thread = threading.Thread(target=auto_train_pipeline, daemon=True)
91
+ training_thread.start()
92
+
93
+ # --- Gradio UI (Chat-Only) ---
94
+ with gr.Blocks(title="🥥 COCONUT-VLM Autonomous Trainer") as demo:
95
+ gr.Markdown("# 🥥 COCONUT-VLM: Autonomous Vision-Language Trainer")
96
+ gr.Markdown("Model is training itself in 3 stages automatically. **You can only chat.** Training is backend-only.")
97
 
98
  with gr.Row():
99
+ with gr.Column(scale=1):
100
+ status = gr.Textbox(label="Training Status", value="Initializing...", interactive=False)
101
+ gr.Markdown("💡 _Training runs automatically in background. No buttons. No switching._")
 
 
102
 
103
+ with gr.Column(scale=2):
104
  image_input = gr.Image(type="pil", label="Upload Image")
105
  chatbot = gr.Chatbot(height=400)
106
  msg = gr.Textbox(label="Ask a question about the image")
107
  clear = gr.Button("Clear Chat")
108
 
 
 
 
 
 
109
  msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot])
110
  clear.click(lambda: None, None, chatbot, queue=False)
111
 
112
+ # Initialize autonomous training on launch
113
+ demo.load(initialize_autonomous_trainer, inputs=None, outputs=None)
114
+ # Poll training status every 3 seconds
115
+ demo.load(lambda: training_status, every=3, outputs=status)
 
 
 
 
116
 
117
+ demo.queue(max_size=20).launch()