Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import uuid | |
| import time | |
| import asyncio | |
| from threading import Thread | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from transformers import ( | |
| Qwen2VLForConditionalGeneration, | |
| Qwen2_5_VLForConditionalGeneration, | |
| AutoModelForImageTextToText, | |
| AutoProcessor, | |
| TextIteratorStreamer, | |
| ) | |
| from transformers.image_utils import load_image | |
| # Constants | |
| MAX_MAX_NEW_TOKENS = 2048 | |
| DEFAULT_MAX_NEW_TOKENS = 1024 | |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| # Load public OCR models | |
| MODEL_ID_V = "nanonets/Nanonets-OCR-s" | |
| processor_v = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True) | |
| model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_V, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 | |
| ).to(device).eval() | |
| MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" | |
| processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) | |
| model_x = Qwen2VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_X, trust_remote_code=True, torch_dtype=torch.bfloat16 | |
| ).to(device).eval() | |
| MODEL_ID_M = "reducto/RolmOCR" | |
| processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) | |
| model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.bfloat16 | |
| ).to(device).eval() | |
| MODEL_ID_W = "prithivMLmods/Lh41-1042-Magellanic-7B-0711" | |
| processor_w = AutoProcessor.from_pretrained(MODEL_ID_W, trust_remote_code=True) | |
| model_w = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID_W, trust_remote_code=True, torch_dtype=torch.bfloat16 | |
| ).to(device).eval() | |
| def downsample_video(video_path): | |
| vidcap = cv2.VideoCapture(video_path) | |
| total = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = vidcap.get(cv2.CAP_PROP_FPS) | |
| frames = [] | |
| for i in np.linspace(0, total - 1, 10, dtype=int): | |
| vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
| ok, img = vidcap.read() | |
| if ok: | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| frames.append((Image.fromarray(img), round(i / fps, 2))) | |
| vidcap.release() | |
| return frames | |
| def generate_image(model_name, text, image, max_new_tokens, temperature, top_p, top_k, repetition_penalty): | |
| mapping = { | |
| "Nanonets-OCR-s": (processor_v, model_v), | |
| "Qwen2-VL-OCR-2B": (processor_x, model_x), | |
| "RolmOCR-7B": (processor_m, model_m), | |
| "Lh41-1042-Magellanic-7B-0711": (processor_w, model_w), | |
| } | |
| if model_name not in mapping: | |
| yield "Invalid model selected.", "Invalid model." | |
| return | |
| processor, model = mapping[model_name] | |
| if image is None: | |
| yield "Please upload an image.", "" | |
| return | |
| msg = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": text}]}] | |
| prompt = processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) | |
| inputs = processor(text=[prompt], images=[image], return_tensors="pt", padding=True).to(device) | |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
| thread = Thread(target=model.generate, kwargs={**inputs, "streamer": streamer, "max_new_tokens": max_new_tokens}) | |
| thread.start() | |
| out = "" | |
| for token in streamer: | |
| out += token.replace("<|im_end|>", "") | |
| time.sleep(0.01) | |
| yield out, out | |
| def generate_video(model_name, text, video_path, max_new_tokens, temperature, top_p, top_k, repetition_penalty): | |
| mapping = { | |
| "Nanonets-OCR-s": (processor_v, model_v), | |
| "Qwen2-VL-OCR-2B": (processor_x, model_x), | |
| "RolmOCR-7B": (processor_m, model_m), | |
| "Lh41-1042-Magellanic-7B-0711": (processor_w, model_w), | |
| } | |
| if model_name not in mapping: | |
| yield "Invalid model selected.", "Invalid model." | |
| return | |
| processor, model = mapping[model_name] | |
| if video_path is None: | |
| yield "Please upload a video.", "" | |
| return | |
| frames = downsample_video(video_path) | |
| messages = [{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, | |
| {"role": "user", "content": [{"type": "text", "text": text}]}] | |
| for img, ts in frames: | |
| messages[1]["content"].append({"type": "text", "text": f"Frame {ts}:"}) | |
| messages[1]["content"].append({"type": "image", "image": img}) | |
| inputs = processor.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, | |
| return_tensors="pt").to(device) | |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
| thread = Thread(target=model.generate, kwargs={**inputs, | |
| "streamer": streamer, | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "repetition_penalty": repetition_penalty}) | |
| thread.start() | |
| out = "" | |
| for token in streamer: | |
| out += token.replace("<|im_end|>", "") | |
| time.sleep(0.01) | |
| yield out, out | |
| # Examples | |
| image_examples = [ | |
| ["Extract the content", "images/4.png"], | |
| ["Explain the scene", "images/3.jpg"], | |
| ["Perform OCR on the image", "images/1.jpg"], | |
| ] | |
| video_examples = [ | |
| ["Explain the Ad in Detail", "videos/1.mp4"], | |
| ] | |
| css = """ | |
| .submit-btn { background-color: #2980b9 !important; color: white !important; } | |
| .submit-btn:hover { background-color: #3498db !important; } | |
| .canvas-output { border: 2px solid #4682B4; border-radius: 10px; padding: 20px; } | |
| """ | |
| with gr.Blocks(css=css, theme="bethecloud/storj_theme") as demo: | |
| gr.Markdown("# **Multimodal OCR**") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| with gr.TabItem("Image Inference"): | |
| img_q = gr.Textbox(label="Query Input", placeholder="Enter prompt") | |
| img_up = gr.Image(type="pil", label="Upload Image") | |
| img_btn = gr.Button("Submit", elem_classes="submit-btn") | |
| gr.Examples(examples=image_examples, inputs=[img_q, img_up]) | |
| with gr.TabItem("Video Inference"): | |
| vid_q = gr.Textbox(label="Query Input") | |
| vid_up = gr.Video(label="Upload Video") | |
| vid_btn = gr.Button("Submit", elem_classes="submit-btn") | |
| gr.Examples(examples=video_examples, inputs=[vid_q, vid_up]) | |
| with gr.Column(elem_classes="canvas-output"): | |
| gr.Markdown("## Output") | |
| out_raw = gr.Textbox(interactive=False, lines=2, show_copy_button=True) | |
| with gr.Accordion("Formatted Output", open=False): | |
| out_md = gr.Markdown() | |
| model_choice = gr.Radio( | |
| choices=["Nanonets-OCR-s", "Qwen2-VL-OCR-2B", "RolmOCR-7B", "Lh41-1042-Magellanic-7B-0711"], | |
| label="Select Model", | |
| value="Nanonets-OCR-s" | |
| ) | |
| img_btn.click(generate_image, inputs=[model_choice, img_q, img_up, | |
| gr.Slider(1, MAX_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS), | |
| gr.Slider(0.1,4.0,value=0.6), | |
| gr.Slider(0.05,1.0,value=0.9), | |
| gr.Slider(1,1000,value=50), | |
| gr.Slider(1.0,2.0,value=1.2)], | |
| outputs=[out_raw, out_md]) | |
| vid_btn.click(generate_video, inputs=[model_choice, vid_q, vid_up, | |
| gr.Slider(1, MAX_MAX_NEW_TOKENS, value=DEFAULT_MAX_NEW_TOKENS), | |
| gr.Slider(0.1,4.0,value=0.6), | |
| gr.Slider(0.05,1.0,value=0.9), | |
| gr.Slider(1,1000,value=50), | |
| gr.Slider(1.0,2.0,value=1.2)], | |
| outputs=[out_raw, out_md]) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=30).launch(share=True, mcp_server=True, ssr_mode=False, show_error=True) | |