Keeby-smilyai commited on
Commit
7be6eb8
Β·
verified Β·
1 Parent(s): 1414591

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -33
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py β€” REFACTORED with a clean, custom Python loop using 'yield'
2
  import gradio as gr
3
  import os
4
  import time
@@ -7,11 +7,11 @@ from transformers import LlavaForConditionalGeneration, AutoProcessor
7
  import torch
8
 
9
  # --- Config ---
10
- MODEL_NAME = "bczhou/TinyLLaVA-3.1B" # or "llava-hf/llava-1.5-7b-hf"
11
  CHECKPOINT_ROOT = "./checkpoints"
12
  os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
13
 
14
- # --- Global state for the model (needed for the chat function) ---
15
  current_stage = 0
16
  model = None
17
  processor = None
@@ -25,21 +25,28 @@ def load_model_for_stage(stage):
25
  """Loads the appropriate model and processor for a given stage."""
26
  global model, processor, current_stage
27
 
28
- # Update the global stage so the chat function knows which model to use
29
  current_stage = stage
30
 
31
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
 
32
  if os.path.exists(ckpt_path) and os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")):
33
  print(f"βœ… Loading checkpoint: Stage {stage}")
34
- # Free up VRAM before loading the next model
35
  del model
36
  torch.cuda.empty_cache()
37
- model = LlavaForConditionalGeneration.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
38
- processor = AutoProcessor.from_pretrained(ckpt_path)
 
 
 
 
39
  else:
40
  print(f"⚠️ No checkpoint for Stage {stage} β€” loading base model")
41
- model = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch.float16).to(device)
42
- processor = AutoProcessor.from_pretrained(MODEL_NAME)
 
 
 
 
43
 
44
  def chat_with_image(image, text, chat_history):
45
  """Handles the user's chat interaction."""
@@ -62,37 +69,34 @@ def chat_with_image(image, text, chat_history):
62
  chat_history.append({"role": "assistant", "content": f"⚠️ Error: {str(e)}"})
63
  return "", chat_history
64
 
65
- # --- The Custom Loop: Autonomous Training Pipeline ---
66
- # This single function runs the entire loop and 'yields' updates to the UI.
67
  def run_autonomous_training_and_update_ui():
68
  """
69
- This is a generator function that runs the entire training pipeline.
70
- It yields status messages that are displayed directly in the Gradio UI.
71
  """
72
  yield "πŸš€ Initializing COCONUT-VLM Autonomous Trainer..."
 
 
 
73
 
74
  for stage in [1, 2, 3]:
75
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
76
 
77
- # 1. Check if stage is already trained
78
  if os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")):
79
  status_message = f"⏭️ Stage {stage} already trained β€” loading..."
80
  print(status_message)
81
  yield status_message
82
  load_model_for_stage(stage)
83
- time.sleep(2) # Give user time to read the message
84
  continue
85
 
86
- # 2. Start training for the current stage
87
  status_message = f"▢️ AUTO-TRAINING STARTED: Stage {stage}"
88
  print(status_message)
89
  yield status_message
90
 
91
  try:
92
- # This is the long-running training task
93
  train_vlm_stage(stage, MODEL_NAME, ckpt_path)
94
 
95
- # 3. Handle successful training
96
  status_message = f"βœ… Stage {stage} completed! Loading new model..."
97
  print(status_message)
98
  yield status_message
@@ -105,19 +109,19 @@ def run_autonomous_training_and_update_ui():
105
  time.sleep(5)
106
 
107
  except Exception as e:
108
- # 4. Handle training failure
109
- status_message = f"❌ Stage {stage} failed: {str(e)}"
110
  print(status_message)
111
  yield status_message
112
- break # Stop the entire pipeline if a stage fails
113
-
114
- # 5. Final completion message
115
- final_message = "πŸŽ‰ COCONUT-VLM Training Complete β€” All 3 Stages Finished!"
116
- print(final_message)
117
- yield final_message
118
 
 
 
 
 
 
119
 
120
- # --- Gradio UI ---
121
  with gr.Blocks(title="πŸ₯₯ COCONUT-VLM Autonomous Trainer") as demo:
122
  gr.Markdown("# πŸ₯₯ COCONUT-VLM: Autonomous Vision-Language Trainer")
123
  gr.Markdown("Model is training itself in 3 stages automatically. **You can only chat.** Training is backend-only.")
@@ -129,7 +133,7 @@ with gr.Blocks(title="πŸ₯₯ COCONUT-VLM Autonomous Trainer") as demo:
129
  value="Waiting to start...",
130
  interactive=False,
131
  show_label=False,
132
- lines=3 # Give it a bit more space
133
  )
134
  gr.Markdown("πŸ’‘ _Training runs automatically on page load. No buttons needed._")
135
 
@@ -139,14 +143,9 @@ with gr.Blocks(title="πŸ₯₯ COCONUT-VLM Autonomous Trainer") as demo:
139
  msg = gr.Textbox(label="Ask a question about the image")
140
  clear = gr.Button("Clear Chat")
141
 
142
- # Chat logic remains the same
143
  msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot])
144
  clear.click(lambda: [], inputs=None, outputs=chatbot)
145
 
146
- # --- THE MAGIC ---
147
- # On page load, run our generator function. Gradio will automatically
148
- # update the 'status' textbox every time the function 'yields' a new value.
149
- # This is clean, efficient, and avoids all threading/polling headaches.
150
  demo.load(
151
  fn=run_autonomous_training_and_update_ui,
152
  inputs=None,
 
1
+ # app.py β€” FIXED: Handles remote code trust and logical error on failure
2
  import gradio as gr
3
  import os
4
  import time
 
7
  import torch
8
 
9
  # --- Config ---
10
+ MODEL_NAME = "bczhou/TinyLLaVA-3.1B"
11
  CHECKPOINT_ROOT = "./checkpoints"
12
  os.makedirs(CHECKPOINT_ROOT, exist_ok=True)
13
 
14
+ # --- Global state for the model ---
15
  current_stage = 0
16
  model = None
17
  processor = None
 
25
  """Loads the appropriate model and processor for a given stage."""
26
  global model, processor, current_stage
27
 
 
28
  current_stage = stage
29
 
30
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
31
+ # βœ… FIX 1: Added trust_remote_code=True to all .from_pretrained calls
32
  if os.path.exists(ckpt_path) and os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")):
33
  print(f"βœ… Loading checkpoint: Stage {stage}")
 
34
  del model
35
  torch.cuda.empty_cache()
36
+ model = LlavaForConditionalGeneration.from_pretrained(
37
+ ckpt_path,
38
+ torch_dtype=torch.float16,
39
+ trust_remote_code=True
40
+ ).to(device)
41
+ processor = AutoProcessor.from_pretrained(ckpt_path, trust_remote_code=True)
42
  else:
43
  print(f"⚠️ No checkpoint for Stage {stage} β€” loading base model")
44
+ model = LlavaForConditionalGeneration.from_pretrained(
45
+ MODEL_NAME,
46
+ torch_dtype=torch.float16,
47
+ trust_remote_code=True
48
+ ).to(device)
49
+ processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
50
 
51
  def chat_with_image(image, text, chat_history):
52
  """Handles the user's chat interaction."""
 
69
  chat_history.append({"role": "assistant", "content": f"⚠️ Error: {str(e)}"})
70
  return "", chat_history
71
 
 
 
72
  def run_autonomous_training_and_update_ui():
73
  """
74
+ Generator function that runs the training pipeline and yields status updates.
 
75
  """
76
  yield "πŸš€ Initializing COCONUT-VLM Autonomous Trainer..."
77
+
78
+ # βœ… FIX 2: Added a flag to track if training failed
79
+ training_failed = False
80
 
81
  for stage in [1, 2, 3]:
82
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
83
 
 
84
  if os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")):
85
  status_message = f"⏭️ Stage {stage} already trained β€” loading..."
86
  print(status_message)
87
  yield status_message
88
  load_model_for_stage(stage)
89
+ time.sleep(2)
90
  continue
91
 
 
92
  status_message = f"▢️ AUTO-TRAINING STARTED: Stage {stage}"
93
  print(status_message)
94
  yield status_message
95
 
96
  try:
97
+ # IMPORTANT: Make sure train_vlm_stage also uses trust_remote_code=True
98
  train_vlm_stage(stage, MODEL_NAME, ckpt_path)
99
 
 
100
  status_message = f"βœ… Stage {stage} completed! Loading new model..."
101
  print(status_message)
102
  yield status_message
 
109
  time.sleep(5)
110
 
111
  except Exception as e:
112
+ status_message = f"❌ Stage {stage} failed: {e}"
 
113
  print(status_message)
114
  yield status_message
115
+ training_failed = True # Set the flag to True on failure
116
+ break # Stop the entire pipeline
 
 
 
 
117
 
118
+ # οΏ½οΏ½οΏ½ FIX 2: Only show the completion message if the loop finished without failing
119
+ if not training_failed:
120
+ final_message = "πŸŽ‰ COCONUT-VLM Training Complete β€” All 3 Stages Finished!"
121
+ print(final_message)
122
+ yield final_message
123
 
124
+ # --- Gradio UI (No changes needed here) ---
125
  with gr.Blocks(title="πŸ₯₯ COCONUT-VLM Autonomous Trainer") as demo:
126
  gr.Markdown("# πŸ₯₯ COCONUT-VLM: Autonomous Vision-Language Trainer")
127
  gr.Markdown("Model is training itself in 3 stages automatically. **You can only chat.** Training is backend-only.")
 
133
  value="Waiting to start...",
134
  interactive=False,
135
  show_label=False,
136
+ lines=3
137
  )
138
  gr.Markdown("πŸ’‘ _Training runs automatically on page load. No buttons needed._")
139
 
 
143
  msg = gr.Textbox(label="Ask a question about the image")
144
  clear = gr.Button("Clear Chat")
145
 
 
146
  msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot])
147
  clear.click(lambda: [], inputs=None, outputs=chatbot)
148
 
 
 
 
 
149
  demo.load(
150
  fn=run_autonomous_training_and_update_ui,
151
  inputs=None,