File size: 6,070 Bytes
257f706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) 2025. Your modifications here.
# This file wraps and extends sam2.utils.misc for custom modifications.

from sam2.utils import misc as sam2_misc
from sam2.utils.misc import * 
from PIL import Image
import numpy as np
import torch
from tqdm import tqdm
import os

import logging

import torch
from hydra import compose
from hydra.utils import instantiate
from omegaconf import OmegaConf

from sam2.utils.misc import AsyncVideoFrameLoader, _load_img_as_tensor
from sam2.build_sam import _load_checkpoint


def _load_img_v2_as_tensor(img, image_size):
    img_pil = Image.fromarray(img.astype(np.uint8))
    img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
    if img_np.dtype == np.uint8:  # np.uint8 is expected for JPEG images
        img_np = img_np / 255.0
    else:
        raise RuntimeError(f"Unknown image dtype: {img_np.dtype}")
    img = torch.from_numpy(img_np).permute(2, 0, 1)
    video_width, video_height = img_pil.size  # the original video size
    return img, video_height, video_width

def load_video_frames(

    video_path,

    image_size,

    offload_video_to_cpu,

    img_mean=(0.485, 0.456, 0.406),

    img_std=(0.229, 0.224, 0.225),

    async_loading_frames=False,

    frame_names=None,

):
    """

    Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).



    The frames are resized to image_size x image_size and are loaded to GPU if

    `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.



    You can load a frame asynchronously by setting `async_loading_frames` to `True`.

    """
    if isinstance(video_path, str) and os.path.isdir(video_path):
        jpg_folder = video_path
    else:
        raise NotImplementedError("Only JPEG frames are supported at this moment")
    if frame_names is None:
        frame_names = [
            p
            for p in os.listdir(jpg_folder)
            if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG", ".png"]
        ]
        frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

    num_frames = len(frame_names)
    if num_frames == 0:
        raise RuntimeError(f"no images found in {jpg_folder}")
    img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
    img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
    img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]

    if async_loading_frames:
        lazy_images = AsyncVideoFrameLoader(
            img_paths, image_size, offload_video_to_cpu, img_mean, img_std
        )
        return lazy_images, lazy_images.video_height, lazy_images.video_width

    images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
    for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
        images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
    if not offload_video_to_cpu:
        images = images.cuda()
        img_mean = img_mean.cuda()
        img_std = img_std.cuda()
    # normalize by mean and std
    images -= img_mean
    images /= img_std
    return images, video_height, video_width


def load_video_frames_v2(

    frames,

    image_size,

    offload_video_to_cpu,

    img_mean=(0.485, 0.456, 0.406),

    img_std=(0.229, 0.224, 0.225),

    async_loading_frames=False,

    frame_names=None,

):
    """

    Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).



    The frames are resized to image_size x image_size and are loaded to GPU if

    `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.



    You can load a frame asynchronously by setting `async_loading_frames` to `True`.

    """
    num_frames = len(frames)
    img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
    img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]

    images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
    for n, frame in enumerate(tqdm(frames, desc="video frame")):
        images[n], video_height, video_width = _load_img_v2_as_tensor(frame, image_size)
    if not offload_video_to_cpu:
        images = images.cuda()
        img_mean = img_mean.cuda()
        img_std = img_std.cuda()
    # normalize by mean and std
    images -= img_mean
    images /= img_std
    return images, video_height, video_width

def build_sam2_video_predictor(

    config_file,

    ckpt_path=None,

    device="cuda",

    mode="eval",

    hydra_overrides_extra=[],

    apply_postprocessing=True,

):
    hydra_overrides = [
        "++model._target_=video_predictor.SAM2VideoPredictor",
    ]
    if apply_postprocessing:
        hydra_overrides_extra = hydra_overrides_extra.copy()
        hydra_overrides_extra += [
            # dynamically fall back to multi-mask if the single mask is not stable
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
            # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
            "++model.binarize_mask_from_pts_for_mem_enc=true",
            # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
            "++model.fill_hole_area=8",
        ]

    hydra_overrides.extend(hydra_overrides_extra)
    # Read config and init model
    cfg = compose(config_name=config_file, overrides=hydra_overrides)
    OmegaConf.resolve(cfg)
    model = instantiate(cfg.model, _recursive_=True)
    _load_checkpoint(model, ckpt_path)
    model = model.to(device)
    if mode == "eval":
        model.eval()
    return model