import gradio as gr import torch from PIL import Image from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import upload_file import os import uuid import logging # Model configuration MID = "apple/FastVLM-0.5B" IMAGE_TOKEN_INDEX = -200 # Your HF model repo where you want to upload results HF_MODEL = "rahul7star/VideoExplain" # change if needed # Load model and tokenizer (lazy load) tok = None model = None def load_model(): global tok, model if tok is None or model is None: print("Loading model...") tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True) if torch.cuda.is_available(): device = "cuda" dtype = torch.float16 else: device = "cpu" dtype = torch.float32 model = AutoModelForCausalLM.from_pretrained( MID, torch_dtype=dtype, device_map=device, trust_remote_code=True, ) print(f"Model loaded on {device.upper()} successfully!") return tok, model def upload_to_hf(image_path, summary_text): """Upload image + summary text to Hugging Face model repo""" unique_folder = f"image_{uuid.uuid4().hex[:8]}" logging.info(f"Creating new HF folder: {unique_folder} in repo {HF_MODEL}") # Upload image img_filename = os.path.basename(image_path) img_hf_path = f"{unique_folder}/{img_filename}" upload_file( path_or_fileobj=image_path, path_in_repo=img_hf_path, repo_id=HF_MODEL, repo_type="model", token=os.environ.get("HUGGINGFACE_HUB_TOKEN"), ) logging.info(f"✅ Uploaded image to HF: {img_hf_path}") # Upload summary text summary_file = "/tmp/summary.txt" with open(summary_file, "w", encoding="utf-8") as f: f.write(summary_text) summary_hf_path = f"{unique_folder}/summary.txt" upload_file( path_or_fileobj=summary_file, path_in_repo=summary_hf_path, repo_id=HF_MODEL, repo_type="model", token=os.environ.get("HUGGINGFACE_HUB_TOKEN"), ) logging.info(f"✅ Uploaded summary to HF: {summary_hf_path}") return f"Uploaded to Hugging Face under {unique_folder}" def caption_image(image, custom_prompt=None): """Generate caption + upload image+caption to HF""" if image is None: return "Please upload an image first." try: # Save uploaded image locally (needed for upload) temp_img = "/tmp/uploaded_image.png" image.save(temp_img) # Load model tok, model = load_model() if image.mode != "RGB": image = image.convert("RGB") prompt = custom_prompt if custom_prompt else "Describe this image in detail." messages = [{"role": "user", "content": f"\n{prompt}"}] rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) pre, post = rendered.split("", 1) pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device) attention_mask = torch.ones_like(input_ids, device=model.device) px = model.get_vision_tower().image_processor(images=image, return_tensors="pt")["pixel_values"] px = px.to(model.device, dtype=model.dtype) with torch.no_grad(): out = model.generate( inputs=input_ids, attention_mask=attention_mask, images=px, max_new_tokens=128, do_sample=False, ) generated_text = tok.decode(out[0], skip_special_tokens=True) response = generated_text.split("assistant")[-1].strip() if "assistant" in generated_text else generated_text # Upload image + caption to HF repo upload_status = upload_to_hf(temp_img, response) return f"{response}\n\n---\n{upload_status}" except Exception as e: return f"Error generating caption: {str(e)}" # Gradio UI with gr.Blocks(title="FastVLM Image Captioning") as demo: gr.Markdown("# 🖼️ FastVLM Image Captioning") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") custom_prompt = gr.Textbox( label="Custom Prompt (Optional)", placeholder="Leave empty for default prompt", lines=2 ) generate_btn = gr.Button("Generate + Upload", variant="primary") clear_btn = gr.ClearButton([image_input, custom_prompt]) with gr.Column(): output = gr.Textbox(label="Generated Caption + Upload Status", lines=8, show_copy_button=True) generate_btn.click(caption_image, [image_input, custom_prompt], output) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)