Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -38,9 +38,10 @@ def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
|
|
| 38 |
boxes = output_dict["boxes"][keep].tolist()
|
| 39 |
scores = output_dict["scores"][keep].tolist()
|
| 40 |
labels = output_dict["labels"][keep].tolist()
|
|
|
|
| 41 |
if id2label is not None:
|
| 42 |
labels = [id2label[x] for x in labels]
|
| 43 |
-
|
| 44 |
plt.figure(figsize=(16, 10))
|
| 45 |
plt.imshow(pil_img)
|
| 46 |
ax = plt.gca()
|
|
@@ -49,7 +50,7 @@ def visualize_prediction(pil_img, output_dict, threshold=0.7, id2label=None):
|
|
| 49 |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
|
| 50 |
ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
|
| 51 |
plt.axis("off")
|
| 52 |
-
return fig2img(plt.gcf())
|
| 53 |
|
| 54 |
def detect_objects(model_name,image_input,threshold):
|
| 55 |
print(type(image_input))
|
|
@@ -71,9 +72,9 @@ def detect_objects(model_name,image_input,threshold):
|
|
| 71 |
processed_outputs = make_prediction(image, feature_extractor, model)
|
| 72 |
|
| 73 |
#Visualize prediction
|
| 74 |
-
viz_img = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
|
| 75 |
|
| 76 |
-
return viz_img
|
| 77 |
|
| 78 |
def set_example_image(example: list) -> dict:
|
| 79 |
return gr.Image.update(value=example[0])
|
|
@@ -116,11 +117,14 @@ with demo:
|
|
| 116 |
with gr.Row():
|
| 117 |
example_images = gr.Dataset(components=[img_input],
|
| 118 |
samples=[["airport.jpg"],['football-match.jpg']])
|
| 119 |
-
|
| 120 |
img_but = gr.Button('Detect')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
-
url_but.click(detect_objects,inputs=[options,url_input,slider_input],outputs=img_output_from_url,queue=True)
|
| 123 |
-
img_but.click(detect_objects,inputs=[options,img_input,slider_input],outputs=img_output_from_upload,queue=True)
|
| 124 |
example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input])
|
| 125 |
example_url.click(fn=set_example_url,inputs=[example_url],outputs=[url_input])
|
| 126 |
|
|
@@ -128,4 +132,4 @@ with demo:
|
|
| 128 |
#gr.Markdown("")
|
| 129 |
|
| 130 |
|
| 131 |
-
demo.launch(enable_queue=True)
|
|
|
|
| 38 |
boxes = output_dict["boxes"][keep].tolist()
|
| 39 |
scores = output_dict["scores"][keep].tolist()
|
| 40 |
labels = output_dict["labels"][keep].tolist()
|
| 41 |
+
print(labels)
|
| 42 |
if id2label is not None:
|
| 43 |
labels = [id2label[x] for x in labels]
|
| 44 |
+
res = dict(zip(labels, scores))
|
| 45 |
plt.figure(figsize=(16, 10))
|
| 46 |
plt.imshow(pil_img)
|
| 47 |
ax = plt.gca()
|
|
|
|
| 50 |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
|
| 51 |
ax.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
|
| 52 |
plt.axis("off")
|
| 53 |
+
return fig2img(plt.gcf()),res
|
| 54 |
|
| 55 |
def detect_objects(model_name,image_input,threshold):
|
| 56 |
print(type(image_input))
|
|
|
|
| 72 |
processed_outputs = make_prediction(image, feature_extractor, model)
|
| 73 |
|
| 74 |
#Visualize prediction
|
| 75 |
+
viz_img,labels = visualize_prediction(image, processed_outputs, threshold, model.config.id2label)
|
| 76 |
|
| 77 |
+
return viz_img,labels
|
| 78 |
|
| 79 |
def set_example_image(example: list) -> dict:
|
| 80 |
return gr.Image.update(value=example[0])
|
|
|
|
| 117 |
with gr.Row():
|
| 118 |
example_images = gr.Dataset(components=[img_input],
|
| 119 |
samples=[["airport.jpg"],['football-match.jpg']])
|
|
|
|
| 120 |
img_but = gr.Button('Detect')
|
| 121 |
+
|
| 122 |
+
with gr.TabItem('Labels'):
|
| 123 |
+
with gr.Row():
|
| 124 |
+
label = gr.Label(label = 'Labels')
|
| 125 |
|
| 126 |
+
url_but.click(detect_objects,inputs=[options,url_input,slider_input],outputs=[img_output_from_url,label],queue=True)
|
| 127 |
+
img_but.click(detect_objects,inputs=[options,img_input,slider_input],outputs=[img_output_from_upload,label],queue=True)
|
| 128 |
example_images.click(fn=set_example_image,inputs=[example_images],outputs=[img_input])
|
| 129 |
example_url.click(fn=set_example_url,inputs=[example_url],outputs=[url_input])
|
| 130 |
|
|
|
|
| 132 |
#gr.Markdown("")
|
| 133 |
|
| 134 |
|
| 135 |
+
demo.launch(enable_queue=True,show_api=False)
|