from typing import Union, Optional, Tuple import torch import torch.nn.functional as F from torch.optim.lr_scheduler import CosineAnnealingLR from peft import LoraConfig, get_peft_model, TaskType from tqdm import tqdm from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast from cube3d.inference.logits_postprocesses import process_logits from cube3d.inference.utils import load_config, load_model_weights, parse_structured, load_model_weights_adaption from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder from cube3d.model.gpt.dual_stream_roformer import DualStreamRoformer from cube3d.model.transformers.cache import Cache from cube3d.model.transformers.rope import precompute_freqs_cis from cube3d.training.utils import positional_encoding from cube3d.config import HF_CACHE_DIR class Engine: def __init__( self, config_path: str, gpt_ckpt_path: str, shape_ckpt_path: str, save_gpt_ckpt_path: str, device: torch.device, mode: str ): """ Initializes the inference engine with the given configuration and checkpoint paths. Args: config_path (str): Path to the configuration file. gpt_ckpt_path (str): Path to the GPT model checkpoint file. shape_ckpt_path (str): Path to the shape model checkpoint file. device (torch.device): The device to run the models on (e.g., 'cpu' or 'cuda'). Attributes: cfg (dict): Loaded configuration from the config file. device (torch.device): The device to run the models on. gpt_model (DualStreamRoformer): The GPT model initialized and loaded with weights. shape_model (OneDAutoEncoder): The shape model initialized and loaded with weights. text_model (CLIPTextModelWithProjection): The text model initialized from a pretrained model. text_tokenizer (CLIPTokenizerFast): The tokenizer for the text model. max_new_tokens (int): Maximum number of new tokens for the shape model. min_id (int): Minimum ID for the shape model codes. max_id (int): Maximum ID for the shape model codes. """ self.cfg = load_config(config_path) self.device = device self.gpt_model = DualStreamRoformer( parse_structured(DualStreamRoformer.Config, self.cfg.gpt_model) ) #------training load if mode=='test': load_model_weights( self.gpt_model, save_gpt_ckpt_path, ) #-------traing load self.gpt_model = self.gpt_model.to(self.device) self.shape_model = OneDAutoEncoder( parse_structured(OneDAutoEncoder.Config, self.cfg.shape_model) ) load_model_weights( self.shape_model, shape_ckpt_path, ) self.shape_model = self.shape_model.eval().to(self.device) # copy vq codebook to gpt with torch.no_grad(): codebook = self.shape_model.bottleneck.block.get_codebook() codebook = self.gpt_model.shape_proj(codebook).detach() self.gpt_model.transformer.wte.weight.data[: codebook.shape[0]] = codebook device_map = self.device.type if isinstance(self.device, torch.device) else self.device self.text_model = CLIPTextModelWithProjection.from_pretrained( self.cfg.text_model_pretrained_model_name_or_path, force_download=False, device_map=device_map, cache_dir=HF_CACHE_DIR, ).eval() print("------text_model device---------", self.text_model.device) self.text_tokenizer = CLIPTokenizerFast.from_pretrained( self.cfg.text_model_pretrained_model_name_or_path, cache_dir=HF_CACHE_DIR, #force_download=False, ) self.max_new_tokens = self.shape_model.cfg.num_encoder_latents self.min_id = 0 self.max_id = self.shape_model.cfg.num_codes self.max_token_length = 110 #bottom #310 #car self.x_prembeds = None self.x_prembeds = None self.x_prembeds = None @torch.inference_mode() def prepare_conditions_with_bbox( self, cond: torch.Tensor, bounding_box_tensor: Optional[torch.Tensor] = None, ): """ Prepares condition embeddings by incorporating bounding box information. Concatenates bounding box embeddings to the existing condition tensor if the model supports bounding box projection. If no bounding box is provided, uses zero padding. Args: cond (torch.Tensor): The input condition embeddings tensor of shape (B, seq_len, dim). bounding_box_xyz (Optional[torch.Tensor], optional): The size of the bounding box as (x, y, z) dimensions represented as a tensor. If None, uses zero padding for bounding box embeddings. Returns: torch.Tensor: The condition tensor with bounding box embeddings concatenated along the sequence dimension if bounding box projection is supported, otherwise returns the original condition tensor unchanged. """ if not hasattr(self.gpt_model, "bbox_proj"): return cond if bounding_box_tensor is None: B = cond.shape[0] bounding_box_tensor = torch.zeros((B, 3), dtype=cond.dtype, device=self.device) bbox_emb = self.gpt_model.bbox_proj(bounding_box_tensor).unsqueeze(dim=1).expand(cond.shape[0], -1, -1) cond = torch.cat([cond, bbox_emb], dim=1) return cond @torch.inference_mode() def prepare_conditions_with_bboxs( self, cond: torch.Tensor, bounding_box_tensor: Optional[torch.Tensor] = None, ): """ Prepares condition embeddings by incorporating bounding box information. Concatenates bounding box embeddings to the existing condition tensor if the model supports bounding box projection. If no bounding box is provided, uses zero padding. Args: cond (torch.Tensor): The input condition embeddings tensor of shape (B, seq_len, dim). bounding_box_xyz (Optional[torch.Tensor], optional): The size of the bounding box as (x, y, z) dimensions represented as a tensor. If None, uses zero padding for bounding box embeddings. Returns: torch.Tensor: The condition tensor with bounding box embeddings concatenated along the sequence dimension if bounding box projection is supported, otherwise returns the original condition tensor unchanged. """ if not hasattr(self.gpt_model, "bbox_proj"): return cond if bounding_box_tensor is None: B = cond.shape[0] bounding_box_tensor = torch.zeros((B, 3), dtype=cond.dtype, device=self.device) bbox_emb = self.gpt_model.bbox_proj(bounding_box_tensor).unsqueeze(dim=1).expand(cond.shape[0], -1, -1) cond = torch.cat([cond, bbox_emb], dim=1) return cond @torch.inference_mode() def prepare_inputs( self, prompts: list[str], guidance_scale: float, bounding_box_xyz: Optional[Tuple[float]] = None, ): """ Prepares the input embeddings for the model based on the provided prompts and guidance scale. Args: prompts (list[str]): A list of prompt strings to be encoded. guidance_scale (float): A scaling factor for guidance. If greater than 0.0, additional processing is applied. bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None, uses default bounding box sizing. Returns: tuple: A tuple containing: - embed (torch.Tensor): The encoded input embeddings. - cond (torch.Tensor): The condition embeddings, which may include unconditional embeddings if guidance_scale is greater than 0.0. """ prompt_embeds = self.run_clip(prompts) # [1, 77, 1536] with torch.autocast(self.device.type, dtype=torch.bfloat16): embed = self.encode_input(prompt_embeds, self.gpt_model.shape_bos_id) # (prompt_embeds, 16384) -> [1, 1, 1536], just embedding shape_bos_id #bos embed if bounding_box_xyz is not None: cond_bbox = torch.atleast_2d(torch.tensor(bounding_box_xyz)).to(self.device) uncond_bbox = torch.zeros_like(cond_bbox).to(self.device) else: cond_bbox = None uncond_bbox = None cond = self.prepare_conditions_with_bbox(prompt_embeds, cond_bbox) if guidance_scale > 0.0: embed = torch.cat([embed, embed], dim=0) #why cat ? for chunk=2 uncond_embeds = self.run_clip([""] * len(prompts)) uncond = self.prepare_conditions_with_bbox(uncond_embeds, uncond_bbox) cond = torch.cat([cond, uncond], dim=0) return embed, cond @torch.inference_mode() def canonical_inputs( self, input_ids: torch.Tensor, mask: torch.Tensor, ): """ Prepares the input embeddings for the model based on the provided prompts and guidance scale. Args: prompts (list[str]): A list of prompt strings to be encoded. guidance_scale (float): A scaling factor for guidance. If greater than 0.0, additional processing is applied. bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None, uses default bounding box sizing. Returns: tuple: A tuple containing: - embed (torch.Tensor): The encoded input embeddings. - cond (torch.Tensor): The condition embeddings, which may include unconditional embeddings if guidance_scale is greater than 0.0. """ # import ipdb; ipdb.set_trace() x_num = 213 y_num = 217 z_num = 529 rot_num = 24 xyz = x_num + y_num + z_num + rot_num #mask_input = input_ids[mask==1] #cut_idx = (mask == False)[:, :, 0].int().argmax(dim=1) input_ids[:, :xyz] = 0 input_ids[:, 0] = 1 return input_ids @torch.inference_mode() def run_clip(self, text_inputs): """ Processes the given text inputs using a text tokenizer and a text model, and returns the encoded text embeddings. Args: text_inputs (str or List[str]): The input text or list of texts to be processed. Returns: torch.Tensor: The encoded text embeddings. """ #import ipdb; ipdb.set_trace() text_inputs = self.text_tokenizer( text_inputs, max_length=self.text_tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt", ) with torch.no_grad(): text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} # use full precision for text encoder with torch.autocast(device_type=self.device.type, enabled=False): encoded = self.text_model(**text_inputs) if self.gpt_model.cfg.use_pooled_text_embed: embed = encoded.text_embeds.unsqueeze(1) # [bs, 1, 512] else: embed = encoded.last_hidden_state # [bs, 77, 512] embed = self.gpt_model.encode_text(embed) return embed @torch.inference_mode() def encode_input(self, inputs: torch.Tensor, bos: int): """ Encodes the beginning of sequence (BOS) token for the given input tensor. Args: inputs (torch.Tensor): The input tensor containing sequences. bos (int): The beginning of sequence token ID. Returns: torch.Tensor: The encoded BOS token embeddings. """ b = inputs.shape[0] bos_embed = self.gpt_model.encode_token( torch.full( (b, 1), fill_value=bos, dtype=torch.long, device=self.device, ) ) return bos_embed @torch.inference_mode() def run_gpt( self, prompts: list[str], use_kv_cache: bool, guidance_scale: float = 3.0, top_p: float = None, bounding_box_xyz: Optional[Tuple[float]] = None, ): """ Generates text using a GPT model based on the provided prompts. Args: prompts (list[str]): A list of input prompts to generate text from. use_kv_cache (bool): Whether to use key-value caching for faster generation. guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0. top_p (float, optional): The cumulative probability threshold for nucleus sampling. If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation). bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None, uses default bounding box sizing. Returns: torch.Tensor: A tensor containing the generated token IDs. """ embed, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz) #embed: bos output_ids = [] batch_size, input_seq_len, dim = embed.shape max_seq_len = input_seq_len + self.max_new_tokens embed_buffer = torch.zeros( (batch_size, max_seq_len, dim), dtype=embed.dtype, device=embed.device ) embed_buffer[:, :input_seq_len, :].copy_(embed) cond_len = cond.shape[1] kv_cache = None if use_kv_cache: # import ipdb; ipdb.set_trace() kv_cache = self.gpt_model.init_kv_cache( batch_size, cond_len, self.max_new_tokens + 1, # +1 for the BOS token torch.bfloat16, embed.device, ) # import ipdb; ipdb.set_trace() with torch.autocast(self.device.type, dtype=torch.bfloat16): for i in tqdm(range(self.max_new_tokens), desc=f"generating"): curr_pos_id = torch.tensor([i], dtype=torch.long, device=embed.device) logits = self.gpt_model( embed_buffer, cond, kv_cache=kv_cache, curr_pos_id=curr_pos_id if use_kv_cache else None, decode=(i > 0) if use_kv_cache else False, ) if use_kv_cache: logits = logits[:, 0, ...] else: logits = logits[:, i, ...] # import ipdb; ipdb.set_trace() logits = logits[..., self.min_id : self.max_id] if guidance_scale > 0.0: logits, uncond_logits = logits.float().chunk(2, dim=0) gamma = ( guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens ) logits = (1 + gamma) * logits - gamma * uncond_logits next_id = process_logits( logits, top_p=top_p, ) output_ids.append(next_id) next_embed = self.gpt_model.encode_token(next_id) if guidance_scale > 0.0: next_embed = torch.cat([next_embed, next_embed], dim=0) embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1)) # import ipdb; ipdb.set_trace() print(logits) return torch.cat(output_ids, dim=1) @torch.inference_mode() def run_shape_decode( self, output_ids: torch.Tensor, resolution_base: float = 8.0, chunk_size: int = 100_000, ): """ Decodes the shape from the given output IDs and extracts the geometry. Args: output_ids (torch.Tensor): The tensor containing the output IDs. resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43. chunk_size (int, optional): The chunk size for processing. Defaults to 100,000. Returns: tuple: A tuple containing the vertices and faces of the mesh. """ shape_ids = ( output_ids[:, : self.shape_model.cfg.num_encoder_latents, ...] .clamp_(0, self.shape_model.cfg.num_codes - 1) .view(-1, self.shape_model.cfg.num_encoder_latents) ) latents = self.shape_model.decode_indices(shape_ids) #where loss? mesh_v_f, _ = self.shape_model.extract_geometry( latents, resolution_base=resolution_base, chunk_size=chunk_size, use_warp=True, ) return mesh_v_f @torch.inference_mode() def t2s( self, prompts: list[str], use_kv_cache: bool, guidance_scale: float = 3.0, resolution_base: float = 8.0, chunk_size: int = 100_000, top_p: float = None, bounding_box_xyz: Optional[Tuple[float]] = None, ): """ Generates a 3D mesh from text prompts using a GPT model and shape decoder. Args: prompts (list[str]): A list of text prompts to guide the generation. use_kv_cache (bool): Whether to use key-value caching for the GPT model. guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0. resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0. chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000. top_p (float, optional): The cumulative probability threshold for nucleus sampling. If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation). bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None, uses default bounding box sizing. Returns: mesh_v_f: The generated 3D mesh vertices and faces. """ output_ids = self.run_gpt( prompts, use_kv_cache, guidance_scale, top_p, bounding_box_xyz ) with torch.autocast(self.device.type, dtype=torch.bfloat16): mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size) return mesh_v_f class EngineFast(Engine): def __init__( self, config_path: str, gpt_ckpt_path: str, shape_ckpt_path: str, save_gpt_ckpt_path: str, device: torch.device, mode: str ): """ Initializes the inference engine with the given configuration and checkpoint paths. Args: config_path (str): Path to the configuration file. gpt_ckpt_path (str): Path to the GPT checkpoint file. shape_ckpt_path (str): Path to the shape checkpoint file. device (torch.device): The device to run the inference on (e.g., CPU or CUDA). """ assert ( device.type == "cuda" ), "EngineFast is only supported on cuda devices, please use Engine on non-cuda devices" super().__init__(config_path, gpt_ckpt_path, shape_ckpt_path, save_gpt_ckpt_path, device, mode) # CUDA Graph params self.graph = torch.cuda.CUDAGraph() self.embed_buffer = torch.Tensor() self.cond_buffer = torch.Tensor() self.logits_buffer = torch.Tensor() self.curr_pos_id = torch.tensor([0], dtype=torch.long, device=self.device) self.kv_cache: list[Cache] = [] #self._warmup_and_capture_graph() def _warmup_and_capture_graph(self): """ Warms up the model by running a series of forward passes and captures the CUDA graph for efficient execution. This method performs the following steps: 1. Prepares the input embeddings and conditions using a warmup prompt. 2. Initializes buffers for embeddings and conditions. 3. Initializes the key-value cache for the GPT model. 4. Runs a series of warmup passes to prefill the model and generate logits. 5. Captures the CUDA graph for the model's forward pass to optimize future executions. """ warmup_prompt = "A cube" embed, cond = self.prepare_inputs([warmup_prompt], guidance_scale=3.0) batch_size, input_seq_len, dim = embed.shape max_seq_len = input_seq_len + self.max_new_tokens self.embed_buffer = torch.zeros( (batch_size, max_seq_len, dim), dtype=embed.dtype, device=self.device ) self.embed_buffer[:, :input_seq_len, :].copy_(embed) self.cond_buffer = torch.empty_like(cond) self.cond_buffer.copy_(cond) cond_len = self.cond_buffer.shape[1] # Initialize kv_cache for the first time self.kv_cache = self.gpt_model.init_kv_cache( batch_size, cond_len, self.max_new_tokens + 1, # +1 for the BOS token torch.bfloat16, self.device, ) num_warmup_passes = 10 with torch.autocast(self.device.type, dtype=torch.bfloat16): self._set_curr_pos_id(0) _ = self._prefill_and_return_logits() for x in range(1, num_warmup_passes): self._set_curr_pos_id(x) self.logits_buffer = self.gpt_model( embed=self.embed_buffer, cond=self.cond_buffer, kv_cache=self.kv_cache, curr_pos_id=self.curr_pos_id, #decode=True, decode=False ) side_stream = torch.cuda.Stream(device=self.device) with torch.cuda.graph(self.graph, stream=side_stream): with torch.autocast(self.device.type, dtype=torch.bfloat16): self.logits_buffer = self.gpt_model( embed=self.embed_buffer, cond=self.cond_buffer, kv_cache=self.kv_cache, curr_pos_id=self.curr_pos_id, decode=True, ) def _reset_kv_cache(self): """ Resets the key-value cache by setting all key and value states to zero. This method iterates through each cache in the `kv_cache` attribute and calls the `zero_()` method on both `key_states` and `value_states` to reset them to their initial state. """ for cache in self.kv_cache: cache.key_states.zero_() cache.value_states.zero_() def _prefill_and_return_logits(self) -> torch.Tensor: """ Prefills the model's key-value cache and returns the logits. This method resets the key-value cache and then performs a forward pass through the GPT model in eager mode to prefill the logits. Returns: torch.Tensor: The prefilled logits tensor with the first dimension removed. """ self._reset_kv_cache() # Prefill is always eager prefill_logits = self.gpt_model( embed=self.embed_buffer, cond=self.cond_buffer, kv_cache=self.kv_cache, curr_pos_id=self.curr_pos_id, decode=False, ) return prefill_logits[:, 0, ...] def _set_curr_pos_id(self, pos: int): """ Set the current position ID. This method updates the `curr_pos_id` attribute with the given position. Args: pos (int): The position ID to set. """ self.curr_pos_id.copy_( torch.tensor([pos], dtype=torch.long, device=self.device) ) def run_gpt( self, prompts: list[str], use_kv_cache: bool, guidance_scale: float = 3.0, top_p: float = None, bounding_box_xyz: Optional[Tuple[float]] = None, ): """ Runs the GPT model to generate text based on the provided prompts. Args: prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported. use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used) guidance_scale (float, optional): The scale factor for guidance. Default is 3.0. top_p (float, optional): The cumulative probability threshold for nucleus sampling. If None, argmax selection is performed. Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept. bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None, uses default bounding box sizing. Returns: torch.Tensor: A tensor containing the generated output token IDs. Raises: AssertionError: If the batch size is greater than 1. """ embed, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz) assert len(prompts) == 1, "batch size > 1 not support for EngineFast" batch_size, input_seq_len, _ = embed.shape self.embed_buffer.zero_() self.embed_buffer[:, :input_seq_len, :].copy_(embed) assert self.cond_buffer.shape == cond.shape self.cond_buffer.copy_(cond) output_ids = torch.zeros( (batch_size // 2, self.max_new_tokens), dtype=torch.int, device=self.device ) with torch.autocast(self.device.type, dtype=torch.bfloat16): self._set_curr_pos_id(0) logits = self._prefill_and_return_logits() # import ipdb; ipdb.set_trace() logits = logits[..., self.min_id : self.max_id] #[2, 16387] if guidance_scale > 0.0: logits, uncond_logits = logits.float().chunk(2, dim=0) gamma = guidance_scale logits = (1 + gamma) * logits - gamma * uncond_logits next_id = process_logits(logits, top_p=top_p) output_ids[:, 0] = next_id.squeeze() next_embed = self.gpt_model.encode_token(next_id) next_embed = next_embed.repeat(2, 1, 1) self.embed_buffer[:, input_seq_len, :].copy_(next_embed.squeeze(1)) for i in tqdm(range(1, self.max_new_tokens), desc=f"generating"): self._set_curr_pos_id(i) self.graph.replay() logits = self.logits_buffer[:, 0, ...] logits = logits[..., self.min_id : self.max_id] if guidance_scale > 0.0: logits, uncond_logits = logits.float().chunk(2, dim=0) gamma = ( guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens ) logits = (1 + gamma) * logits - gamma * uncond_logits next_id = process_logits(logits, top_p=top_p) output_ids[:, i] = next_id.squeeze() next_embed = self.gpt_model.encode_token(next_id) next_embed = next_embed.repeat(2, 1, 1) self.embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1)) print(logits) return output_ids def pad_id_and_attn(self, inputs_ids, attention_mask): # same # reserve one space for `bos`, the pad_id will be replaced to `bos` place_holder = torch.ones_like(inputs_ids[:, [0]]) # batch x 1 # prepare input_ids and attention_mask for transformers #input_ids[attention_mask.bool()] += 3 # 0 - num_tokens to 3 - num_tokens + 3, total: 0 - num_tokens + 3, num: numtokens + 4 #input_ids[~attention_mask.bool()] = self.padding_token_id # 2 # in transformers pad token id is only used for init nn.embedding which we won't use # input_ids = torch.cat( # (place_holder * self.shape_bos_id, input_ids, place_holder * self.pad_id), # dim=1 # ) inputs_ids = torch.cat( #(place_holder * self.gpt_model.shape_bos_id, input_ids, place_holder * self.gpt_model.shape_eos_id), (place_holder * self.gpt_model.shape_bos_id, inputs_ids), dim=1 ) #input_ids[torch.arange(0, input_ids.shape[0]), attention_mask.sum(dim=1).long()+1] = self.eos_token_id # #bos: begin of sequence, eos: end of sequence, pad: padding token #import ipdb; ipdb.set_trace() #input_ids[attention_mask.sum(dim=1).long()+1] = self.gpt_model.shape_eos_id # attention_mask = torch.cat( (place_holder, place_holder, attention_mask, ), dim=1 ) # length return inputs_ids, attention_mask def precompute_freqs_cis_position(self, b, x_l, y_l, z_l, device): """ Set the current position ID. This method updates the `curr_pos_id` attribute with the given position. Args: pos (int): The position ID to set. """ x_ids = torch.arange(x_l, dtype=torch.long, device=device) # shape (t) x_ids = x_ids.unsqueeze_(0).expand(b, -1) x_freqs_cis = precompute_freqs_cis( dim=self.gpt_model.cfg.n_embd // self.gpt_model.cfg.n_head * 4, # 128 t=x_ids, theta=self.gpt_model.cfg.rope_theta, #10000.0 ) y_ids = torch.arange(y_l, dtype=torch.long, device=device) # shape (t) y_ids = y_ids.unsqueeze_(0).expand(b, -1) y_freqs_cis = precompute_freqs_cis( dim=self.gpt_model.cfg.n_embd // self.gpt_model.cfg.n_head * 4, # 128*4 t=y_ids, theta=self.gpt_model.cfg.rope_theta, #10000.0 ) z_ids = torch.arange(z_l, dtype=torch.long, device=device) # shape (t) z_ids = z_ids.unsqueeze_(0).expand(b, -1) z_freqs_cis = precompute_freqs_cis( dim=self.gpt_model.cfg.n_embd // self.gpt_model.cfg.n_head * 4, # 128 t=z_ids, theta=self.gpt_model.cfg.rope_theta, #10000.0 ) return x_freqs_cis, y_freqs_cis, z_freqs_cis def fwd_gpt( self, prompts: list[str], inputs_ids: list[torch.Tensor], latent: list[torch.Tensor], use_kv_cache: bool, guidance_scale: float = 3.0, top_p: float = None, bounding_box_xyz: Optional[Tuple[float]] = None, strategy: int = None, mode: str = 'train' ): """ Runs the GPT model to generate text based on the provided prompts. Args: prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported. use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used) guidance_scale (float, optional): The scale factor for guidance. Default is 3.0. top_p (float, optional): The cumulative probability threshold for nucleus sampling. If None, argmax selection is performed. Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept. bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None, uses default bounding box sizing. Returns: torch.Tensor: A tensor containing the generated output token IDs. Raises: AssertionError: If the batch size is greater than 1. """ #_, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz) #assert len(prompts) == 1, "batch size > 1 not support for EngineFast" #why? #batch_size, input_seq_len, _ = embed.shape with torch.no_grad(): attention_mask = inputs_ids != -1 cut_idx = (attention_mask == False)[:, :, -3].int().argmax(dim=1) #dat_id = inputs_ids[:,:,self.gpt_model.xyz:self.gpt_model.xyz+self.gpt_model.dat_num].argmax(-1) dat_id = inputs_ids[:,:,-6].long() dat_id = torch.where(torch.arange(dat_id.shape[1], device=dat_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.dat_num, dat_id) inputs_embeds = self.gpt_model.dte(dat_id) # x_id = inputs_ids[:,:,24:self.gpt_model.x+24].argmax(-1) # y_id = inputs_ids[:,:,self.gpt_model.x:self.gpt_model.xy].argmax(-1) # z_id = inputs_ids[:,:,self.gpt_model.xy:self.gpt_model.xyz].argmax(-1) # coord_ids = torch.cat([x_id.unsqueeze(-1), y_id.unsqueeze(-1), z_id.unsqueeze(-1)], dim=-1) # max_vals = torch.tensor([self.gpt_model.x_num - 1, self.gpt_model.y_num - 1, self.gpt_model.z_num - 1], # dtype=torch.float32, # device=coord_ids.device) # normliz_coord = coord_ids.float() / max_vals.view(1, 1, 3) * 2 - 1 # # pos_embeds = positional_encoding(normliz_coord, 128) #embeds_from_id = self.gpt_model.encode_embed(inputs_ids[:, :, self.gpt_model.xyz:self.gpt_model.xyz + self.gpt_model.dat_num].float()) #embeds_from_id = self.gpt_model.encode_embed(inputs_ids[:, :, 24:self.gpt_model.xyz + self.gpt_model.dat_num].float()) #embeds_from_id = self.gpt_model.encode_embed(inputs_ids[:, :, 24:self.gpt_model.xyz].float()) #flatten rot id r_id = inputs_ids[:,:,0] r_id = torch.where(torch.arange(r_id.shape[1], device=r_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.rot_num, r_id) #flatten postion id x_id = inputs_ids[:,:,-5] y_id = inputs_ids[:,:,-4] z_id = inputs_ids[:,:,-3] x_id = torch.where(torch.arange(x_id.shape[1], device=x_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.x_num, x_id) y_id = torch.where(torch.arange(y_id.shape[1], device=y_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.y_num, y_id) z_id = torch.where(torch.arange(z_id.shape[1], device=z_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.z_num, z_id) inputs_ids[:, :, 0] = r_id.clone() inputs_ids[:, :, -6] = dat_id.clone() inputs_ids[:, :, -5] = x_id.clone() inputs_ids[:, :, -4] = y_id.clone() inputs_ids[:, :, -3] = z_id.clone() #mask token strategy = strategy if mode=='test' else torch.randint(0, 4, (1,)).item() if strategy == 0: x_id = torch.where(torch.arange(x_id.shape[1], device=x_id.device)[None,:] < cut_idx[:,None], self.gpt_model.x_num+1, x_id) y_id = torch.where(torch.arange(y_id.shape[1], device=y_id.device)[None,:] < cut_idx[:,None], self.gpt_model.y_num+1, y_id) z_id = torch.where(torch.arange(z_id.shape[1], device=z_id.device)[None,:] < cut_idx[:,None], self.gpt_model.z_num+1, z_id) mask = None elif strategy == 1: x_id = torch.where(torch.arange(x_id.shape[1], device=x_id.device)[None,:] < cut_idx[:,None], self.gpt_model.x_num+1, x_id) y_id = torch.where(torch.arange(y_id.shape[1], device=y_id.device)[None,:] < cut_idx[:,None], self.gpt_model.y_num+1, y_id) z_id = torch.where(torch.arange(z_id.shape[1], device=z_id.device)[None,:] < cut_idx[:,None], self.gpt_model.z_num+1, z_id) r_id = torch.where(torch.arange(r_id.shape[1], device=r_id.device)[None,:] < cut_idx[:,None], self.gpt_model.rot_num+1, r_id) mask = None elif strategy == 2: x_id = torch.where(torch.arange(x_id.shape[1], device=x_id.device)[None,:] < cut_idx[:,None], self.gpt_model.x_num+1, x_id) y_id = torch.where(torch.arange(y_id.shape[1], device=y_id.device)[None,:] < cut_idx[:,None], self.gpt_model.y_num+1, y_id) z_id = torch.where(torch.arange(z_id.shape[1], device=z_id.device)[None,:] < cut_idx[:,None], self.gpt_model.z_num+1, z_id) mask = (torch.arange(r_id.shape[1], device=r_id.device)[None,:] < cut_idx[:,None]) & (torch.rand(r_id.shape, device=r_id.device) > torch.empty(1, device=r_id.device).uniform_(0.0, 1.0).item()) r_id = torch.where(mask, self.gpt_model.rot_num+1, r_id) else: mask = (torch.arange(x_id.shape[1], device=x_id.device)[None,:] < cut_idx[:,None]) & (torch.rand(x_id.shape, device=x_id.device) > torch.empty(1, device=r_id.device).uniform_(0.0, 1.).item()) x_id = torch.where(mask, self.gpt_model.x_num+1, x_id) y_id = torch.where(mask, self.gpt_model.y_num+1, y_id) z_id = torch.where(mask, self.gpt_model.z_num+1, z_id) #print(strategy) rembeds_from_id = self.gpt_model.rte(r_id) xembeds_from_id = self.gpt_model.xte(x_id) yembeds_from_id = self.gpt_model.yte(y_id) zembeds_from_id = self.gpt_model.zte(z_id) embeds_from_id = torch.stack([inputs_embeds.clone(), rembeds_from_id, yembeds_from_id, xembeds_from_id, zembeds_from_id], dim=2) # [b, 310, 3, 1536] #embeds_from_id = torch.stack([yembeds_from_id, xembeds_from_id, zembeds_from_id], dim=2) embeds_from_id = embeds_from_id.view(xembeds_from_id.shape[0], xembeds_from_id.shape[1] * 5, xembeds_from_id.shape[2]) # [b, 930, 1536] #inputs_embeds = self.gpt_model.encode_token(latent) #position embedding #inputs_embeds = torch.cat([pos_embeds, inputs_embeds], dim=-1) inputs_embeds = self.prepare_conditions_with_bboxs(inputs_embeds, bounding_box_xyz) #add token number padding #sequence_length = inputs_ids.shape[1] #pad_sequence = torch.ones((inputs_ids.shape[0], sequence_length), dtype=torch.long, device=inputs_ids.device) * self.gpt_model.dat_num #self.gpt_model.padding_id #pad_sequence_embed = self.gpt_model.encode_token(pad_sequence) #[b, 1536] #!!!--------litte wrong #embeds_from_id[~attention_mask[:,:,:inputs_embeds.shape[2]]] = pad_sequence_embed[~attention_mask[:,:,:inputs_embeds.shape[2]]] #add bos place_holder = torch.ones_like(inputs_ids[:, 0, 0]).long() # batch x 1 bos_embed = self.gpt_model.encode_token(place_holder * self.gpt_model.shape_bos_id) #[1, 1536] embeds_from_id = torch.cat([bos_embed[:, None, :], embeds_from_id], dim=1) inputs_embeds = bos_embed.unsqueeze(1) #exchange # ex = inputs_embeds.clone() # inputs_embeds = self.prepare_conditions_with_bboxs(embeds_from_id, bounding_box_xyz) # embeds_from_id = torch.cat([bos_embed[:, None, :], ex], dim=1) # Prefill is always eager prefill_logits = self.gpt_model( embed=embeds_from_id, #_repeat, cond=inputs_embeds, #_repeat, kv_cache=None, curr_pos_id=None, decode=False, ) logits = prefill_logits[..., self.min_id : self.max_id] # if guidance_scale > 0.0: # logits, uncond_logits = logits.float().chunk(2, dim=0) # gamma = guidance_scale # # seq_len = logits.size(1) # # gamma_list = guidance_scale * (seq_len - torch.arange(seq_len)) / seq_len # # # shape: [seq_len] # logits = (1 + gamma) * logits - gamma * uncond_logits return logits, inputs_ids, strategy, mask, cut_idx def t2t( self, prompts: list[str], inputs_ids: list[torch.Tensor], latent: list[torch.Tensor], use_kv_cache: bool, guidance_scale: float = 3.0, resolution_base: float = 8.0, chunk_size: int = 100_000, top_p: float = None, bounding_box_xyz: Optional[Tuple[float]] = None, strategy: int = None, mode: str = None ): """ Generates a 3D mesh from text prompts using a GPT model. Args: prompts (list[str]): A list of text prompts to guide the generation. use_kv_cache (bool): Whether to use key-value caching for the GPT model. guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0. resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0. chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000. top_p (float, optional): The cumulative probability threshold for nucleus sampling. If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation). bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None, uses default bounding box sizing. Returns: output_ids: The generated 3D mesh tokens. """ logits = self.fwd_gpt( prompts, inputs_ids, latent, use_kv_cache, guidance_scale, top_p, bounding_box_xyz, strategy, mode ) return logits def configure_optimizers( self, train_config ): """ This long function is unfortunately doing something very simple and is being very defensive: We are separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won't (biases, and layernorm/embedding weights). We are then returning the PyTorch optimizer object. """ # separate out all parameters to those that will and won't experience regularizing weight decay decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear, ) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) # import ipdb; ipdb.set_trace() for mn, m in self.gpt_model.named_modules(): #print(mn, m) if mn!='lm_head': continue for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name # random note: because named_modules and named_parameters are recursive # we will see the same tensors p many many times. but doing it this way # allows us to know which parent module any tensor p belongs to... if pn.endswith('bias'): # all biases will not be decayed no_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) elif '_norm.weight' in pn: # no_decay.add(fpn) #import ipdb; ipdb.set_trace() # validate that we considered every parameter param_dict = {pn: p for pn, p in self.gpt_model.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay # assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) # assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ # % (str(param_dict.keys() - union_params), ) # create the pytorch optimizer object optim_groups = [ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) return optimizer def configure_optimizers_lora( self, train_config ): """ This long function is unfortunately doing something very simple and is being very defensive: We are separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won't (biases, and layernorm/embedding weights). We are then returning the PyTorch optimizer object. """ optim_groups = (p for p in self.gpt_model.parameters() if p.requires_grad) optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) return optimizer def configure_optimizers_lora_linear( self, train_config ): """ This long function is unfortunately doing something very simple and is being very defensive: We are separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won't (biases, and layernorm/embedding weights). We are then returning the PyTorch optimizer object. """ # separate out all parameters to those that will and won't experience regularizing weight decay decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear, ) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.gpt_model.named_modules(): #print(mn, m) if mn!='ldr_head' or mn!='ldr_proj' or mn!='dte' or mn!='xte' or mn!='yte' or mn!='zte' or mn!='rte': continue for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name # random note: because named_modules and named_parameters are recursive # we will see the same tensors p many many times. but doing it this way # allows us to know which parent module any tensor p belongs to... if pn.endswith('bias'): # all biases will not be decayed no_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) elif '_norm.weight' in pn: # no_decay.add(fpn) # validate that we considered every parameter param_dict = {pn: p for pn, p in self.gpt_model.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay lora_optim_groups = [p for p in self.gpt_model.parameters() if p.requires_grad] optim_groups = [ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, {"params": lora_optim_groups}, ] optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) scheduler = CosineAnnealingLR( optimizer, T_max=train_config.max_iters, eta_min=train_config.learning_rate * 0.01 ) return optimizer, scheduler def configure_optimizers_scratch_linear( self, train_config ): """ This long function is unfortunately doing something very simple and is being very defensive: We are separating out all parameters of the model into two buckets: those that will experience weight decay for regularization and those that won't (biases, and layernorm/embedding weights). We are then returning the PyTorch optimizer object. """ # separate out all parameters to those that will and won't experience regularizing weight decay decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear, ) blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) for mn, m in self.gpt_model.named_modules(): #print(mn, m) # if mn!='ldr_head' or mn!='ldr_proj' or mn!='dte' or mn!='xte' or mn!='yte' or mn!='zte' or mn!='rte': # continue for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name # random note: because named_modules and named_parameters are recursive # we will see the same tensors p many many times. but doing it this way # allows us to know which parent module any tensor p belongs to... if pn.endswith('bias'): # all biases will not be decayed no_decay.add(fpn) elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): # weights of blacklist modules will NOT be weight decayed no_decay.add(fpn) elif '_norm.weight' in pn: # no_decay.add(fpn) # validate that we considered every parameter param_dict = {pn: p for pn, p in self.gpt_model.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ % (str(param_dict.keys() - union_params), ) optim_groups = [ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) scheduler = CosineAnnealingLR( optimizer, T_max=train_config.max_iters, eta_min=train_config.learning_rate * 0.01 ) return optimizer, scheduler