File size: 4,800 Bytes
af29e00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db2cd8e
af29e00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db2cd8e
af29e00
db2cd8e
 
af29e00
db2cd8e
 
af29e00
db2cd8e
 
af29e00
db2cd8e
 
 
 
 
 
 
af29e00
db2cd8e
af29e00
 
 
 
 
 
 
 
 
db2cd8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1503cd0
db2cd8e
 
 
 
 
 
af29e00
db2cd8e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import warnings
warnings.filterwarnings('ignore')

import subprocess, io, os, sys, time
os.system("pip install gradio==3.36.1")
import gradio as gr

from loguru import logger

# os.system("pip install diffuser==0.6.0")
# os.system("pip install transformers==4.29.1")

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

if os.environ.get('IS_MY_DEBUG') is None:
    result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True)
    print(f'pip install GroundingDINO = {result}')

# result = subprocess.run(['pip', 'list'], check=True)
# print(f'pip list = {result}')

sys.path.insert(0, './GroundingDINO')

import argparse
import copy

import numpy as np
import torch
from PIL import Image

# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap


def load_image(image_path):
    # # load image
    if isinstance(image_path, PIL.Image.Image):
        image_pil = image_path
    else:
        image_pil = Image.open(image_path).convert("RGB")  # load image

    transform = T.Compose(
        [
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    image, _ = transform(image_pil, None)  # 3, h, w
    return image_pil, image


def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
    caption = caption.lower()
    caption = caption.strip()
    if not caption.endswith("."):
        caption = caption + "."
    model = model.to(device)
    image = image.to(device)
    with torch.no_grad():
        outputs = model(image[None], captions=[caption])
    logits = outputs["pred_logits"].cpu().sigmoid()[0]  # (nq, 256)
    boxes = outputs["pred_boxes"].cpu()[0]  # (nq, 4)
    logits.shape[0]

    # filter output
    logits_filt = logits.clone()
    boxes_filt = boxes.clone()
    filt_mask = logits_filt.max(dim=1)[0] > box_threshold
    logits_filt = logits_filt[filt_mask]  # num_filt, 256
    boxes_filt = boxes_filt[filt_mask]  # num_filt, 4
    logits_filt.shape[0]

    # get phrase
    tokenlizer = model.tokenizer
    tokenized = tokenlizer(caption)
    # build pred
    pred_phrases = []
    for logit, box in zip(logits_filt, boxes_filt):
        pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
        if with_logits:
            pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
        else:
            pred_phrases.append(pred_phrase)

    return boxes_filt, pred_phrases


def run_inference(input_image, text_prompt, box_threshold, text_threshold, config_file, ckpt_repo_id, ckpt_filenmae):

    # Load the Grounding DINO model
    model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)

    # Load the input image
    image_pil, image = load_image(input_image)

    # Run the object detection and grounding model
    boxes, labels = get_grounding_output(model, image, text_prompt, box_threshold, text_threshold)

    # Convert the boxes and labels to a JSON format
    result = []
    for box, label in zip(boxes, labels):
        result.append({
            "box": box.tolist(),
            "label": label
        })

    return result


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
    parser.add_argument("--debug", action="store_true", help="using debug mode")
    parser.add_argument("--share", action="store_true", help="share the app")
    args = parser.parse_args()
    print(f'args = {args}')

    model_config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'
    model_ckpt_repo_id = "ShilongLiu/GroundingDINO"
    model_ckpt_filenmae = "groundingdino_swint_ogc.pth"

    def inference_func(input_image, text_prompt):
        result = run_inference(input_image, text_prompt, 0.3, 0.25, model_config_file, model_ckpt_repo_id, model_ckpt_filenmae)
        return result

    # Create the Gradio interface for the model
    interface = gr.Interface(
        fn=inference_func,
        inputs=[
            gr.inputs.Image(label="Input Image"),
            gr.inputs.Textbox(label="Detection Prompt")
        ],
        outputs=gr.outputs.Dataframe(type="pandas"),
        title="Object Detection and Grounding",
        description="A Gradio app to detect objects in an image and ground them to captions using Grounding DINO.",
        server_name='0.0.0.0',
        debug=args.debug,
        share=args.share
    )

    # Launch the interface
    interface.launch()