| import os | |
| import cv2 | |
| import torch | |
| import traceback | |
| import numpy as np | |
| import gradio as gr | |
| from itertools import chain | |
| from huggingface_hub import hf_hub_download | |
| from segment_anything import SamPredictor, sam_model_registry | |
| hf_hub_download(repo_id="vmoras/sam_api", filename="sam_vit_h.pth", token=os.environ.get('model_token'), local_dir="./") | |
| sam_checkpoint = "sam_vit_h.pth" | |
| model_type = "vit_h" | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def set_predictor(image): | |
| """ | |
| Creates a Sam predictor object based on a given image and model. | |
| """ | |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) | |
| sam.to(device=device) | |
| predictor = SamPredictor(sam) | |
| predictor.set_image(image) | |
| return [image, predictor, 'Done'] | |
| def get_polygon(points, image, predictor): | |
| """ | |
| Returns the points of the polygon given a bounding box and a prediction | |
| made by Sam. | |
| """ | |
| points = list(chain.from_iterable(points)) | |
| input_box = np.array(points) | |
| masks, _, _ = predictor.predict( | |
| box=input_box[None, :], | |
| multimask_output=False, | |
| ) | |
| img = masks[0].astype(np.uint8) | |
| contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) | |
| if len(contours) == 0: | |
| return [], img | |
| points = contours[0] | |
| polygon = [] | |
| for point in points: | |
| for x, y in point: | |
| polygon.append([int(x), int(y)]) | |
| mask = np.zeros(image.shape, dtype='uint8') | |
| poly = np.array(polygon) | |
| cv2.fillPoly(mask, [poly], (0, 255, 0)) | |
| return polygon, mask | |
| def add_bbox(bbox, evt: gr.SelectData): | |
| if bbox[0] == [0, 0]: | |
| bbox[0] = [evt.index[0], evt.index[1]] | |
| return bbox, bbox | |
| bbox[1] = [evt.index[0], evt.index[1]] | |
| return bbox, bbox | |
| def clear_bbox(bbox): | |
| updated_bbox = [[0, 0], [0, 0]] | |
| return updated_bbox, updated_bbox | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| # Instructions | |
| 1. Upload the image and press 'Send Image'. | |
| 2. Wait until the word 'Done' appears on the 'Status' box. | |
| 3. Click on the image where the upper left corner of the bbox should be. | |
| 4. Click on the image where the lower right corner of the bbox should be. | |
| 5. Check the coordinates using the 'bbox' box. | |
| 6. Click on 'Send bounding box'. | |
| 7. On the right side you will see the binary mask '\*'. | |
| 8. On the lower side you will see the points that made up the polygon '\*'. | |
| 9. Click on 'Clear bbox' to send another bounding box and repeat the steps from the thrid step. | |
| 10. Repeat steps 3 to 9 until all the segments for this image are done. | |
| 11. Click on the right corner of the image to remove it and repeat all the steps with the next | |
| image. | |
| '\*' If the binary mask is all black and the polygon is an empty list, it means the program did | |
| not find any segment in the bbox. Make the bbox a little big bigger if that happens. | |
| """) | |
| image = gr.State() | |
| embedding = gr.State() | |
| bbox = gr.State([[0, 0], [0, 0]]) | |
| with gr.Row(): | |
| input_image = gr.Image(label='Image') | |
| mask = gr.Image(label='Mask') | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_status = gr.Textbox(label='Status') | |
| with gr.Column(): | |
| predictor_button = gr.Button('Send Image') | |
| with gr.Row(): | |
| with gr.Column(): | |
| bbox_box = gr.Textbox(label="bbox") | |
| with gr.Column(): | |
| bbox_button = gr.Button('Clear bbox') | |
| with gr.Row(): | |
| with gr.Column(): | |
| polygon = gr.Textbox(label='Polygon') | |
| with gr.Column(): | |
| points_button = gr.Button('Send bounding box') | |
| predictor_button.click( | |
| set_predictor, | |
| input_image, | |
| [image, embedding, output_status], | |
| ) | |
| points_button.click( | |
| get_polygon, | |
| [bbox, image, embedding], | |
| [polygon, mask], | |
| ) | |
| bbox_button.click( | |
| clear_bbox, | |
| bbox, | |
| [bbox, bbox_box], | |
| ) | |
| input_image.select( | |
| add_bbox, | |
| bbox, | |
| [bbox, bbox_box] | |
| ) | |
| demo.launch(debug=True, auth=(os.environ['user'], os.environ['password'])) |