Keeby-smilyai commited on
Commit
51d26e2
ยท
verified ยท
1 Parent(s): 31126b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -16
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py โ€” Fully autonomous 3-stage VLM trainer. UI is chat-only.
2
  import gradio as gr
3
  import threading
4
  import os
@@ -19,10 +19,14 @@ processor = None
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  training_status = "๐Ÿš€ Initializing COCONUT-VLM Autonomous Trainer..."
21
 
 
 
 
 
22
  def load_model_for_stage(stage):
23
  global model, processor
24
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
25
- if os.path.exists(ckpt_path):
26
  print(f"โœ… Loading checkpoint: Stage {stage}")
27
  model = LlavaForConditionalGeneration.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
28
  processor = AutoProcessor.from_pretrained(ckpt_path)
@@ -36,6 +40,7 @@ def chat_with_image(image, text, chat_history):
36
  load_model_for_stage(current_stage)
37
 
38
  try:
 
39
  conversation = [
40
  {"role": "user", "content": f"<image>\n{text}"},
41
  ]
@@ -45,10 +50,13 @@ def chat_with_image(image, text, chat_history):
45
  output = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
46
  response = processor.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
47
 
48
- chat_history.append((text, response))
 
 
49
  return "", chat_history
50
  except Exception as e:
51
- chat_history.append((text, f"โš ๏ธ Error: {str(e)}"))
 
52
  return "", chat_history
53
 
54
  # --- Autonomous Training Pipeline ---
@@ -60,18 +68,22 @@ def auto_train_pipeline():
60
  training_status = f"โ–ถ๏ธ AUTO-TRAINING STARTED: Stage {stage}"
61
  print(training_status)
62
 
 
 
 
 
 
 
 
 
 
63
  try:
64
- # Train stage
65
- train_vlm_stage(stage, MODEL_NAME, f"{CHECKPOINT_ROOT}/stage_{stage}")
66
 
67
- # Update status
68
  training_status = f"โœ… Stage {stage} completed! Loading model..."
69
  print(training_status)
70
-
71
- # Load newly trained model
72
  load_model_for_stage(stage)
73
 
74
- # Brief pause before next stage
75
  if stage < 3:
76
  training_status = f"โณ Advancing to Stage {stage + 1} in 5 seconds..."
77
  print(training_status)
@@ -80,7 +92,7 @@ def auto_train_pipeline():
80
  except Exception as e:
81
  training_status = f"โŒ Stage {stage} failed: {str(e)}"
82
  print(training_status)
83
- break # Stop pipeline on failure
84
 
85
  training_status = "๐ŸŽ‰ COCONUT-VLM Training Complete โ€” All 3 Stages Finished!"
86
  print(training_status)
@@ -98,20 +110,23 @@ with gr.Blocks(title="๐Ÿฅฅ COCONUT-VLM Autonomous Trainer") as demo:
98
  with gr.Row():
99
  with gr.Column(scale=1):
100
  status = gr.Textbox(label="Training Status", value="Initializing...", interactive=False)
101
- gr.Markdown("๐Ÿ’ก _Training runs automatically in background. No buttons. No switching._")
102
 
103
  with gr.Column(scale=2):
104
  image_input = gr.Image(type="pil", label="Upload Image")
105
- chatbot = gr.Chatbot(height=400)
 
 
 
106
  msg = gr.Textbox(label="Ask a question about the image")
107
  clear = gr.Button("Clear Chat")
108
 
 
109
  msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot])
110
  clear.click(lambda: None, None, chatbot, queue=False)
111
 
112
- # Initialize autonomous training on launch
113
  demo.load(initialize_autonomous_trainer, inputs=None, outputs=None)
114
- # Poll training status every 3 seconds
115
- demo.load(lambda: training_status, every=3, outputs=status)
116
 
117
  demo.queue(max_size=20).launch()
 
1
+ # app.py โ€” FIXED: Gradio 4.x compatible, no deprecation warnings, auto-trains stages
2
  import gradio as gr
3
  import threading
4
  import os
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  training_status = "๐Ÿš€ Initializing COCONUT-VLM Autonomous Trainer..."
21
 
22
+ print(f"๐Ÿ–ฅ๏ธ Running on device: {device}")
23
+ if device == "cuda":
24
+ print(f"๐ŸŽฎ GPU: {torch.cuda.get_device_name(0)}")
25
+
26
  def load_model_for_stage(stage):
27
  global model, processor
28
  ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
29
+ if os.path.exists(ckpt_path) and os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")):
30
  print(f"โœ… Loading checkpoint: Stage {stage}")
31
  model = LlavaForConditionalGeneration.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
32
  processor = AutoProcessor.from_pretrained(ckpt_path)
 
40
  load_model_for_stage(current_stage)
41
 
42
  try:
43
+ # Format input for model
44
  conversation = [
45
  {"role": "user", "content": f"<image>\n{text}"},
46
  ]
 
50
  output = model.generate(**inputs, max_new_tokens=256, do_sample=True, temperature=0.7)
51
  response = processor.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
52
 
53
+ # Append as OpenAI-style messages (fixes deprecation warning)
54
+ chat_history.append({"role": "user", "content": text})
55
+ chat_history.append({"role": "assistant", "content": response})
56
  return "", chat_history
57
  except Exception as e:
58
+ chat_history.append({"role": "user", "content": text})
59
+ chat_history.append({"role": "assistant", "content": f"โš ๏ธ Error: {str(e)}"})
60
  return "", chat_history
61
 
62
  # --- Autonomous Training Pipeline ---
 
68
  training_status = f"โ–ถ๏ธ AUTO-TRAINING STARTED: Stage {stage}"
69
  print(training_status)
70
 
71
+ ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}"
72
+ # Skip if already trained
73
+ if os.path.exists(os.path.join(ckpt_path, "adapter_model.safetensors")):
74
+ training_status = f"โญ๏ธ Stage {stage} already trained โ€” loading..."
75
+ print(training_status)
76
+ load_model_for_stage(stage)
77
+ time.sleep(3)
78
+ continue
79
+
80
  try:
81
+ train_vlm_stage(stage, MODEL_NAME, ckpt_path)
 
82
 
 
83
  training_status = f"โœ… Stage {stage} completed! Loading model..."
84
  print(training_status)
 
 
85
  load_model_for_stage(stage)
86
 
 
87
  if stage < 3:
88
  training_status = f"โณ Advancing to Stage {stage + 1} in 5 seconds..."
89
  print(training_status)
 
92
  except Exception as e:
93
  training_status = f"โŒ Stage {stage} failed: {str(e)}"
94
  print(training_status)
95
+ break
96
 
97
  training_status = "๐ŸŽ‰ COCONUT-VLM Training Complete โ€” All 3 Stages Finished!"
98
  print(training_status)
 
110
  with gr.Row():
111
  with gr.Column(scale=1):
112
  status = gr.Textbox(label="Training Status", value="Initializing...", interactive=False)
113
+ gr.Markdown("๐Ÿ’ก _Training runs automatically. No buttons. No switching._")
114
 
115
  with gr.Column(scale=2):
116
  image_input = gr.Image(type="pil", label="Upload Image")
117
+
118
+ # โœ… FIXED: Set type="messages" to avoid deprecation warning
119
+ chatbot = gr.Chatbot(height=400, type="messages")
120
+
121
  msg = gr.Textbox(label="Ask a question about the image")
122
  clear = gr.Button("Clear Chat")
123
 
124
+ # Chat logic
125
  msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot])
126
  clear.click(lambda: None, None, chatbot, queue=False)
127
 
128
+ # โœ… FIXED: Use Gradio 4.x compatible .load() with every=
129
  demo.load(initialize_autonomous_trainer, inputs=None, outputs=None)
130
+ demo.load(lambda: training_status, inputs=None, outputs=status, every=3) # โ† Now compatible
 
131
 
132
  demo.queue(max_size=20).launch()