yonigozlan HF Staff commited on
Commit
af1cc2b
·
1 Parent(s): e684124

use deepcopy on inference session instead

Browse files
Files changed (1) hide show
  1. app.py +4 -8
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import colorsys
2
  import gc
 
3
  from typing import Optional
4
 
5
  import gradio as gr
@@ -513,7 +514,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
513
 
514
  processor = GLOBAL_STATE.processor
515
  model = GLOBAL_STATE.model
516
- inference_session = GLOBAL_STATE.inference_session
517
  # set inference device to cuda to use zero gpu
518
  inference_session.inference_device = "cuda"
519
  inference_session.cache.inference_device = "cuda"
@@ -556,8 +557,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
556
  inference_session.cache.inference_device = "cpu"
557
  gc.collect()
558
  torch.cuda.empty_cache()
559
- test = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
560
- inference_session.reset_inference_session()
561
 
562
  # Final status; ensure slider points to last processed frame
563
  # yield (
@@ -565,11 +565,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
565
  # gr.update(value=last_frame_idx),
566
  # )
567
  print("global state", GLOBAL_STATE)
568
- return (
569
- GLOBAL_STATE,
570
- f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects.",
571
- gr.update(value=last_frame_idx),
572
- )
573
 
574
 
575
  def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
 
1
  import colorsys
2
  import gc
3
+ from copy import deepcopy
4
  from typing import Optional
5
 
6
  import gradio as gr
 
514
 
515
  processor = GLOBAL_STATE.processor
516
  model = GLOBAL_STATE.model
517
+ inference_session = deepcopy(GLOBAL_STATE.inference_session)
518
  # set inference device to cuda to use zero gpu
519
  inference_session.inference_device = "cuda"
520
  inference_session.cache.inference_device = "cuda"
 
557
  inference_session.cache.inference_device = "cpu"
558
  gc.collect()
559
  torch.cuda.empty_cache()
560
+ text = f"Propagated masks across {processed} frames for {len(inference_session.obj_ids)} objects."
 
561
 
562
  # Final status; ensure slider points to last processed frame
563
  # yield (
 
565
  # gr.update(value=last_frame_idx),
566
  # )
567
  print("global state", GLOBAL_STATE)
568
+ return GLOBAL_STATE, text, gr.update(value=last_frame_idx)
 
 
 
 
569
 
570
 
571
  def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]: