zye0616's picture
mission detection with summary
a8d3381
import logging
from typing import Sequence
import numpy as np
import torch
from transformers import Owlv2ForObjectDetection, Owlv2Processor
from models.detectors.base import DetectionResult, ObjectDetector
class Owlv2Detector(ObjectDetector):
MODEL_NAME = "google/owlv2-large-patch14"
def __init__(self) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
self.processor = Owlv2Processor.from_pretrained(self.MODEL_NAME)
torch_dtype = torch.float16 if self.device.type == "cuda" else torch.float32
self.model = Owlv2ForObjectDetection.from_pretrained(
self.MODEL_NAME, torch_dtype=torch_dtype
)
self.model.to(self.device)
self.model.eval()
self.name = "owlv2"
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
inputs = self.processor(text=queries, images=frame, return_tensors="pt")
if hasattr(inputs, "to"):
inputs = inputs.to(self.device)
else:
inputs = {
key: value.to(self.device) if hasattr(value, "to") else value
for key, value in inputs.items()
}
with torch.no_grad():
outputs = self.model(**inputs)
processed = self.processor.post_process_object_detection(
outputs, threshold=0.3, target_sizes=[frame.shape[:2]]
)[0]
boxes = processed["boxes"]
scores = processed.get("scores", [])
labels = processed.get("labels", [])
boxes_np = boxes.cpu().numpy() if hasattr(boxes, "cpu") else np.asarray(boxes)
if hasattr(scores, "cpu"):
scores_seq = scores.cpu().numpy().tolist()
elif isinstance(scores, np.ndarray):
scores_seq = scores.tolist()
else:
scores_seq = list(scores)
if hasattr(labels, "cpu"):
labels_seq = labels.cpu().numpy().tolist()
elif isinstance(labels, np.ndarray):
labels_seq = labels.tolist()
else:
labels_seq = list(labels)
return DetectionResult(boxes=boxes_np, scores=scores_seq, labels=labels_seq)