# Copyright (c) 2025. Your modifications here. # A wrapper for sam2 functions from collections import OrderedDict import torch from tqdm import tqdm from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base from sam2.sam2_video_predictor import SAM2VideoPredictor as _SAM2VideoPredictor from sam2.utils.misc import concat_points, fill_holes_in_mask_scores from sam_utils import load_video_frames_v2, load_video_frames class SAM2VideoPredictor(_SAM2VideoPredictor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @torch.inference_mode() def init_state( self, video_path, offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, frame_names=None ): """Initialize a inference state.""" images, video_height, video_width = load_video_frames( video_path=video_path, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, async_loading_frames=async_loading_frames, frame_names=frame_names ) inference_state = {} inference_state["images"] = images inference_state["num_frames"] = len(images) # whether to offload the video frames to CPU memory # turning on this option saves the GPU memory with only a very small overhead inference_state["offload_video_to_cpu"] = offload_video_to_cpu # whether to offload the inference state to CPU memory # turning on this option saves the GPU memory at the cost of a lower tracking fps # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object # and from 24 to 21 when tracking two objects) inference_state["offload_state_to_cpu"] = offload_state_to_cpu # the original video height and width, used for resizing final output scores inference_state["video_height"] = video_height inference_state["video_width"] = video_width inference_state["device"] = torch.device("cuda") if offload_state_to_cpu: inference_state["storage_device"] = torch.device("cpu") else: inference_state["storage_device"] = torch.device("cuda") # inputs on each frame inference_state["point_inputs_per_obj"] = {} inference_state["mask_inputs_per_obj"] = {} # visual features on a small number of recently visited frames for quick interactions inference_state["cached_features"] = {} # values that don't change across frames (so we only need to hold one copy of them) inference_state["constants"] = {} # mapping between client-side object id and model-side object index inference_state["obj_id_to_idx"] = OrderedDict() inference_state["obj_idx_to_id"] = OrderedDict() inference_state["obj_ids"] = [] # A storage to hold the model's tracking results and states on each frame inference_state["output_dict"] = { "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } # Slice (view) of each object tracking results, sharing the same memory with "output_dict" inference_state["output_dict_per_obj"] = {} # A temporary storage to hold new outputs when user interact with a frame # to add clicks or mask (it's merged into "output_dict" before propagation starts) inference_state["temp_output_dict_per_obj"] = {} # Frames that already holds consolidated outputs from click or mask inputs # (we directly use their consolidated outputs during tracking) inference_state["consolidated_frame_inds"] = { "cond_frame_outputs": set(), # set containing frame indices "non_cond_frame_outputs": set(), # set containing frame indices } # metadata for each tracking frame (e.g. which direction it's tracked) inference_state["tracking_has_started"] = False inference_state["frames_already_tracked"] = {} # Warm up the visual backbone and cache the image feature on frame 0 self._get_image_feature(inference_state, frame_idx=0, batch_size=1) return inference_state @torch.inference_mode() def init_state_v2( self, frames, offload_video_to_cpu=False, offload_state_to_cpu=False, async_loading_frames=False, frame_names=None ): """Initialize a inference state.""" images, video_height, video_width = load_video_frames_v2( frames=frames, image_size=self.image_size, offload_video_to_cpu=offload_video_to_cpu, async_loading_frames=async_loading_frames, frame_names=frame_names ) inference_state = {} inference_state["images"] = images inference_state["num_frames"] = len(images) # whether to offload the video frames to CPU memory # turning on this option saves the GPU memory with only a very small overhead inference_state["offload_video_to_cpu"] = offload_video_to_cpu # whether to offload the inference state to CPU memory # turning on this option saves the GPU memory at the cost of a lower tracking fps # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object # and from 24 to 21 when tracking two objects) inference_state["offload_state_to_cpu"] = offload_state_to_cpu # the original video height and width, used for resizing final output scores inference_state["video_height"] = video_height inference_state["video_width"] = video_width inference_state["device"] = torch.device("cuda") if offload_state_to_cpu: inference_state["storage_device"] = torch.device("cpu") else: inference_state["storage_device"] = torch.device("cuda") # inputs on each frame inference_state["point_inputs_per_obj"] = {} inference_state["mask_inputs_per_obj"] = {} # visual features on a small number of recently visited frames for quick interactions inference_state["cached_features"] = {} # values that don't change across frames (so we only need to hold one copy of them) inference_state["constants"] = {} # mapping between client-side object id and model-side object index inference_state["obj_id_to_idx"] = OrderedDict() inference_state["obj_idx_to_id"] = OrderedDict() inference_state["obj_ids"] = [] # A storage to hold the model's tracking results and states on each frame inference_state["output_dict"] = { "cond_frame_outputs": {}, # dict containing {frame_idx: } "non_cond_frame_outputs": {}, # dict containing {frame_idx: } } # Slice (view) of each object tracking results, sharing the same memory with "output_dict" inference_state["output_dict_per_obj"] = {} # A temporary storage to hold new outputs when user interact with a frame # to add clicks or mask (it's merged into "output_dict" before propagation starts) inference_state["temp_output_dict_per_obj"] = {} # Frames that already holds consolidated outputs from click or mask inputs # (we directly use their consolidated outputs during tracking) inference_state["consolidated_frame_inds"] = { "cond_frame_outputs": set(), # set containing frame indices "non_cond_frame_outputs": set(), # set containing frame indices } # metadata for each tracking frame (e.g. which direction it's tracked) inference_state["tracking_has_started"] = False inference_state["frames_already_tracked"] = {} # Warm up the visual backbone and cache the image feature on frame 0 self._get_image_feature(inference_state, frame_idx=0, batch_size=1) return inference_state