Keeby-smilyai commited on
Commit
b798184
Β·
verified Β·
1 Parent(s): e5aed01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -23
app.py CHANGED
@@ -1,8 +1,8 @@
1
- # app.py β€” FIXED: Handles remote code trust and logical error on failure
2
  import gradio as gr
3
  import os
4
  import time
5
- from train_vlm import train_vlm_stage # Assuming this file exists and works
6
  from transformers import LlavaForConditionalGeneration, AutoProcessor
7
  import torch
8
 
@@ -28,14 +28,15 @@ def load_model_for_stage(stage):
28
  current_stage = stage
29
 
30
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
31
- # βœ… FIX 1: Added trust_remote_code=True to all .from_pretrained calls
32
  if os.path.exists(ckpt_path) and os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")):
33
  print(f"βœ… Loading checkpoint: Stage {stage}")
 
34
  del model
35
- torch.cuda.empty_cache()
 
36
  model = LlavaForConditionalGeneration.from_pretrained(
37
  ckpt_path,
38
- torch_dtype=torch.float16,
39
  trust_remote_code=True
40
  ).to(device)
41
  processor = AutoProcessor.from_pretrained(ckpt_path, trust_remote_code=True)
@@ -43,7 +44,7 @@ def load_model_for_stage(stage):
43
  print(f"⚠️ No checkpoint for Stage {stage} β€” loading base model")
44
  model = LlavaForConditionalGeneration.from_pretrained(
45
  MODEL_NAME,
46
- torch_dtype=torch.float16,
47
  trust_remote_code=True
48
  ).to(device)
49
  processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
@@ -51,11 +52,11 @@ def load_model_for_stage(stage):
51
  def chat_with_image(image, text, chat_history):
52
  """Handles the user's chat interaction."""
53
  if model is None or processor is None:
54
- return "", chat_history.append({"role": "assistant", "content": "Model is not loaded yet. Please wait for training to start."})
55
 
56
  try:
57
  conversation = [{"role": "user", "content": f"<image>\n{text}"}]
58
- prompt = processor.apply_chat_template(conversation, tokenize=False)
59
 
60
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
61
  output = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
@@ -70,12 +71,9 @@ def chat_with_image(image, text, chat_history):
70
  return "", chat_history
71
 
72
  def run_autonomous_training_and_update_ui():
73
- """
74
- Generator function that runs the training pipeline and yields status updates.
75
- """
76
  yield "πŸš€ Initializing COCONUT-VLM Autonomous Trainer..."
77
 
78
- # βœ… FIX 2: Added a flag to track if training failed
79
  training_failed = False
80
 
81
  for stage in [1, 2, 3]:
@@ -94,7 +92,6 @@ def run_autonomous_training_and_update_ui():
94
  yield status_message
95
 
96
  try:
97
- # IMPORTANT: Make sure train_vlm_stage also uses trust_remote_code=True
98
  train_vlm_stage(stage, MODEL_NAME, ckpt_path)
99
 
100
  status_message = f"βœ… Stage {stage} completed! Loading new model..."
@@ -112,16 +109,15 @@ def run_autonomous_training_and_update_ui():
112
  status_message = f"❌ Stage {stage} failed: {e}"
113
  print(status_message)
114
  yield status_message
115
- training_failed = True # Set the flag to True on failure
116
- break # Stop the entire pipeline
117
 
118
- # βœ… FIX 2: Only show the completion message if the loop finished without failing
119
  if not training_failed:
120
  final_message = "πŸŽ‰ COCONUT-VLM Training Complete β€” All 3 Stages Finished!"
121
  print(final_message)
122
  yield final_message
123
 
124
- # --- Gradio UI (No changes needed here) ---
125
  with gr.Blocks(title="πŸ₯₯ COCONUT-VLM Autonomous Trainer") as demo:
126
  gr.Markdown("# πŸ₯₯ COCONUT-VLM: Autonomous Vision-Language Trainer")
127
  gr.Markdown("Model is training itself in 3 stages automatically. **You can only chat.** Training is backend-only.")
@@ -129,13 +125,10 @@ with gr.Blocks(title="πŸ₯₯ COCONUT-VLM Autonomous Trainer") as demo:
129
  with gr.Row():
130
  with gr.Column(scale=1):
131
  status = gr.Textbox(
132
- label="Training Status",
133
- value="Waiting to start...",
134
- interactive=False,
135
- show_label=False,
136
- lines=3
137
  )
138
- gr.Markdown("πŸ’‘ _Training runs automatically on page load. No buttons needed._")
139
 
140
  with gr.Column(scale=2):
141
  image_input = gr.Image(type="pil", label="Upload Image")
 
1
+ # app.py
2
  import gradio as gr
3
  import os
4
  import time
5
+ from train_vlm import train_vlm_stage
6
  from transformers import LlavaForConditionalGeneration, AutoProcessor
7
  import torch
8
 
 
28
  current_stage = stage
29
 
30
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
 
31
  if os.path.exists(ckpt_path) and os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")):
32
  print(f"βœ… Loading checkpoint: Stage {stage}")
33
+ # Free up memory before loading the next model
34
  del model
35
+ if device == "cuda":
36
+ torch.cuda.empty_cache()
37
  model = LlavaForConditionalGeneration.from_pretrained(
38
  ckpt_path,
39
+ torch_dtype=torch.float16 if device == "cuda" else torch.bfloat16,
40
  trust_remote_code=True
41
  ).to(device)
42
  processor = AutoProcessor.from_pretrained(ckpt_path, trust_remote_code=True)
 
44
  print(f"⚠️ No checkpoint for Stage {stage} β€” loading base model")
45
  model = LlavaForConditionalGeneration.from_pretrained(
46
  MODEL_NAME,
47
+ torch_dtype=torch.float16 if device == "cuda" else torch.bfloat16,
48
  trust_remote_code=True
49
  ).to(device)
50
  processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
52
  def chat_with_image(image, text, chat_history):
53
  """Handles the user's chat interaction."""
54
  if model is None or processor is None:
55
+ return "", chat_history + [{"role": "assistant", "content": "Model is not loaded yet. Please wait for training to start."}]
56
 
57
  try:
58
  conversation = [{"role": "user", "content": f"<image>\n{text}"}]
59
+ prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
60
 
61
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
62
  output = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
 
71
  return "", chat_history
72
 
73
  def run_autonomous_training_and_update_ui():
74
+ """Generator function that runs the training pipeline and yields status updates."""
 
 
75
  yield "πŸš€ Initializing COCONUT-VLM Autonomous Trainer..."
76
 
 
77
  training_failed = False
78
 
79
  for stage in [1, 2, 3]:
 
92
  yield status_message
93
 
94
  try:
 
95
  train_vlm_stage(stage, MODEL_NAME, ckpt_path)
96
 
97
  status_message = f"βœ… Stage {stage} completed! Loading new model..."
 
109
  status_message = f"❌ Stage {stage} failed: {e}"
110
  print(status_message)
111
  yield status_message
112
+ training_failed = True
113
+ break
114
 
 
115
  if not training_failed:
116
  final_message = "πŸŽ‰ COCONUT-VLM Training Complete β€” All 3 Stages Finished!"
117
  print(final_message)
118
  yield final_message
119
 
120
+ # --- Gradio UI ---
121
  with gr.Blocks(title="πŸ₯₯ COCONUT-VLM Autonomous Trainer") as demo:
122
  gr.Markdown("# πŸ₯₯ COCONUT-VLM: Autonomous Vision-Language Trainer")
123
  gr.Markdown("Model is training itself in 3 stages automatically. **You can only chat.** Training is backend-only.")
 
125
  with gr.Row():
126
  with gr.Column(scale=1):
127
  status = gr.Textbox(
128
+ label="Training Status", value="Waiting to start...", interactive=False,
129
+ show_label=False, lines=3, max_lines=5
 
 
 
130
  )
131
+ gr.Markdown("πŸ’‘ _Training runs automatically on page load._")
132
 
133
  with gr.Column(scale=2):
134
  image_input = gr.Image(type="pil", label="Upload Image")