Spaces:
Paused
Paused
File size: 4,800 Bytes
af29e00 db2cd8e af29e00 db2cd8e af29e00 db2cd8e af29e00 db2cd8e af29e00 db2cd8e af29e00 db2cd8e af29e00 db2cd8e af29e00 db2cd8e 1503cd0 db2cd8e af29e00 db2cd8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import warnings
warnings.filterwarnings('ignore')
import subprocess, io, os, sys, time
os.system("pip install gradio==3.36.1")
import gradio as gr
from loguru import logger
# os.system("pip install diffuser==0.6.0")
# os.system("pip install transformers==4.29.1")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if os.environ.get('IS_MY_DEBUG') is None:
result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True)
print(f'pip install GroundingDINO = {result}')
# result = subprocess.run(['pip', 'list'], check=True)
# print(f'pip list = {result}')
sys.path.insert(0, './GroundingDINO')
import argparse
import copy
import numpy as np
import torch
from PIL import Image
# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
def load_image(image_path):
# # load image
if isinstance(image_path, PIL.Image.Image):
image_pil = image_path
else:
image_pil = Image.open(image_path).convert("RGB") # load image
transform = T.Compose(
[
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image, _ = transform(image_pil, None) # 3, h, w
return image_pil, image
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
caption = caption.lower()
caption = caption.strip()
if not caption.endswith("."):
caption = caption + "."
model = model.to(device)
image = image.to(device)
with torch.no_grad():
outputs = model(image[None], captions=[caption])
logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256)
boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4)
logits.shape[0]
# filter output
logits_filt = logits.clone()
boxes_filt = boxes.clone()
filt_mask = logits_filt.max(dim=1)[0] > box_threshold
logits_filt = logits_filt[filt_mask] # num_filt, 256
boxes_filt = boxes_filt[filt_mask] # num_filt, 4
logits_filt.shape[0]
# get phrase
tokenlizer = model.tokenizer
tokenized = tokenlizer(caption)
# build pred
pred_phrases = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
return boxes_filt, pred_phrases
def run_inference(input_image, text_prompt, box_threshold, text_threshold, config_file, ckpt_repo_id, ckpt_filenmae):
# Load the Grounding DINO model
model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
# Load the input image
image_pil, image = load_image(input_image)
# Run the object detection and grounding model
boxes, labels = get_grounding_output(model, image, text_prompt, box_threshold, text_threshold)
# Convert the boxes and labels to a JSON format
result = []
for box, label in zip(boxes, labels):
result.append({
"box": box.tolist(),
"label": label
})
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
parser.add_argument("--debug", action="store_true", help="using debug mode")
parser.add_argument("--share", action="store_true", help="share the app")
args = parser.parse_args()
print(f'args = {args}')
model_config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
model_ckpt_repo_id = "ShilongLiu/GroundingDINO"
model_ckpt_filenmae = "groundingdino_swint_ogc.pth"
def inference_func(input_image, text_prompt):
result = run_inference(input_image, text_prompt, 0.3, 0.25, model_config_file, model_ckpt_repo_id, model_ckpt_filenmae)
return result
# Create the Gradio interface for the model
interface = gr.Interface(
fn=inference_func,
inputs=[
gr.inputs.Image(label="Input Image"),
gr.inputs.Textbox(label="Detection Prompt")
],
outputs=gr.outputs.Dataframe(type="pandas"),
title="Object Detection and Grounding",
description="A Gradio app to detect objects in an image and ground them to captions using Grounding DINO.",
server_name='0.0.0.0',
debug=args.debug,
share=args.share
)
# Launch the interface
interface.launch() |