import torch from transformers import PretrainedConfig from typing import List, Optional, Dict, Any, Tuple, Union from pathlib import Path class VineConfig(PretrainedConfig): """ Configuration class for VINE (Video Understanding with Natural Language) model. """ model_type = "vine" def __init__( self, model_name: str = "openai/clip-vit-base-patch32", hidden_dim: int = 768, use_hf_repo: bool = True, model_repo: Optional[str] = "KevinX-Penn28/testing", model_file: Optional[str] = None, local_dir: Optional[str] = str(Path(__file__).resolve().parent), local_filename: Optional[str] = "laser_model_v1.pkl", num_top_pairs: int = 18, segmentation_method: str = "grounding_dino_sam2", box_threshold: float = 0.35, text_threshold: float = 0.25, target_fps: int = 1, alpha: float = 0.5, white_alpha: float = 0.8, topk_cate: int = 3, multi_class: bool = False, output_logit: bool = False, use_pretrained_cate_weights: bool = False, categorical_pool: str = "mean", # "mean" or "max" max_video_length: int = 100, bbox_min_dim: int = 1, visualize: bool = False, visualization_dir: Optional[str] = None, return_flattened_segments: bool = False, return_valid_pairs: bool = False, interested_object_pairs: Optional[List[Tuple[int, int]]] = None, debug_visualizations: bool = False, device: Optional[Union[str, int]] = None, **kwargs: Any, ): self.model_name = model_name self.use_hf_repo = use_hf_repo if use_hf_repo: self.model_repo = model_repo self.model_file = model_file self.local_dir = None self.local_filename = None else: self.model_repo = None self.model_file = None self.local_dir = local_dir self.local_filename = local_filename self.hidden_dim = hidden_dim self.num_top_pairs = num_top_pairs self.segmentation_method = segmentation_method self.box_threshold = box_threshold self.text_threshold = text_threshold self.target_fps = target_fps self.alpha = alpha self.white_alpha = white_alpha self.topk_cate = topk_cate self.multi_class = multi_class self.output_logit = output_logit self.use_pretrained_cate_weights = use_pretrained_cate_weights self.categorical_pool = categorical_pool self.max_video_length = max_video_length self.bbox_min_dim = bbox_min_dim self.visualize = visualize self.visualization_dir = visualization_dir self.return_flattened_segments = return_flattened_segments self.return_valid_pairs = return_valid_pairs self.interested_object_pairs = interested_object_pairs or [] self.debug_visualizations = debug_visualizations if isinstance(device, int): self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu" else: self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") super().__init__(**kwargs)