Keeby-smilyai commited on
Commit
1414591
ยท
verified ยท
1 Parent(s): 7ac9339

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -68
app.py CHANGED
@@ -1,9 +1,8 @@
1
- # app.py โ€” FIXED: Gradio 4.x compatible, no deprecation warnings, auto-trains stages
2
  import gradio as gr
3
- import threading
4
  import os
5
  import time
6
- from train_vlm import train_vlm_stage
7
  from transformers import LlavaForConditionalGeneration, AutoProcessor
8
  import torch
9
 
@@ -12,22 +11,29 @@ MODEL_NAME = "bczhou/TinyLLaVA-3.1B" # or "llava-hf/llava-1.5-7b-hf"
12
  CHECKPOINT_ROOT = "./checkpoints"
13
  os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
14
 
15
- # --- Global state ---
16
- current_stage = 1
17
  model = None
18
  processor = None
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
- training_status = "๐Ÿš€ Initializing COCONUT-VLM Autonomous Trainer..."
21
 
22
  print(f"๐Ÿ–ฅ๏ธ Running on device: {device}")
23
  if device == "cuda":
24
  print(f"๐ŸŽฎ GPU: {torch.cuda.get_device_name(0)}")
25
 
26
  def load_model_for_stage(stage):
27
- global model, processor
 
 
 
 
 
28
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
29
  if os.path.exists(ckpt_path) and os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")):
30
  print(f"โœ… Loading checkpoint: Stage {stage}")
 
 
 
31
  model = LlavaForConditionalGeneration.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
32
  processor = AutoProcessor.from_pretrained(ckpt_path)
33
  else:
@@ -36,21 +42,18 @@ def load_model_for_stage(stage):
36
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
37
 
38
  def chat_with_image(image, text, chat_history):
 
39
  if model is None or processor is None:
40
- load_model_for_stage(current_stage)
41
 
42
  try:
43
- # Format input for model
44
- conversation = [
45
- {"role": "user", "content": f"<image>\n{text}"},
46
- ]
47
  prompt = processor.apply_chat_template(conversation, tokenize=False)
48
 
49
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
50
  output = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
51
  response = processor.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
52
 
53
- # Append as OpenAI-style messages (fixes deprecation warning)
54
  chat_history.append({"role": "user", "content": text})
55
  chat_history.append({"role": "assistant", "content": response})
56
  return "", chat_history
@@ -59,99 +62,95 @@ def chat_with_image(image, text, chat_history):
59
  chat_history.append({"role": "assistant", "content": f"โš ๏ธ Error: {str(e)}"})
60
  return "", chat_history
61
 
62
- # --- Autonomous Training Pipeline ---
63
- def auto_train_pipeline():
64
- global current_stage, training_status
 
 
 
 
 
65
 
66
  for stage in [1, 2, 3]:
67
- current_stage = stage
68
- training_status = f"โ–ถ๏ธ AUTO-TRAINING STARTED: Stage {stage}"
69
- print(training_status)
70
-
71
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
72
- # Skip if already trained
 
73
  if os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")):
74
- training_status = f"โญ๏ธ Stage {stage} already trained โ€” loading..."
75
- print(training_status)
 
76
  load_model_for_stage(stage)
77
- time.sleep(3)
78
  continue
79
 
 
 
 
 
 
80
  try:
 
81
  train_vlm_stage(stage, MODEL_NAME, ckpt_path)
82
 
83
- training_status = f"โœ… Stage {stage} completed! Loading model..."
84
- print(training_status)
 
 
85
  load_model_for_stage(stage)
86
 
87
  if stage < 3:
88
- training_status = f"โณ Advancing to Stage {stage + 1} in 5 seconds..."
89
- print(training_status)
 
90
  time.sleep(5)
91
 
92
  except Exception as e:
93
- training_status = f"โŒ Stage {stage} failed: {str(e)}"
94
- print(training_status)
95
- break
96
-
97
- training_status = "๐ŸŽ‰ COCONUT-VLM Training Complete โ€” All 3 Stages Finished!"
98
- print(training_status)
99
 
100
- # --- Launch training on app start ---
101
- def initialize_autonomous_trainer():
102
- training_thread = threading.Thread(target=auto_train_pipeline, daemon=True)
103
- training_thread.start()
104
-
105
- # Start the status update process
106
- return training_status
107
 
108
- # --- Status update function ---
109
- def update_status():
110
- # Return the current status and trigger the next update
111
- time.sleep(0.5) # Small delay to prevent CPU overload
112
- return training_status, gr.update(autoscroll=True) # Also autoscroll chat window
113
 
114
- # --- Gradio UI (Chat-Only) ---
115
  with gr.Blocks(title="๐Ÿฅฅ COCONUT-VLM Autonomous Trainer") as demo:
116
  gr.Markdown("# ๐Ÿฅฅ COCONUT-VLM: Autonomous Vision-Language Trainer")
117
  gr.Markdown("Model is training itself in 3 stages automatically. **You can only chat.** Training is backend-only.")
118
-
119
- # We'll create a hidden component to trigger status updates
120
- hidden_dummy = gr.State()
121
 
122
  with gr.Row():
123
  with gr.Column(scale=1):
124
  status = gr.Textbox(
125
  label="Training Status",
126
- value="Initializing...",
127
  interactive=False,
128
- show_label=False
 
129
  )
130
- gr.Markdown("๐Ÿ’ก _Training runs automatically. No buttons. No switching._")
131
 
132
  with gr.Column(scale=2):
133
  image_input = gr.Image(type="pil", label="Upload Image")
134
-
135
- # โœ… FIXED: Set type="messages" to avoid deprecation warning
136
  chatbot = gr.Chatbot(height=400, type="messages")
137
-
138
  msg = gr.Textbox(label="Ask a question about the image")
139
  clear = gr.Button("Clear Chat")
140
 
141
- # Chat logic - FIXED: Changed output to chatbot
142
  msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot])
143
- clear.click(lambda: None, inputs=None, outputs=chatbot, queue=False)
144
 
145
- # โœ… FIXED: Combined initialization and status updates
 
 
 
146
  demo.load(
147
- fn=initialize_autonomous_trainer,
148
- inputs=None,
149
- outputs=status,
150
- then=[
151
- {"fn": update_status, "outputs": [status, chatbot]},
152
- {"fn": update_status, "outputs": [status, chatbot]},
153
- ],
154
- every=1.5
155
  )
156
 
157
- demo.queue(max_size=20).launch()
 
1
+ # app.py โ€” REFACTORED with a clean, custom Python loop using 'yield'
2
  import gradio as gr
 
3
  import os
4
  import time
5
+ from train_vlm import train_vlm_stage # Assuming this file exists and works
6
  from transformers import LlavaForConditionalGeneration, AutoProcessor
7
  import torch
8
 
 
11
  CHECKPOINT_ROOT = "./checkpoints"
12
  os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
13
 
14
+ # --- Global state for the model (needed for the chat function) ---
15
+ current_stage = 0
16
  model = None
17
  processor = None
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
19
 
20
  print(f"๐Ÿ–ฅ๏ธ Running on device: {device}")
21
  if device == "cuda":
22
  print(f"๐ŸŽฎ GPU: {torch.cuda.get_device_name(0)}")
23
 
24
  def load_model_for_stage(stage):
25
+ """Loads the appropriate model and processor for a given stage."""
26
+ global model, processor, current_stage
27
+
28
+ # Update the global stage so the chat function knows which model to use
29
+ current_stage = stage
30
+
31
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
32
  if os.path.exists(ckpt_path) and os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")):
33
  print(f"โœ… Loading checkpoint: Stage {stage}")
34
+ # Free up VRAM before loading the next model
35
+ del model
36
+ torch.cuda.empty_cache()
37
  model = LlavaForConditionalGeneration.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
38
  processor = AutoProcessor.from_pretrained(ckpt_path)
39
  else:
 
42
  processor = AutoProcessor.from_pretrained(MODEL_NAME)
43
 
44
  def chat_with_image(image, text, chat_history):
45
+ """Handles the user's chat interaction."""
46
  if model is None or processor is None:
47
+ return "", chat_history.append({"role": "assistant", "content": "Model is not loaded yet. Please wait for training to start."})
48
 
49
  try:
50
+ conversation = [{"role": "user", "content": f"<image>\n{text}"}]
 
 
 
51
  prompt = processor.apply_chat_template(conversation, tokenize=False)
52
 
53
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
54
  output = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
55
  response = processor.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
56
 
 
57
  chat_history.append({"role": "user", "content": text})
58
  chat_history.append({"role": "assistant", "content": response})
59
  return "", chat_history
 
62
  chat_history.append({"role": "assistant", "content": f"โš ๏ธ Error: {str(e)}"})
63
  return "", chat_history
64
 
65
+ # --- The Custom Loop: Autonomous Training Pipeline ---
66
+ # This single function runs the entire loop and 'yields' updates to the UI.
67
+ def run_autonomous_training_and_update_ui():
68
+ """
69
+ This is a generator function that runs the entire training pipeline.
70
+ It yields status messages that are displayed directly in the Gradio UI.
71
+ """
72
+ yield "๐Ÿš€ Initializing COCONUT-VLM Autonomous Trainer..."
73
 
74
  for stage in [1, 2, 3]:
 
 
 
 
75
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
76
+
77
+ # 1. Check if stage is already trained
78
  if os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")):
79
+ status_message = f"โญ๏ธ Stage {stage} already trained โ€” loading..."
80
+ print(status_message)
81
+ yield status_message
82
  load_model_for_stage(stage)
83
+ time.sleep(2) # Give user time to read the message
84
  continue
85
 
86
+ # 2. Start training for the current stage
87
+ status_message = f"โ–ถ๏ธ AUTO-TRAINING STARTED: Stage {stage}"
88
+ print(status_message)
89
+ yield status_message
90
+
91
  try:
92
+ # This is the long-running training task
93
  train_vlm_stage(stage, MODEL_NAME, ckpt_path)
94
 
95
+ # 3. Handle successful training
96
+ status_message = f"โœ… Stage {stage} completed! Loading new model..."
97
+ print(status_message)
98
+ yield status_message
99
  load_model_for_stage(stage)
100
 
101
  if stage < 3:
102
+ status_message = f"โณ Advancing to Stage {stage + 1} in 5 seconds..."
103
+ print(status_message)
104
+ yield status_message
105
  time.sleep(5)
106
 
107
  except Exception as e:
108
+ # 4. Handle training failure
109
+ status_message = f"โŒ Stage {stage} failed: {str(e)}"
110
+ print(status_message)
111
+ yield status_message
112
+ break # Stop the entire pipeline if a stage fails
 
113
 
114
+ # 5. Final completion message
115
+ final_message = "๐ŸŽ‰ COCONUT-VLM Training Complete โ€” All 3 Stages Finished!"
116
+ print(final_message)
117
+ yield final_message
 
 
 
118
 
 
 
 
 
 
119
 
120
+ # --- Gradio UI ---
121
  with gr.Blocks(title="๐Ÿฅฅ COCONUT-VLM Autonomous Trainer") as demo:
122
  gr.Markdown("# ๐Ÿฅฅ COCONUT-VLM: Autonomous Vision-Language Trainer")
123
  gr.Markdown("Model is training itself in 3 stages automatically. **You can only chat.** Training is backend-only.")
 
 
 
124
 
125
  with gr.Row():
126
  with gr.Column(scale=1):
127
  status = gr.Textbox(
128
  label="Training Status",
129
+ value="Waiting to start...",
130
  interactive=False,
131
+ show_label=False,
132
+ lines=3 # Give it a bit more space
133
  )
134
+ gr.Markdown("๐Ÿ’ก _Training runs automatically on page load. No buttons needed._")
135
 
136
  with gr.Column(scale=2):
137
  image_input = gr.Image(type="pil", label="Upload Image")
 
 
138
  chatbot = gr.Chatbot(height=400, type="messages")
 
139
  msg = gr.Textbox(label="Ask a question about the image")
140
  clear = gr.Button("Clear Chat")
141
 
142
+ # Chat logic remains the same
143
  msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot])
144
+ clear.click(lambda: [], inputs=None, outputs=chatbot)
145
 
146
+ # --- THE MAGIC ---
147
+ # On page load, run our generator function. Gradio will automatically
148
+ # update the 'status' textbox every time the function 'yields' a new value.
149
+ # This is clean, efficient, and avoids all threading/polling headaches.
150
  demo.load(
151
+ fn=run_autonomous_training_and_update_ui,
152
+ inputs=None,
153
+ outputs=status
 
 
 
 
 
154
  )
155
 
156
+ demo.queue().launch()