Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
fe9d46b
1
Parent(s):
fee29b3
set device to cpu and remove empty cuda cache
Browse files
app.py
CHANGED
|
@@ -130,7 +130,7 @@ class AppState:
|
|
| 130 |
self.inference_session = None
|
| 131 |
self.model: Optional[AutoModel] = None
|
| 132 |
self.processor: Optional[Sam2VideoProcessor] = None
|
| 133 |
-
self.device: str = "
|
| 134 |
self.dtype: torch.dtype = torch.bfloat16
|
| 135 |
self.video_fps: float | None = None
|
| 136 |
self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
|
|
@@ -154,6 +154,9 @@ class AppState:
|
|
| 154 |
self.model_repo_id: str | None = None
|
| 155 |
self.session_repo_id: str | None = None
|
| 156 |
|
|
|
|
|
|
|
|
|
|
| 157 |
@property
|
| 158 |
def num_frames(self) -> int:
|
| 159 |
return len(self.video_frames)
|
|
@@ -189,8 +192,6 @@ def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoPr
|
|
| 189 |
print(f"Loading model from {desired_repo}")
|
| 190 |
device, dtype = get_device_and_dtype()
|
| 191 |
# free up the gpu memory
|
| 192 |
-
torch.cuda.empty_cache()
|
| 193 |
-
gc.collect()
|
| 194 |
model = AutoModel.from_pretrained(desired_repo)
|
| 195 |
processor = Sam2VideoProcessor.from_pretrained(desired_repo)
|
| 196 |
model.to(device, dtype=dtype)
|
|
@@ -225,11 +226,6 @@ def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
|
|
| 225 |
pass
|
| 226 |
GLOBAL_STATE.inference_session = None
|
| 227 |
gc.collect()
|
| 228 |
-
try:
|
| 229 |
-
if torch.cuda.is_available():
|
| 230 |
-
torch.cuda.empty_cache()
|
| 231 |
-
except Exception:
|
| 232 |
-
pass
|
| 233 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 234 |
video=GLOBAL_STATE.video_frames,
|
| 235 |
inference_device=device,
|
|
@@ -566,6 +562,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
|
|
| 566 |
# f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
|
| 567 |
# gr.update(value=last_frame_idx),
|
| 568 |
# )
|
|
|
|
| 569 |
return (
|
| 570 |
GLOBAL_STATE,
|
| 571 |
f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
|
|
@@ -596,11 +593,6 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
|
|
| 596 |
pass
|
| 597 |
GLOBAL_STATE.inference_session = None
|
| 598 |
gc.collect()
|
| 599 |
-
try:
|
| 600 |
-
if torch.cuda.is_available():
|
| 601 |
-
torch.cuda.empty_cache()
|
| 602 |
-
except Exception:
|
| 603 |
-
pass
|
| 604 |
ensure_session_for_current_model(GLOBAL_STATE)
|
| 605 |
|
| 606 |
# Keep current slider index if possible
|
|
|
|
| 130 |
self.inference_session = None
|
| 131 |
self.model: Optional[AutoModel] = None
|
| 132 |
self.processor: Optional[Sam2VideoProcessor] = None
|
| 133 |
+
self.device: str = "cpu"
|
| 134 |
self.dtype: torch.dtype = torch.bfloat16
|
| 135 |
self.video_fps: float | None = None
|
| 136 |
self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
|
|
|
|
| 154 |
self.model_repo_id: str | None = None
|
| 155 |
self.session_repo_id: str | None = None
|
| 156 |
|
| 157 |
+
def __repr__(self):
|
| 158 |
+
return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})"
|
| 159 |
+
|
| 160 |
@property
|
| 161 |
def num_frames(self) -> int:
|
| 162 |
return len(self.video_frames)
|
|
|
|
| 192 |
print(f"Loading model from {desired_repo}")
|
| 193 |
device, dtype = get_device_and_dtype()
|
| 194 |
# free up the gpu memory
|
|
|
|
|
|
|
| 195 |
model = AutoModel.from_pretrained(desired_repo)
|
| 196 |
processor = Sam2VideoProcessor.from_pretrained(desired_repo)
|
| 197 |
model.to(device, dtype=dtype)
|
|
|
|
| 226 |
pass
|
| 227 |
GLOBAL_STATE.inference_session = None
|
| 228 |
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
GLOBAL_STATE.inference_session = processor.init_video_session(
|
| 230 |
video=GLOBAL_STATE.video_frames,
|
| 231 |
inference_device=device,
|
|
|
|
| 562 |
# f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
|
| 563 |
# gr.update(value=last_frame_idx),
|
| 564 |
# )
|
| 565 |
+
print("global state", GLOBAL_STATE)
|
| 566 |
return (
|
| 567 |
GLOBAL_STATE,
|
| 568 |
f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
|
|
|
|
| 593 |
pass
|
| 594 |
GLOBAL_STATE.inference_session = None
|
| 595 |
gc.collect()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
ensure_session_for_current_model(GLOBAL_STATE)
|
| 597 |
|
| 598 |
# Keep current slider index if possible
|