Spaces:
Paused
Paused
make predictor global and remove bf16
Browse files
app.py
CHANGED
|
@@ -72,6 +72,7 @@ examples = [
|
|
| 72 |
OBJ_ID = 0
|
| 73 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
| 74 |
model_cfg = "edgetam.yaml"
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
def get_video_fps(video_path):
|
|
@@ -226,7 +227,6 @@ def preprocess_video_in(
|
|
| 226 |
input_points = []
|
| 227 |
input_labels = []
|
| 228 |
|
| 229 |
-
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
| 230 |
inference_state = predictor.init_state(
|
| 231 |
offload_video_to_cpu=True,
|
| 232 |
offload_state_to_cpu=True,
|
|
@@ -255,7 +255,6 @@ def segment_with_points(
|
|
| 255 |
inference_state,
|
| 256 |
evt: gr.SelectData,
|
| 257 |
):
|
| 258 |
-
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
| 259 |
input_points.append(evt.index)
|
| 260 |
print(f"TRACKING INPUT POINT: {input_points}")
|
| 261 |
|
|
@@ -337,12 +336,13 @@ def propagate_to_all(
|
|
| 337 |
input_points,
|
| 338 |
inference_state,
|
| 339 |
):
|
| 340 |
-
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
| 341 |
-
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
|
| 342 |
if torch.cuda.get_device_properties(0).major >= 8:
|
| 343 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 344 |
torch.backends.cudnn.allow_tf32 = True
|
| 345 |
|
|
|
|
|
|
|
| 346 |
if len(input_points) == 0 or video_in is None or inference_state is None:
|
| 347 |
return None
|
| 348 |
# run propagation throughout the video and collect the results in a dict
|
|
|
|
| 72 |
OBJ_ID = 0
|
| 73 |
sam2_checkpoint = "checkpoints/edgetam.pt"
|
| 74 |
model_cfg = "edgetam.yaml"
|
| 75 |
+
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cpu")
|
| 76 |
|
| 77 |
|
| 78 |
def get_video_fps(video_path):
|
|
|
|
| 227 |
input_points = []
|
| 228 |
input_labels = []
|
| 229 |
|
|
|
|
| 230 |
inference_state = predictor.init_state(
|
| 231 |
offload_video_to_cpu=True,
|
| 232 |
offload_state_to_cpu=True,
|
|
|
|
| 255 |
inference_state,
|
| 256 |
evt: gr.SelectData,
|
| 257 |
):
|
|
|
|
| 258 |
input_points.append(evt.index)
|
| 259 |
print(f"TRACKING INPUT POINT: {input_points}")
|
| 260 |
|
|
|
|
| 336 |
input_points,
|
| 337 |
inference_state,
|
| 338 |
):
|
| 339 |
+
# torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
|
|
|
|
| 340 |
if torch.cuda.get_device_properties(0).major >= 8:
|
| 341 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 342 |
torch.backends.cudnn.allow_tf32 = True
|
| 343 |
|
| 344 |
+
predictor.to("cuda")
|
| 345 |
+
|
| 346 |
if len(input_points) == 0 or video_in is None or inference_state is None:
|
| 347 |
return None
|
| 348 |
# run propagation throughout the video and collect the results in a dict
|