zye0616's picture
mission detection with summary
a8d3381
raw
history blame
2.25 kB
import logging
from typing import Sequence
import numpy as np
import torch
from transformers import Owlv2ForObjectDetection, Owlv2Processor
from models.detectors.base import DetectionResult, ObjectDetector
class Owlv2Detector(ObjectDetector):
MODEL_NAME = "google/owlv2-large-patch14"
def __init__(self) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info("Loading %s onto %s", self.MODEL_NAME, self.device)
self.processor = Owlv2Processor.from_pretrained(self.MODEL_NAME)
torch_dtype = torch.float16 if self.device.type == "cuda" else torch.float32
self.model = Owlv2ForObjectDetection.from_pretrained(
self.MODEL_NAME, torch_dtype=torch_dtype
)
self.model.to(self.device)
self.model.eval()
self.name = "owlv2"
def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
inputs = self.processor(text=queries, images=frame, return_tensors="pt")
if hasattr(inputs, "to"):
inputs = inputs.to(self.device)
else:
inputs = {
key: value.to(self.device) if hasattr(value, "to") else value
for key, value in inputs.items()
}
with torch.no_grad():
outputs = self.model(**inputs)
processed = self.processor.post_process_object_detection(
outputs, threshold=0.3, target_sizes=[frame.shape[:2]]
)[0]
boxes = processed["boxes"]
scores = processed.get("scores", [])
labels = processed.get("labels", [])
boxes_np = boxes.cpu().numpy() if hasattr(boxes, "cpu") else np.asarray(boxes)
if hasattr(scores, "cpu"):
scores_seq = scores.cpu().numpy().tolist()
elif isinstance(scores, np.ndarray):
scores_seq = scores.tolist()
else:
scores_seq = list(scores)
if hasattr(labels, "cpu"):
labels_seq = labels.cpu().numpy().tolist()
elif isinstance(labels, np.ndarray):
labels_seq = labels.tolist()
else:
labels_seq = list(labels)
return DetectionResult(boxes=boxes_np, scores=scores_seq, labels=labels_seq)