File size: 3,268 Bytes
f9a6349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from transformers import PretrainedConfig
from typing import List, Optional, Dict, Any, Tuple, Union
from pathlib import Path


class VineConfig(PretrainedConfig):
    """
    Configuration class for VINE (Video Understanding with Natural Language) model.
    """

    model_type = "vine"

    def __init__(
        self,
        model_name: str = "openai/clip-vit-base-patch32",
        hidden_dim: int = 768,
        use_hf_repo: bool = True,
        model_repo: Optional[str] = "KevinX-Penn28/testing",
        model_file: Optional[str] = None,
        local_dir: Optional[str] = str(Path(__file__).resolve().parent),
        local_filename: Optional[str] = "laser_model_v1.pkl",
        num_top_pairs: int = 18,
        segmentation_method: str = "grounding_dino_sam2",
        box_threshold: float = 0.35,
        text_threshold: float = 0.25,
        target_fps: int = 1,
        alpha: float = 0.5,
        white_alpha: float = 0.8,
        topk_cate: int = 3,
        multi_class: bool = False,
        output_logit: bool = False,
        use_pretrained_cate_weights: bool = False,
        categorical_pool: str = "mean",  # "mean" or "max"
        max_video_length: int = 100,
        bbox_min_dim: int = 1,
        visualize: bool = False,
        visualization_dir: Optional[str] = None,
        return_flattened_segments: bool = False,
        return_valid_pairs: bool = False,
        interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
        debug_visualizations: bool = False,
        device: Optional[Union[str, int]] = None,
        **kwargs: Any,
    ):
        self.model_name = model_name
        self.use_hf_repo = use_hf_repo
        if use_hf_repo:
            self.model_repo = model_repo
            self.model_file = model_file
            self.local_dir = None
            self.local_filename = None
        else:
            self.model_repo = None
            self.model_file = None
            self.local_dir = local_dir
            self.local_filename = local_filename

        self.hidden_dim = hidden_dim
        self.num_top_pairs = num_top_pairs
        self.segmentation_method = segmentation_method
        self.box_threshold = box_threshold
        self.text_threshold = text_threshold
        self.target_fps = target_fps
        self.alpha = alpha
        self.white_alpha = white_alpha
        self.topk_cate = topk_cate
        self.multi_class = multi_class
        self.output_logit = output_logit
        self.use_pretrained_cate_weights = use_pretrained_cate_weights
        self.categorical_pool = categorical_pool
        self.max_video_length = max_video_length
        self.bbox_min_dim = bbox_min_dim
        self.visualize = visualize
        self.visualization_dir = visualization_dir
        self.return_flattened_segments = return_flattened_segments
        self.return_valid_pairs = return_valid_pairs
        self.interested_object_pairs = interested_object_pairs or []
        self.debug_visualizations = debug_visualizations

        if isinstance(device, int):
            self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
        else:
            self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        super().__init__(**kwargs)