| |
|
| | import torch |
| | import requests |
| | from PIL import Image |
| | from io import BytesIO |
| | from pathlib import Path |
| | from typing import Union, List, Dict, Any |
| | import sys |
| |
|
| | |
| | OCULUS_ROOT = Path(__file__).parent |
| | sys.path.insert(0, str(OCULUS_ROOT)) |
| |
|
| | try: |
| | from oculus_unified_model import OculusForConditionalGeneration |
| | except ImportError: |
| | |
| | from Oculus.oculus_unified_model import OculusForConditionalGeneration |
| |
|
| | class OculusPredictor: |
| | """ |
| | Easy-to-use interface for the Oculus Unified Model. |
| | Supports Object Detection, VQA, and Captioning. |
| | """ |
| | |
| | def __init__(self, model_path: str = None, device: str = "cpu"): |
| | self.device = device |
| | |
| | |
| | if model_path is None: |
| | base_dir = OCULUS_ROOT / "checkpoints" / "oculus_detection_v2" |
| | if (base_dir / "final").exists(): |
| | model_path = str(base_dir / "final") |
| | else: |
| | |
| | model_path = str(OCULUS_ROOT / "checkpoints" / "oculus_detection" / "final") |
| | |
| | print(f"Loading Oculus model from: {model_path}") |
| | self.model = OculusForConditionalGeneration.from_pretrained(model_path) |
| | |
| | |
| | heads_path = Path(model_path) / "heads.pth" |
| | if heads_path.exists(): |
| | heads = torch.load(heads_path, map_location=device) |
| | self.model.detection_head.load_state_dict(heads['detection']) |
| | print("✓ Detection heads loaded") |
| | |
| | |
| | instruct_path = OCULUS_ROOT / "checkpoints" / "oculus_instruct_v1" / "vqa_model" |
| | if instruct_path.exists(): |
| | from transformers import BlipForQuestionAnswering |
| | self.model.lm_vqa_model = BlipForQuestionAnswering.from_pretrained(instruct_path) |
| | print("✓ Instruction-tuned VQA model loaded") |
| | |
| | print("✓ Model loaded successfully") |
| |
|
| | def load_image(self, image_source: Union[str, Image.Image]) -> Image.Image: |
| | """Load image from path, URL, or PIL object.""" |
| | if isinstance(image_source, Image.Image): |
| | return image_source.convert("RGB") |
| | |
| | if image_source.startswith("http"): |
| | response = requests.get(image_source, headers={'User-Agent': 'Mozilla/5.0'}) |
| | return Image.open(BytesIO(response.content)).convert("RGB") |
| | |
| | return Image.open(image_source).convert("RGB") |
| |
|
| | def detect(self, image_source: Union[str, Image.Image], prompt: str = "Detect objects", threshold: float = 0.2) -> Dict[str, Any]: |
| | """ |
| | Run object detection. |
| | Returns: {'boxes': [[x1,y1,x2,y2], ...], 'labels': [...], 'confidences': [...]} |
| | """ |
| | image = self.load_image(image_source) |
| | output = self.model.generate(image, mode="box", prompt=prompt, threshold=threshold) |
| | |
| | |
| | return { |
| | 'boxes': output.boxes, |
| | 'labels': output.labels, |
| | 'confidences': output.confidences, |
| | 'image_size': image.size |
| | } |
| |
|
| | def ask(self, image_source: Union[str, Image.Image], question: str) -> str: |
| | """Ask a question about the image (VQA).""" |
| | image = self.load_image(image_source) |
| | output = self.model.generate(image, mode="text", prompt=question) |
| | return output.text |
| |
|
| | def caption(self, image_source: Union[str, Image.Image]) -> str: |
| | """Generate a caption for the image.""" |
| | return self.ask(image_source, "A photo of") |
| |
|