Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -84,15 +84,6 @@ def draw_contours_on_image(img, index_mask, color_mask, brightness_factor=1.6, a
|
|
| 84 |
return np.clip(blended, 0, 255).astype("uint8")
|
| 85 |
|
| 86 |
|
| 87 |
-
def extract_first_frame_from_video(video):
|
| 88 |
-
cap = cv2.VideoCapture(video)
|
| 89 |
-
success, frame = cap.read()
|
| 90 |
-
cap.release()
|
| 91 |
-
if success:
|
| 92 |
-
return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 93 |
-
return None
|
| 94 |
-
|
| 95 |
-
|
| 96 |
def extract_points_from_mask(mask_pil):
|
| 97 |
mask = np.asarray(mask_pil)[..., 0]
|
| 98 |
coords = np.nonzero(mask)
|
|
@@ -100,26 +91,6 @@ def extract_points_from_mask(mask_pil):
|
|
| 100 |
|
| 101 |
return coords
|
| 102 |
|
| 103 |
-
def add_contour(img, mask, color=(1., 1., 1.)):
|
| 104 |
-
img = img.copy()
|
| 105 |
-
|
| 106 |
-
mask = mask.astype(np.uint8) * 255
|
| 107 |
-
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 108 |
-
cv2.drawContours(img, contours, -1, color, thickness=8)
|
| 109 |
-
|
| 110 |
-
return img
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def load_first_frame(video_path):
|
| 114 |
-
cap = cv2.VideoCapture(video_path)
|
| 115 |
-
ret, frame = cap.read()
|
| 116 |
-
cap.release()
|
| 117 |
-
if not ret:
|
| 118 |
-
raise gr.Error("Could not read the video file.")
|
| 119 |
-
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 120 |
-
image = Image.fromarray(frame)
|
| 121 |
-
return image
|
| 122 |
-
|
| 123 |
|
| 124 |
def clear_masks():
|
| 125 |
return [], [], [], []
|
|
@@ -146,11 +117,11 @@ def apply_sam(image, input_points):
|
|
| 146 |
|
| 147 |
|
| 148 |
@spaces.GPU(duration=120)
|
| 149 |
-
def run(mode, images, timestamps, masks, mask_ids, instruction, mask_output_video):
|
| 150 |
if mode == "QA":
|
| 151 |
response = run_text_inference(images, timestamps, masks, mask_ids, instruction)
|
| 152 |
else:
|
| 153 |
-
response, mask_output_video = run_seg_inference(images, timestamps, instruction)
|
| 154 |
return response, mask_output_video
|
| 155 |
|
| 156 |
|
|
@@ -181,7 +152,7 @@ def run_text_inference(images, timestamps, masks, mask_ids, instruction):
|
|
| 181 |
return output
|
| 182 |
|
| 183 |
|
| 184 |
-
def run_seg_inference(images, timestamps, instruction):
|
| 185 |
output, masks = mm_infer_segmentation(
|
| 186 |
(images, timestamps),
|
| 187 |
seg_processor,
|
|
@@ -190,6 +161,7 @@ def run_seg_inference(images, timestamps, instruction):
|
|
| 190 |
tokenizer=processor.tokenizer,
|
| 191 |
do_sample=False,
|
| 192 |
modal='video',
|
|
|
|
| 193 |
)
|
| 194 |
|
| 195 |
w, h = images[0].size
|
|
@@ -255,7 +227,7 @@ if __name__ == "__main__":
|
|
| 255 |
<h1 align="center"><img src="https://github.com/alibaba-damo-academy/RynnEC/blob/main/assets/logo.jpg?raw=true" style="vertical-align: middle; width: 45px; height: auto;"> RynnEC Demo</h1>
|
| 256 |
<h5 align="center" style="margin: 0;">Feel free to click on anything that grabs your interest!</h5>
|
| 257 |
<h5 align="center" style="margin: 0;">If this demo please you, please give us a star β on Github or π on this space.</h5>
|
| 258 |
-
<div style="display: flex; justify-content:
|
| 259 |
<a href="https://huggingface.co/Alibaba-DAMO-Academy/RynnEC-2B"><img src="https://img.shields.io/badge/π€-Checkpoints-FBD49F.svg" style="margin-right: 5px;"></a>
|
| 260 |
<a href="https://huggingface.co/datasets/Alibaba-DAMO-Academy/RynnEC-Bench"><img src="https://img.shields.io/badge/π€-Benchmark-FBD49F.svg" style="margin-right: 5px;"></a>
|
| 261 |
<a href="https://www.youtube.com/watch?v=vsMxbzsmrQc"><img src="https://img.shields.io/badge/Video-36600E?logo=youtube&logoColor=green" style="margin-right: 5px;"></a>
|
|
@@ -265,7 +237,6 @@ if __name__ == "__main__":
|
|
| 265 |
|
| 266 |
TIPS = """
|
| 267 |
### π‘ Tips:
|
| 268 |
-
|
| 269 |
π§Έ Upload a video, and select a frame using the slider.
|
| 270 |
|
| 271 |
βοΈ Use the drawing tool to highlight the areas you're interested in.
|
|
@@ -274,51 +245,51 @@ if __name__ == "__main__":
|
|
| 274 |
|
| 275 |
π Click the button 'Clear Masks' to clear the current generated masks.
|
| 276 |
|
|
|
|
| 277 |
"""
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
""
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
label="
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
frames, timestamps = load_video(video_path, fps=1, max_frames=128)
|
| 322 |
frames = [Image.fromarray(x.transpose(1, 2, 0)) for x in frames]
|
| 323 |
return frames, timestamps, frames[0], gr.update(value=0, maximum=len(frames) - 1, interactive=True)
|
| 324 |
|
|
@@ -328,13 +299,15 @@ if __name__ == "__main__":
|
|
| 328 |
def to_seg_mode():
|
| 329 |
return (
|
| 330 |
*[gr.update(visible=False) for _ in range(4)],
|
| 331 |
-
[]
|
|
|
|
| 332 |
)
|
| 333 |
|
| 334 |
def to_qa_mode():
|
| 335 |
return (
|
| 336 |
*[gr.update(visible=True) for _ in range(4)],
|
| 337 |
-
[]
|
|
|
|
| 338 |
)
|
| 339 |
|
| 340 |
def on_mode_change(mode):
|
|
@@ -342,8 +315,8 @@ if __name__ == "__main__":
|
|
| 342 |
return to_qa_mode()
|
| 343 |
return to_seg_mode()
|
| 344 |
|
| 345 |
-
mode_video.change(on_mode_change, inputs=[mode_video], outputs=[frame_idx, selected_frame, generate_mask_btn_video, response_video, mask_output_video])
|
| 346 |
-
video_input.change(on_video_upload, inputs=[video_input], outputs=[frames, timestamps, selected_frame, frame_idx])
|
| 347 |
frame_idx.change(on_frame_idx_change, inputs=[frame_idx, frames], outputs=[selected_frame])
|
| 348 |
|
| 349 |
generate_mask_btn_video.click(
|
|
@@ -354,7 +327,7 @@ if __name__ == "__main__":
|
|
| 354 |
|
| 355 |
submit_btn_video1.click(
|
| 356 |
fn=run,
|
| 357 |
-
inputs=[mode_video, frames, timestamps, mask_raw_list_video, mask_ids, query_video, mask_output_video],
|
| 358 |
outputs=[response_video, mask_output_video],
|
| 359 |
api_name="describe_video"
|
| 360 |
)
|
|
@@ -372,10 +345,8 @@ if __name__ == "__main__":
|
|
| 372 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 373 |
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
| 374 |
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
| 375 |
-
# sam_model = sam_processor = None
|
| 376 |
disable_torch_init()
|
| 377 |
model, processor = model_init(args_cli.model_path)
|
| 378 |
seg_model, seg_processor = model_init(args_cli.seg_model_path)
|
| 379 |
-
# model = processor = None
|
| 380 |
|
| 381 |
demo.launch()
|
|
|
|
| 84 |
return np.clip(blended, 0, 255).astype("uint8")
|
| 85 |
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
def extract_points_from_mask(mask_pil):
|
| 88 |
mask = np.asarray(mask_pil)[..., 0]
|
| 89 |
coords = np.nonzero(mask)
|
|
|
|
| 91 |
|
| 92 |
return coords
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
def clear_masks():
|
| 96 |
return [], [], [], []
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
@spaces.GPU(duration=120)
|
| 120 |
+
def run(mode, images, timestamps, masks, mask_ids, instruction, mask_output_video, mask_threshold):
|
| 121 |
if mode == "QA":
|
| 122 |
response = run_text_inference(images, timestamps, masks, mask_ids, instruction)
|
| 123 |
else:
|
| 124 |
+
response, mask_output_video = run_seg_inference(images, timestamps, instruction, mask_threshold)
|
| 125 |
return response, mask_output_video
|
| 126 |
|
| 127 |
|
|
|
|
| 152 |
return output
|
| 153 |
|
| 154 |
|
| 155 |
+
def run_seg_inference(images, timestamps, instruction, mask_threshold):
|
| 156 |
output, masks = mm_infer_segmentation(
|
| 157 |
(images, timestamps),
|
| 158 |
seg_processor,
|
|
|
|
| 161 |
tokenizer=processor.tokenizer,
|
| 162 |
do_sample=False,
|
| 163 |
modal='video',
|
| 164 |
+
mask_threshold=mask_threshold,
|
| 165 |
)
|
| 166 |
|
| 167 |
w, h = images[0].size
|
|
|
|
| 227 |
<h1 align="center"><img src="https://github.com/alibaba-damo-academy/RynnEC/blob/main/assets/logo.jpg?raw=true" style="vertical-align: middle; width: 45px; height: auto;"> RynnEC Demo</h1>
|
| 228 |
<h5 align="center" style="margin: 0;">Feel free to click on anything that grabs your interest!</h5>
|
| 229 |
<h5 align="center" style="margin: 0;">If this demo please you, please give us a star β on Github or π on this space.</h5>
|
| 230 |
+
<div style="display: flex; justify-content: center; margin-top: 10px;">
|
| 231 |
<a href="https://huggingface.co/Alibaba-DAMO-Academy/RynnEC-2B"><img src="https://img.shields.io/badge/π€-Checkpoints-FBD49F.svg" style="margin-right: 5px;"></a>
|
| 232 |
<a href="https://huggingface.co/datasets/Alibaba-DAMO-Academy/RynnEC-Bench"><img src="https://img.shields.io/badge/π€-Benchmark-FBD49F.svg" style="margin-right: 5px;"></a>
|
| 233 |
<a href="https://www.youtube.com/watch?v=vsMxbzsmrQc"><img src="https://img.shields.io/badge/Video-36600E?logo=youtube&logoColor=green" style="margin-right: 5px;"></a>
|
|
|
|
| 237 |
|
| 238 |
TIPS = """
|
| 239 |
### π‘ Tips:
|
|
|
|
| 240 |
π§Έ Upload a video, and select a frame using the slider.
|
| 241 |
|
| 242 |
βοΈ Use the drawing tool to highlight the areas you're interested in.
|
|
|
|
| 245 |
|
| 246 |
π Click the button 'Clear Masks' to clear the current generated masks.
|
| 247 |
|
| 248 |
+
βοΈ If you change the settings, you need to re-upload the video to apply the new settings.
|
| 249 |
"""
|
| 250 |
|
| 251 |
+
gr.HTML(HEADER)
|
| 252 |
+
|
| 253 |
+
with gr.Tab("Demo"):
|
| 254 |
+
with gr.Row():
|
| 255 |
+
with gr.Column():
|
| 256 |
+
video_input = gr.Video(label="Video", interactive=True)
|
| 257 |
+
frame_idx = gr.Slider(minimum=0, maximum=0, value=0, step=1, label="Select Frame", interactive=False)
|
| 258 |
+
selected_frame = gr.ImageEditor(
|
| 259 |
+
label="Annotate Frame",
|
| 260 |
+
type="pil",
|
| 261 |
+
sources=[],
|
| 262 |
+
interactive=True,
|
| 263 |
+
)
|
| 264 |
+
generate_mask_btn_video = gr.Button("1οΈβ£ Generate Mask", visible=True, variant="primary")
|
| 265 |
+
gr.Examples([f"./demo/videos/{i+1}.mp4" for i in range(4)], inputs=video_input, label="Examples")
|
| 266 |
+
|
| 267 |
+
with gr.Column():
|
| 268 |
+
mode_video = gr.Radio(label="Mode", choices=["QA", "Seg"], value="QA")
|
| 269 |
+
mask_output_video = gr.Gallery(label="Referred Masks", object_fit='scale-down')
|
| 270 |
+
|
| 271 |
+
query_video = gr.Textbox(label="Question", value="What's the function of <object0>?", interactive=True, visible=True)
|
| 272 |
+
response_video = gr.Textbox(label="Answer", interactive=False)
|
| 273 |
+
|
| 274 |
+
submit_btn_video = gr.Button("Generate Caption", variant="primary", visible=False)
|
| 275 |
+
submit_btn_video1 = gr.Button("2οΈβ£ Generate Answer", variant="primary", visible=True)
|
| 276 |
+
description_video = gr.Textbox(label="Output", visible=False)
|
| 277 |
+
|
| 278 |
+
clear_masks_btn_video = gr.Button("Clear Masks", variant="secondary")
|
| 279 |
+
|
| 280 |
+
with gr.Tab("Settings"):
|
| 281 |
+
fps = gr.Slider(label="FPS", minimum=1, maximum=30, value=1, step=1)
|
| 282 |
+
max_frames = gr.Slider(label="Max Frames", minimum=1, maximum=128, value=80, step=1)
|
| 283 |
+
mask_threshold = gr.Slider(label="Mask Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.01)
|
| 284 |
+
|
| 285 |
+
gr.Markdown(TIPS)
|
| 286 |
+
|
| 287 |
+
frames = gr.State(value=[])
|
| 288 |
+
timestamps = gr.State(value=[])
|
| 289 |
+
mask_ids = gr.State(value=[])
|
| 290 |
+
|
| 291 |
+
def on_video_upload(video_path, fps, max_frames):
|
| 292 |
+
frames, timestamps = load_video(video_path, fps=fps, max_frames=max_frames)
|
|
|
|
| 293 |
frames = [Image.fromarray(x.transpose(1, 2, 0)) for x in frames]
|
| 294 |
return frames, timestamps, frames[0], gr.update(value=0, maximum=len(frames) - 1, interactive=True)
|
| 295 |
|
|
|
|
| 299 |
def to_seg_mode():
|
| 300 |
return (
|
| 301 |
*[gr.update(visible=False) for _ in range(4)],
|
| 302 |
+
[],
|
| 303 |
+
"Please segment the rubbish bin.",
|
| 304 |
)
|
| 305 |
|
| 306 |
def to_qa_mode():
|
| 307 |
return (
|
| 308 |
*[gr.update(visible=True) for _ in range(4)],
|
| 309 |
+
[],
|
| 310 |
+
"What's the function of <object0>?",
|
| 311 |
)
|
| 312 |
|
| 313 |
def on_mode_change(mode):
|
|
|
|
| 315 |
return to_qa_mode()
|
| 316 |
return to_seg_mode()
|
| 317 |
|
| 318 |
+
mode_video.change(on_mode_change, inputs=[mode_video], outputs=[frame_idx, selected_frame, generate_mask_btn_video, response_video, mask_output_video, query_video])
|
| 319 |
+
video_input.change(on_video_upload, inputs=[video_input, fps, max_frames], outputs=[frames, timestamps, selected_frame, frame_idx])
|
| 320 |
frame_idx.change(on_frame_idx_change, inputs=[frame_idx, frames], outputs=[selected_frame])
|
| 321 |
|
| 322 |
generate_mask_btn_video.click(
|
|
|
|
| 327 |
|
| 328 |
submit_btn_video1.click(
|
| 329 |
fn=run,
|
| 330 |
+
inputs=[mode_video, frames, timestamps, mask_raw_list_video, mask_ids, query_video, mask_output_video, mask_threshold],
|
| 331 |
outputs=[response_video, mask_output_video],
|
| 332 |
api_name="describe_video"
|
| 333 |
)
|
|
|
|
| 345 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 346 |
sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
|
| 347 |
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
|
|
|
| 348 |
disable_torch_init()
|
| 349 |
model, processor = model_init(args_cli.model_path)
|
| 350 |
seg_model, seg_processor = model_init(args_cli.seg_model_path)
|
|
|
|
| 351 |
|
| 352 |
demo.launch()
|