import os import copy import math import warnings import shutil from functools import partial import torch import numpy as np from .model import load_pretrained_model from .mm_utils import load_images, process_images, load_video, process_video, tokenizer_multimodal_token, get_model_name_from_path, KeywordsStoppingCriteria, DirectResize, sam_preprocess_batch from .constants import NUM_FRAMES, DEFAULT_IMAGE_TOKEN, DEFAULT_VIDEO_TOKEN, MODAL_INDEX_MAP, STREAM_START_TOKEN, STREAM_END_TOKEN from .model.rynnec_qwen2 import Videollama3Qwen2Processor def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def model_init(model_path=None, min_visual_tokens=None, max_visual_tokens=None, **kwargs): model_path = "Alibaba-DAMO-Academy/RynnEC-2B" if model_path is None else model_path model_name = get_model_name_from_path(model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, **kwargs) if max_visual_tokens is not None: image_processor.max_tokens = max_visual_tokens if min_visual_tokens is not None: image_processor.min_tokens = min_visual_tokens if tokenizer.pad_token is None and tokenizer.unk_token is not None: tokenizer.pad_token = tokenizer.unk_token processor = Videollama3Qwen2Processor(image_processor, tokenizer) return model, processor def mm_infer(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', **kwargs): mask_ids = kwargs.pop('mask_ids', None) masks = kwargs.pop('masks', None) if modal == 'image': modal_token = DEFAULT_IMAGE_TOKEN images = images_or_videos timestamps = None elif modal == 'video': modal_token = DEFAULT_VIDEO_TOKEN images, timestamps = images_or_videos elif modal == 'text': modal_token = '' else: raise ValueError(f"Unsupported modal: {modal}") # 1. text preprocess (tag process & generate prompt). if isinstance(instruct, str): messages = [{'role': 'user', 'content': instruct}] elif isinstance(instruct, list): messages = copy.deepcopy(instruct) else: raise ValueError(f"Unsupported type of instruct: {type(instruct)}") if all(not modal_token in message["content"] for message in messages): warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!") messages[0]["content"] = modal_token + messages[0]["content"] converted_messages = [] for message in messages: chunks = message["content"].split(modal_token) converted_messages.append({ "role": "user", "content": [] }) for chunk_idx in range(1, 2 * len(chunks)): if chunk_idx % 2 == 1: chunk = chunks[chunk_idx // 2].strip() converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None else: if modal == 'image': converted_messages[-1]["content"].append({"type": "image"}) elif modal == 'video': converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps}) messages = converted_messages system_message = [] image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size) # TODO: attention mask? messages = system_message + messages data_dict = vlprocessor( images=images, text=messages, merge_size=image_downsampling, return_labels=True, return_tensors="pt", ) torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16 # images = [x.to(torch_dtype).cuda(non_blocking=True) for x in data_dict["images"]] # grid_thws = [x.cuda(non_blocking=True) for x in data_dict["grid_thws"]] # 3. generate response according to visual signals and prompts. keywords = [tokenizer.eos_token] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"].unsqueeze(0)) do_sample = kwargs.get('do_sample', False) temperature = kwargs.get('temperature', 0.2 if do_sample else 1.0) top_p = kwargs.get('top_p', 0.9 if do_sample else 1.0) top_k = kwargs.get('top_k', 20 if do_sample else 50) max_new_tokens = kwargs.get('max_new_tokens', 2048) data_dict["modals"] = [modal] data_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()} if "pixel_values" in data_dict: data_dict["modals"] = data_dict["modals"] * len(data_dict["grid_sizes"]) data_dict["pixel_values"] = data_dict["pixel_values"].to(torch.bfloat16) with torch.inference_mode(): output_ids = model.generate( input_ids=data_dict["input_ids"].unsqueeze(0).cuda(), pixel_values=data_dict["pixel_values"], grid_sizes=data_dict["grid_sizes"], merge_sizes=data_dict["merge_sizes"], modals=data_dict["modals"], do_sample=do_sample, temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, use_cache=True, stopping_criteria=[stopping_criteria], pad_token_id=tokenizer.eos_token_id, masks=[masks], mask_ids=mask_ids ) outputs = tokenizer.decode(output_ids[0], skip_special_tokens=True) return outputs def mm_infer_segmentation(images_or_videos, vlprocessor, instruct, model, tokenizer, modal='video', seg_start_idx=0, mask_threshold=0.5, **kwargs): image2maskids = kwargs.get('image2maskids', []) img_size=1024 sam_transform = DirectResize(img_size) if modal == 'image': modal_token = DEFAULT_IMAGE_TOKEN images = images_or_videos timestamps = None elif modal == 'video': modal_token = DEFAULT_VIDEO_TOKEN images, timestamps = images_or_videos elif modal == 'text': modal_token = '' else: raise ValueError(f"Unsupported modal: {modal}") sam_images = [] sam_size = None if len(images)>0: for image in images: sam_image = sam_transform.apply_image(np.array(image)) sam_images.append(sam_image) if sam_size is None: sam_size = sam_image.shape[:2] sam_images = np.array(sam_images) sam_images = torch.from_numpy(sam_images).permute(0, 3, 1, 2).contiguous() sam_images = sam_preprocess_batch(sam_images) # 1. text preprocess (tag process & generate prompt). if isinstance(instruct, str): messages = [{'role': 'user', 'content': instruct}] elif isinstance(instruct, list): messages = copy.deepcopy(instruct) else: raise ValueError(f"Unsupported type of instruct: {type(instruct)}") if all(not modal_token in message["content"] for message in messages): warnings.warn(f"Image tag not found in the conversation, add it automatically at the beginning!") messages[0]["content"] = modal_token + messages[0]["content"] converted_messages = [] for message in messages: chunks = message["content"].split(modal_token) converted_messages.append({ "role": "user", "content": [] }) for chunk_idx in range(1, 2 * len(chunks)): if chunk_idx % 2 == 1: chunk = chunks[chunk_idx // 2].strip() converted_messages[-1]["content"].append({"type": "text", "text": chunk}) if chunk else None else: if modal == 'image': converted_messages[-1]["content"].append({"type": "image"}) elif modal == 'video': converted_messages[-1]["content"].append({"type": "video", "num_frames": len(images), "time": timestamps}) messages = converted_messages system_message = [] image_downsampling = kwargs.get('image_downsampling', model.config.spatial_merge_size) # TODO: attention mask? messages = system_message + messages data_dict = vlprocessor( images=images, text=messages, merge_size=image_downsampling, return_labels=True, return_tensors="pt", ) torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16 keywords = [tokenizer.eos_token] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, data_dict["input_ids"].unsqueeze(0)) do_sample = kwargs.get('do_sample', False) temperature = kwargs.get('temperature', 0.2 if do_sample else 1.0) top_p = kwargs.get('top_p', 0.9 if do_sample else 1.0) top_k = kwargs.get('top_k', 20 if do_sample else 50) max_new_tokens = kwargs.get('max_new_tokens', 2048) torch_dtype = model.config.torch_dtype if hasattr(model.config, "torch_dtype") else torch.float16 data_dict["modals"] = [modal] data_dict = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data_dict.items()} if "pixel_values" in data_dict: data_dict["modals"] = data_dict["modals"] * len(data_dict["grid_sizes"]) data_dict["pixel_values"] = data_dict["pixel_values"].to(torch.bfloat16) with torch.inference_mode(): output_ids, pred_masks = model.inference( input_ids=data_dict["input_ids"].unsqueeze(0).cuda(), pixel_values=data_dict["pixel_values"], grid_sizes=data_dict["grid_sizes"], merge_sizes=data_dict["merge_sizes"], modals=data_dict["modals"], sam_images=[sam_images], sam_size=[sam_size], image2maskids=[image2maskids], do_sample=do_sample, temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k, use_cache=True, stopping_criteria=[stopping_criteria], pad_token_id=tokenizer.eos_token_id, seg_start_idx=seg_start_idx ) outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() pred_masks_sigmoid = pred_masks.sigmoid() > mask_threshold return outputs, pred_masks_sigmoid