Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from PIL import Image | |
| from transformers import Florence2ForConditionalGeneration, Florence2Processor | |
| MODEL_ID = "ducviet00/Florence-2-large-hf" | |
| # Global variables for lazy loading | |
| _model = None | |
| _processor = None | |
| _device = None | |
| _torch_dtype = None | |
| def _load_model(): | |
| """Load model and processor lazily""" | |
| global _model, _processor, _device, _torch_dtype | |
| if _model is None: | |
| _device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| _torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| print(f"Loading model {MODEL_ID} on {_device} with dtype {_torch_dtype}...") | |
| _model = Florence2ForConditionalGeneration.from_pretrained( | |
| MODEL_ID, dtype=_torch_dtype, trust_remote_code=True | |
| ).to(_device) # type: ignore | |
| _processor = Florence2Processor.from_pretrained( | |
| MODEL_ID, trust_remote_code=True, use_fast=True | |
| ) | |
| print("Model loaded successfully!") | |
| return _model, _processor, _device, _torch_dtype | |
| def get_task_response(task_prompt: str, image: Image.Image, text_input=None): | |
| """Return associated task response | |
| Task can be: | |
| '<MORE_DETAILED_CAPTION>' | |
| '<DETAILED_CAPTION>' | |
| '<CAPTION>' | |
| """ | |
| # Lazy load model only when needed | |
| model, processor, device, torch_dtype = _load_model() | |
| if text_input is None: | |
| prompt = task_prompt | |
| else: | |
| prompt = task_prompt + text_input | |
| # Ensure image is in RGB mode | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| if processor is None: | |
| raise ValueError("processor is None") | |
| inputs = processor( | |
| text=prompt, | |
| images=image, | |
| return_tensors="pt", # type: ignore | |
| ).to(device, torch_dtype) | |
| generated_ids = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| num_beams=3, | |
| ) | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| parsed_answer = processor.post_process_generation( | |
| generated_text, task=task_prompt, image_size=(image.width, image.height) | |
| ) | |
| return parsed_answer[task_prompt] | |