LASER / vine_hf /vine_config.py
ASethi04's picture
updates
f9a6349
raw
history blame
3.27 kB
import torch
from transformers import PretrainedConfig
from typing import List, Optional, Dict, Any, Tuple, Union
from pathlib import Path
class VineConfig(PretrainedConfig):
"""
Configuration class for VINE (Video Understanding with Natural Language) model.
"""
model_type = "vine"
def __init__(
self,
model_name: str = "openai/clip-vit-base-patch32",
hidden_dim: int = 768,
use_hf_repo: bool = True,
model_repo: Optional[str] = "KevinX-Penn28/testing",
model_file: Optional[str] = None,
local_dir: Optional[str] = str(Path(__file__).resolve().parent),
local_filename: Optional[str] = "laser_model_v1.pkl",
num_top_pairs: int = 18,
segmentation_method: str = "grounding_dino_sam2",
box_threshold: float = 0.35,
text_threshold: float = 0.25,
target_fps: int = 1,
alpha: float = 0.5,
white_alpha: float = 0.8,
topk_cate: int = 3,
multi_class: bool = False,
output_logit: bool = False,
use_pretrained_cate_weights: bool = False,
categorical_pool: str = "mean", # "mean" or "max"
max_video_length: int = 100,
bbox_min_dim: int = 1,
visualize: bool = False,
visualization_dir: Optional[str] = None,
return_flattened_segments: bool = False,
return_valid_pairs: bool = False,
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
debug_visualizations: bool = False,
device: Optional[Union[str, int]] = None,
**kwargs: Any,
):
self.model_name = model_name
self.use_hf_repo = use_hf_repo
if use_hf_repo:
self.model_repo = model_repo
self.model_file = model_file
self.local_dir = None
self.local_filename = None
else:
self.model_repo = None
self.model_file = None
self.local_dir = local_dir
self.local_filename = local_filename
self.hidden_dim = hidden_dim
self.num_top_pairs = num_top_pairs
self.segmentation_method = segmentation_method
self.box_threshold = box_threshold
self.text_threshold = text_threshold
self.target_fps = target_fps
self.alpha = alpha
self.white_alpha = white_alpha
self.topk_cate = topk_cate
self.multi_class = multi_class
self.output_logit = output_logit
self.use_pretrained_cate_weights = use_pretrained_cate_weights
self.categorical_pool = categorical_pool
self.max_video_length = max_video_length
self.bbox_min_dim = bbox_min_dim
self.visualize = visualize
self.visualization_dir = visualization_dir
self.return_flattened_segments = return_flattened_segments
self.return_valid_pairs = return_valid_pairs
self.interested_object_pairs = interested_object_pairs or []
self.debug_visualizations = debug_visualizations
if isinstance(device, int):
self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
else:
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
super().__init__(**kwargs)