Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import os | |
| import time | |
| from train_vlm import train_vlm_stage | |
| from transformers import AutoImageProcessor, AutoTokenizer | |
| from custom_vlm import CustomScratchVLM | |
| import torch | |
| CHECKPOINT_ROOT = "./checkpoints" | |
| os.makedirs(CHECKPOINT_ROOT, exist_ok=True) | |
| current_stage = 0 | |
| model = None | |
| image_processor = None | |
| tokenizer = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"๐ฅ๏ธ Running on device: {device}") | |
| if device == "cuda": print(f"๐ฎ GPU: {torch.cuda.get_device_name(0)}") | |
| def load_model_for_stage(stage): | |
| global model, image_processor, tokenizer, current_stage | |
| current_stage = stage | |
| ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}" | |
| if os.path.exists(os.path.join(ckpt_path, "config.json")): | |
| print(f"โ Loading FROM-SCRATCH checkpoint: Stage {stage}") | |
| if model is not None: del model | |
| if device == "cuda": torch.cuda.empty_cache() | |
| model = CustomScratchVLM.from_pretrained(ckpt_path).to(device).eval() | |
| image_processor = AutoImageProcessor.from_pretrained(ckpt_path) | |
| tokenizer = AutoTokenizer.from_pretrained(ckpt_path) | |
| else: | |
| print(f"โ ๏ธ No checkpoint for Stage {stage} โ model is not loaded.") | |
| model, image_processor, tokenizer = None, None, None | |
| def chat_with_image(image, text, chat_history): | |
| if not all([model, image_processor, tokenizer]): | |
| return "", chat_history + [{"role": "assistant", "content": "Model is not loaded or is currently training. Please wait."}] | |
| if image is None: | |
| return "", chat_history + [{"role": "user", "content": text}, {"role": "assistant", "content": "Please upload an image."}] | |
| try: | |
| pixel_values = image_processor(image, return_tensors="pt").pixel_values.to(device) | |
| # For inference, we do not include the <IMAGE> token in the text prompt | |
| prompt = f"USER: \nQuestion: {text}\nASSISTANT:" | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| input_ids = inputs.input_ids.to(device) | |
| attention_mask = inputs.attention_mask.to(device) | |
| output_ids = model.generate( | |
| pixel_values=pixel_values, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.7, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode only the newly generated tokens | |
| response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) | |
| chat_history.append({"role": "user", "content": text}) | |
| chat_history.append({"role": "assistant", "content": response or "[No response generated]"}) | |
| return "", chat_history | |
| except Exception as e: | |
| return "", chat_history + [{"role": "user", "content": text}, {"role": "assistant", "content": f"โ ๏ธ Error: {e}"}] | |
| def run_autonomous_training_and_update_ui(): | |
| yield "๐ Initializing From-Scratch Trainer..." | |
| for stage in [1, 2, 3]: | |
| ckpt_path = f"{CHECKPOINT_ROOT}/stage_{stage}" | |
| if os.path.exists(os.path.join(ckpt_path, "config.json")): | |
| status_message = f"โญ๏ธ Stage {stage} already trained โ loading..." | |
| yield status_message | |
| load_model_for_stage(stage) | |
| continue | |
| status_message = f"โถ๏ธ AUTO-TRAINING FROM SCRATCH: Stage {stage}" | |
| yield status_message | |
| try: | |
| train_vlm_stage(stage, ckpt_path) | |
| status_message = f"โ Stage {stage} completed! Loading new model..." | |
| yield status_message | |
| load_model_for_stage(stage) | |
| except Exception as e: | |
| status_message = f"โ Stage {stage} failed: {e}" | |
| yield status_message; raise e # Stop execution on failure | |
| yield "๐ COCONUT-VLM Training Complete โ All 3 Stages Finished!" | |
| with gr.Blocks(title="๐ฅฅ COCONUT-VLM From Scratch") as demo: | |
| gr.Markdown("# ๐ฅฅ COCONUT-VLM: Autonomous Trainer (From Scratch)") | |
| gr.Markdown("Model is training itself **from random initialization**. You can interact with the latest trained model.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| status = gr.Textbox(label="Training Status", value="Waiting to start...", interactive=False, lines=10) | |
| with gr.Column(scale=2): | |
| image_input = gr.Image(type="pil", label="Upload Image") | |
| chatbot = gr.Chatbot(label="Chat with the VLM", height=400) | |
| msg = gr.Textbox(label="Ask a question") | |
| clear = gr.Button("Clear Chat") | |
| msg.submit(chat_with_image, [image_input, msg, chatbot], [msg, chatbot]) | |
| clear.click(lambda: (None, None, []), None, [image_input, msg, chatbot]) | |
| demo.load(fn=run_autonomous_training_and_update_ui, inputs=None, outputs=status) | |
| demo.queue().launch(debug=True) |