LASER / src /vine_hf /vine_config.py
moqingyan123
final fixes
888f9e4
raw
history blame
5.02 kB
import torch
from transformers import PretrainedConfig
from typing import List, Optional, Dict, Any, Tuple
from pathlib import Path
class VineConfig(PretrainedConfig):
"""
Configuration class for VINE (Video Understanding with Natural Language) model.
VINE is a video understanding model that processes categorical (object class names),
unary keywords (actions on one object), and binary keywords (relations between two objects),
and returns probability distributions over all of them when passed a video.
Args:
model_name (str): The CLIP model name to use as backbone. Default: "openai/clip-vit-large-patch14-336"
hidden_dim (int): Hidden dimension size. Default: 768
num_top_pairs (int): Number of top object pairs to consider. Default: 10
segmentation_method (str): Segmentation method to use ("sam2" or "grounding_dino_sam2"). Default: "grounding_dino_sam2"
box_threshold (float): Box threshold for Grounding DINO. Default: 0.35
text_threshold (float): Text threshold for Grounding DINO. Default: 0.25
target_fps (int): Target FPS for video processing. Default: 1
alpha (float): Alpha value for object extraction. Default: 0.5
white_alpha (float): White alpha value for background blending. Default: 0.8
topk_cate (int): Top-k categories to return. Default: 3
multi_class (bool): Whether to use multi-class classification. Default: False
output_logit (bool): Whether to output logits instead of probabilities. Default: False
max_video_length (int): Maximum number of frames to process. Default: 100
bbox_min_dim (int): Minimum bounding box dimension. Default: 5
visualize (bool): Whether to visualize results. Default: False
visualization_dir (str, optional): Directory to save visualizations. Default: None
debug_visualizations (bool): Whether to save debug visualizations. Default: False
return_flattened_segments (bool): Whether to return flattened segments. Default: False
return_valid_pairs (bool): Whether to return valid object pairs. Default: False
interested_object_pairs (List[Tuple[int, int]], optional): List of interested object pairs
"""
model_type = "vine"
def __init__(
self,
model_name: str = "openai/clip-vit-base-patch32",
hidden_dim = 768,
use_hf_repo: bool = True,
model_repo: Optional[str] = "KevinX-Penn28/testing",
model_file: Optional[str] = None,
local_dir: Optional[str] = str(Path(__file__).resolve().parent),
local_filename: Optional[str] = "laser_model_v1.pkl",
num_top_pairs: int = 18,
segmentation_method: str = "grounding_dino_sam2",
box_threshold: float = 0.35,
text_threshold: float = 0.25,
target_fps: int = 1,
alpha: float = 0.5,
white_alpha: float = 0.8,
topk_cate: int = 3,
multi_class: bool = False,
output_logit: bool = False,
max_video_length: int = 100,
bbox_min_dim: int = 5,
visualize: bool = False,
visualization_dir: Optional[str] = None,
return_flattened_segments: bool = False,
return_valid_pairs: bool = False,
interested_object_pairs: Optional[List[Tuple[int, int]]] = None,
debug_visualizations: bool = False,
device: Optional[str | int] = None,
**kwargs
):
self.model_name = model_name
self.use_hf_repo = use_hf_repo
if use_hf_repo:
self.model_repo = model_repo
self.model_file = model_file
self.local_dir = None
self.local_filename = None
else:
self.model_repo = None
self.model_file = None
self.local_dir = local_dir
self.local_filename = local_filename
self.hidden_dim = hidden_dim
self.num_top_pairs = num_top_pairs
self.segmentation_method = segmentation_method
self.box_threshold = box_threshold
self.text_threshold = text_threshold
self.target_fps = target_fps
self.alpha = alpha
self.white_alpha = white_alpha
self.topk_cate = topk_cate
self.multi_class = multi_class
self.output_logit = output_logit
self.max_video_length = max_video_length
self.bbox_min_dim = bbox_min_dim
self.visualize = visualize
self.visualization_dir = visualization_dir
self.return_flattened_segments = return_flattened_segments
self.return_valid_pairs = return_valid_pairs
self.interested_object_pairs = interested_object_pairs or []
self.debug_visualizations = debug_visualizations
if device is int:
self._device = f"cuda:{device}" if torch.cuda.is_available() else "cpu"
else:
self._device = device or ("cuda" if torch.cuda.is_available() else "cpu")
super().__init__(**kwargs)