yonigozlan HF Staff commited on
Commit
fe9d46b
·
1 Parent(s): fee29b3

set device to cpu and remove empty cuda cache

Browse files
Files changed (1) hide show
  1. app.py +5 -13
app.py CHANGED
@@ -130,7 +130,7 @@ class AppState:
130
  self.inference_session = None
131
  self.model: Optional[AutoModel] = None
132
  self.processor: Optional[Sam2VideoProcessor] = None
133
- self.device: str = "cuda"
134
  self.dtype: torch.dtype = torch.bfloat16
135
  self.video_fps: float | None = None
136
  self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
@@ -154,6 +154,9 @@ class AppState:
154
  self.model_repo_id: str | None = None
155
  self.session_repo_id: str | None = None
156
 
 
 
 
157
  @property
158
  def num_frames(self) -> int:
159
  return len(self.video_frames)
@@ -189,8 +192,6 @@ def load_model_if_needed(GLOBAL_STATE: gr.State) -> tuple[AutoModel, Sam2VideoPr
189
  print(f"Loading model from {desired_repo}")
190
  device, dtype = get_device_and_dtype()
191
  # free up the gpu memory
192
- torch.cuda.empty_cache()
193
- gc.collect()
194
  model = AutoModel.from_pretrained(desired_repo)
195
  processor = Sam2VideoProcessor.from_pretrained(desired_repo)
196
  model.to(device, dtype=dtype)
@@ -225,11 +226,6 @@ def ensure_session_for_current_model(GLOBAL_STATE: gr.State) -> None:
225
  pass
226
  GLOBAL_STATE.inference_session = None
227
  gc.collect()
228
- try:
229
- if torch.cuda.is_available():
230
- torch.cuda.empty_cache()
231
- except Exception:
232
- pass
233
  GLOBAL_STATE.inference_session = processor.init_video_session(
234
  video=GLOBAL_STATE.video_frames,
235
  inference_device=device,
@@ -566,6 +562,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
566
  # f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
567
  # gr.update(value=last_frame_idx),
568
  # )
 
569
  return (
570
  GLOBAL_STATE,
571
  f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
@@ -596,11 +593,6 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
596
  pass
597
  GLOBAL_STATE.inference_session = None
598
  gc.collect()
599
- try:
600
- if torch.cuda.is_available():
601
- torch.cuda.empty_cache()
602
- except Exception:
603
- pass
604
  ensure_session_for_current_model(GLOBAL_STATE)
605
 
606
  # Keep current slider index if possible
 
130
  self.inference_session = None
131
  self.model: Optional[AutoModel] = None
132
  self.processor: Optional[Sam2VideoProcessor] = None
133
+ self.device: str = "cpu"
134
  self.dtype: torch.dtype = torch.bfloat16
135
  self.video_fps: float | None = None
136
  self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
 
154
  self.model_repo_id: str | None = None
155
  self.session_repo_id: str | None = None
156
 
157
+ def __repr__(self):
158
+ return f"AppState(video_frames={self.video_frames}, inference_session={self.inference_session is not None}, model={self.model is not None}, processor={self.processor is not None}, device={self.device}, dtype={self.dtype}, video_fps={self.video_fps}, masks_by_frame={self.masks_by_frame}, color_by_obj={self.color_by_obj}, clicks_by_frame_obj={self.clicks_by_frame_obj}, boxes_by_frame_obj={self.boxes_by_frame_obj}, composited_frames={self.composited_frames}, current_frame_idx={self.current_frame_idx}, current_obj_id={self.current_obj_id}, current_label={self.current_label}, current_clear_old={self.current_clear_old}, current_prompt_type={self.current_prompt_type}, pending_box_start={self.pending_box_start}, pending_box_start_frame_idx={self.pending_box_start_frame_idx}, pending_box_start_obj_id={self.pending_box_start_obj_id}, is_switching_model={self.is_switching_model}, model_repo_key={self.model_repo_key}, model_repo_id={self.model_repo_id}, session_repo_id={self.session_repo_id})"
159
+
160
  @property
161
  def num_frames(self) -> int:
162
  return len(self.video_frames)
 
192
  print(f"Loading model from {desired_repo}")
193
  device, dtype = get_device_and_dtype()
194
  # free up the gpu memory
 
 
195
  model = AutoModel.from_pretrained(desired_repo)
196
  processor = Sam2VideoProcessor.from_pretrained(desired_repo)
197
  model.to(device, dtype=dtype)
 
226
  pass
227
  GLOBAL_STATE.inference_session = None
228
  gc.collect()
 
 
 
 
 
229
  GLOBAL_STATE.inference_session = processor.init_video_session(
230
  video=GLOBAL_STATE.video_frames,
231
  inference_device=device,
 
562
  # f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
563
  # gr.update(value=last_frame_idx),
564
  # )
565
+ print("global state", GLOBAL_STATE)
566
  return (
567
  GLOBAL_STATE,
568
  f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
 
593
  pass
594
  GLOBAL_STATE.inference_session = None
595
  gc.collect()
 
 
 
 
 
596
  ensure_session_for_current_model(GLOBAL_STATE)
597
 
598
  # Keep current slider index if possible