0xZohar's picture
Fix: Safe device selection for CPU/GPU compatibility
4885d4a verified
from typing import Union, Optional, Tuple
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
from peft import LoraConfig, get_peft_model, TaskType
from tqdm import tqdm
from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast
from cube3d.inference.logits_postprocesses import process_logits
from cube3d.inference.utils import load_config, load_model_weights, parse_structured, load_model_weights_adaption
from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
from cube3d.model.gpt.dual_stream_roformer import DualStreamRoformer
from cube3d.model.transformers.cache import Cache
from cube3d.model.transformers.rope import precompute_freqs_cis
from cube3d.training.utils import positional_encoding
from cube3d.config import HF_CACHE_DIR
class Engine:
def __init__(
self,
config_path: str,
gpt_ckpt_path: str,
shape_ckpt_path: str,
save_gpt_ckpt_path: str,
device: torch.device,
mode: str
):
"""
Initializes the inference engine with the given configuration and checkpoint paths.
Args:
config_path (str): Path to the configuration file.
gpt_ckpt_path (str): Path to the GPT model checkpoint file.
shape_ckpt_path (str): Path to the shape model checkpoint file.
device (torch.device): The device to run the models on (e.g., 'cpu' or 'cuda').
Attributes:
cfg (dict): Loaded configuration from the config file.
device (torch.device): The device to run the models on.
gpt_model (DualStreamRoformer): The GPT model initialized and loaded with weights.
shape_model (OneDAutoEncoder): The shape model initialized and loaded with weights.
text_model (CLIPTextModelWithProjection): The text model initialized from a pretrained model.
text_tokenizer (CLIPTokenizerFast): The tokenizer for the text model.
max_new_tokens (int): Maximum number of new tokens for the shape model.
min_id (int): Minimum ID for the shape model codes.
max_id (int): Maximum ID for the shape model codes.
"""
self.cfg = load_config(config_path)
self.device = device
self.gpt_model = DualStreamRoformer(
parse_structured(DualStreamRoformer.Config, self.cfg.gpt_model)
)
#------training load
if mode=='test':
load_model_weights(
self.gpt_model,
save_gpt_ckpt_path,
)
#-------traing load
self.gpt_model = self.gpt_model.to(self.device)
self.shape_model = OneDAutoEncoder(
parse_structured(OneDAutoEncoder.Config, self.cfg.shape_model)
)
load_model_weights(
self.shape_model,
shape_ckpt_path,
)
self.shape_model = self.shape_model.eval().to(self.device)
# copy vq codebook to gpt
with torch.no_grad():
codebook = self.shape_model.bottleneck.block.get_codebook()
codebook = self.gpt_model.shape_proj(codebook).detach()
self.gpt_model.transformer.wte.weight.data[: codebook.shape[0]] = codebook
device_map = self.device.type if isinstance(self.device, torch.device) else self.device
self.text_model = CLIPTextModelWithProjection.from_pretrained(
self.cfg.text_model_pretrained_model_name_or_path,
force_download=False,
device_map=device_map,
cache_dir=HF_CACHE_DIR,
).eval()
print("------text_model device---------", self.text_model.device)
self.text_tokenizer = CLIPTokenizerFast.from_pretrained(
self.cfg.text_model_pretrained_model_name_or_path,
cache_dir=HF_CACHE_DIR,
#force_download=False,
)
self.max_new_tokens = self.shape_model.cfg.num_encoder_latents
self.min_id = 0
self.max_id = self.shape_model.cfg.num_codes
self.max_token_length = 110 #bottom #310 #car
self.x_prembeds = None
self.x_prembeds = None
self.x_prembeds = None
@torch.inference_mode()
def prepare_conditions_with_bbox(
self,
cond: torch.Tensor,
bounding_box_tensor: Optional[torch.Tensor] = None,
):
"""
Prepares condition embeddings by incorporating bounding box information.
Concatenates bounding box embeddings to the existing condition tensor if the model
supports bounding box projection. If no bounding box is provided, uses zero padding.
Args:
cond (torch.Tensor): The input condition embeddings tensor of shape (B, seq_len, dim).
bounding_box_xyz (Optional[torch.Tensor], optional): The size of the bounding box
as (x, y, z) dimensions represented as a tensor. If None, uses zero padding for
bounding box embeddings.
Returns:
torch.Tensor: The condition tensor with bounding box embeddings concatenated along
the sequence dimension if bounding box projection is supported, otherwise
returns the original condition tensor unchanged.
"""
if not hasattr(self.gpt_model, "bbox_proj"):
return cond
if bounding_box_tensor is None:
B = cond.shape[0]
bounding_box_tensor = torch.zeros((B, 3), dtype=cond.dtype, device=self.device)
bbox_emb = self.gpt_model.bbox_proj(bounding_box_tensor).unsqueeze(dim=1).expand(cond.shape[0], -1, -1)
cond = torch.cat([cond, bbox_emb], dim=1)
return cond
@torch.inference_mode()
def prepare_conditions_with_bboxs(
self,
cond: torch.Tensor,
bounding_box_tensor: Optional[torch.Tensor] = None,
):
"""
Prepares condition embeddings by incorporating bounding box information.
Concatenates bounding box embeddings to the existing condition tensor if the model
supports bounding box projection. If no bounding box is provided, uses zero padding.
Args:
cond (torch.Tensor): The input condition embeddings tensor of shape (B, seq_len, dim).
bounding_box_xyz (Optional[torch.Tensor], optional): The size of the bounding box
as (x, y, z) dimensions represented as a tensor. If None, uses zero padding for
bounding box embeddings.
Returns:
torch.Tensor: The condition tensor with bounding box embeddings concatenated along
the sequence dimension if bounding box projection is supported, otherwise
returns the original condition tensor unchanged.
"""
if not hasattr(self.gpt_model, "bbox_proj"):
return cond
if bounding_box_tensor is None:
B = cond.shape[0]
bounding_box_tensor = torch.zeros((B, 3), dtype=cond.dtype, device=self.device)
bbox_emb = self.gpt_model.bbox_proj(bounding_box_tensor).unsqueeze(dim=1).expand(cond.shape[0], -1, -1)
cond = torch.cat([cond, bbox_emb], dim=1)
return cond
@torch.inference_mode()
def prepare_inputs(
self,
prompts: list[str],
guidance_scale: float,
bounding_box_xyz: Optional[Tuple[float]] = None,
):
"""
Prepares the input embeddings for the model based on the provided prompts and guidance scale.
Args:
prompts (list[str]): A list of prompt strings to be encoded.
guidance_scale (float): A scaling factor for guidance. If greater than 0.0, additional processing is applied.
bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
uses default bounding box sizing.
Returns:
tuple: A tuple containing:
- embed (torch.Tensor): The encoded input embeddings.
- cond (torch.Tensor): The condition embeddings, which may include unconditional embeddings if guidance_scale is greater than 0.0.
"""
prompt_embeds = self.run_clip(prompts) # [1, 77, 1536]
with torch.autocast(self.device.type, dtype=torch.bfloat16):
embed = self.encode_input(prompt_embeds, self.gpt_model.shape_bos_id) # (prompt_embeds, 16384) -> [1, 1, 1536], just embedding shape_bos_id
#bos embed
if bounding_box_xyz is not None:
cond_bbox = torch.atleast_2d(torch.tensor(bounding_box_xyz)).to(self.device)
uncond_bbox = torch.zeros_like(cond_bbox).to(self.device)
else:
cond_bbox = None
uncond_bbox = None
cond = self.prepare_conditions_with_bbox(prompt_embeds, cond_bbox)
if guidance_scale > 0.0:
embed = torch.cat([embed, embed], dim=0) #why cat ? for chunk=2
uncond_embeds = self.run_clip([""] * len(prompts))
uncond = self.prepare_conditions_with_bbox(uncond_embeds, uncond_bbox)
cond = torch.cat([cond, uncond], dim=0)
return embed, cond
@torch.inference_mode()
def canonical_inputs(
self,
input_ids: torch.Tensor,
mask: torch.Tensor,
):
"""
Prepares the input embeddings for the model based on the provided prompts and guidance scale.
Args:
prompts (list[str]): A list of prompt strings to be encoded.
guidance_scale (float): A scaling factor for guidance. If greater than 0.0, additional processing is applied.
bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
uses default bounding box sizing.
Returns:
tuple: A tuple containing:
- embed (torch.Tensor): The encoded input embeddings.
- cond (torch.Tensor): The condition embeddings, which may include unconditional embeddings if guidance_scale is greater than 0.0.
"""
# import ipdb; ipdb.set_trace()
x_num = 213
y_num = 217
z_num = 529
rot_num = 24
xyz = x_num + y_num + z_num + rot_num
#mask_input = input_ids[mask==1]
#cut_idx = (mask == False)[:, :, 0].int().argmax(dim=1)
input_ids[:, :xyz] = 0
input_ids[:, 0] = 1
return input_ids
@torch.inference_mode()
def run_clip(self, text_inputs):
"""
Processes the given text inputs using a text tokenizer and a text model, and returns the encoded text embeddings.
Args:
text_inputs (str or List[str]): The input text or list of texts to be processed.
Returns:
torch.Tensor: The encoded text embeddings.
"""
#import ipdb; ipdb.set_trace()
text_inputs = self.text_tokenizer(
text_inputs,
max_length=self.text_tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()}
# use full precision for text encoder
with torch.autocast(device_type=self.device.type, enabled=False):
encoded = self.text_model(**text_inputs)
if self.gpt_model.cfg.use_pooled_text_embed:
embed = encoded.text_embeds.unsqueeze(1) # [bs, 1, 512]
else:
embed = encoded.last_hidden_state # [bs, 77, 512]
embed = self.gpt_model.encode_text(embed)
return embed
@torch.inference_mode()
def encode_input(self, inputs: torch.Tensor, bos: int):
"""
Encodes the beginning of sequence (BOS) token for the given input tensor.
Args:
inputs (torch.Tensor): The input tensor containing sequences.
bos (int): The beginning of sequence token ID.
Returns:
torch.Tensor: The encoded BOS token embeddings.
"""
b = inputs.shape[0]
bos_embed = self.gpt_model.encode_token(
torch.full(
(b, 1),
fill_value=bos,
dtype=torch.long,
device=self.device,
)
)
return bos_embed
@torch.inference_mode()
def run_gpt(
self,
prompts: list[str],
use_kv_cache: bool,
guidance_scale: float = 3.0,
top_p: float = None,
bounding_box_xyz: Optional[Tuple[float]] = None,
):
"""
Generates text using a GPT model based on the provided prompts.
Args:
prompts (list[str]): A list of input prompts to generate text from.
use_kv_cache (bool): Whether to use key-value caching for faster generation.
guidance_scale (float, optional): The scale for guidance during generation. Default is 3.0.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
bounding_box_xyz (Optional[Tuple[float]], optional): The size of the bounding box for generation
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
uses default bounding box sizing.
Returns:
torch.Tensor: A tensor containing the generated token IDs.
"""
embed, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz) #embed: bos
output_ids = []
batch_size, input_seq_len, dim = embed.shape
max_seq_len = input_seq_len + self.max_new_tokens
embed_buffer = torch.zeros(
(batch_size, max_seq_len, dim), dtype=embed.dtype, device=embed.device
)
embed_buffer[:, :input_seq_len, :].copy_(embed)
cond_len = cond.shape[1]
kv_cache = None
if use_kv_cache:
# import ipdb; ipdb.set_trace()
kv_cache = self.gpt_model.init_kv_cache(
batch_size,
cond_len,
self.max_new_tokens + 1, # +1 for the BOS token
torch.bfloat16,
embed.device,
)
# import ipdb; ipdb.set_trace()
with torch.autocast(self.device.type, dtype=torch.bfloat16):
for i in tqdm(range(self.max_new_tokens), desc=f"generating"):
curr_pos_id = torch.tensor([i], dtype=torch.long, device=embed.device)
logits = self.gpt_model(
embed_buffer,
cond,
kv_cache=kv_cache,
curr_pos_id=curr_pos_id if use_kv_cache else None,
decode=(i > 0) if use_kv_cache else False,
)
if use_kv_cache:
logits = logits[:, 0, ...]
else:
logits = logits[:, i, ...]
# import ipdb; ipdb.set_trace()
logits = logits[..., self.min_id : self.max_id]
if guidance_scale > 0.0:
logits, uncond_logits = logits.float().chunk(2, dim=0)
gamma = (
guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
)
logits = (1 + gamma) * logits - gamma * uncond_logits
next_id = process_logits(
logits,
top_p=top_p,
)
output_ids.append(next_id)
next_embed = self.gpt_model.encode_token(next_id)
if guidance_scale > 0.0:
next_embed = torch.cat([next_embed, next_embed], dim=0)
embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
# import ipdb; ipdb.set_trace()
print(logits)
return torch.cat(output_ids, dim=1)
@torch.inference_mode()
def run_shape_decode(
self,
output_ids: torch.Tensor,
resolution_base: float = 8.0,
chunk_size: int = 100_000,
):
"""
Decodes the shape from the given output IDs and extracts the geometry.
Args:
output_ids (torch.Tensor): The tensor containing the output IDs.
resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
Returns:
tuple: A tuple containing the vertices and faces of the mesh.
"""
shape_ids = (
output_ids[:, : self.shape_model.cfg.num_encoder_latents, ...]
.clamp_(0, self.shape_model.cfg.num_codes - 1)
.view(-1, self.shape_model.cfg.num_encoder_latents)
)
latents = self.shape_model.decode_indices(shape_ids) #where loss?
mesh_v_f, _ = self.shape_model.extract_geometry(
latents,
resolution_base=resolution_base,
chunk_size=chunk_size,
use_warp=True,
)
return mesh_v_f
@torch.inference_mode()
def t2s(
self,
prompts: list[str],
use_kv_cache: bool,
guidance_scale: float = 3.0,
resolution_base: float = 8.0,
chunk_size: int = 100_000,
top_p: float = None,
bounding_box_xyz: Optional[Tuple[float]] = None,
):
"""
Generates a 3D mesh from text prompts using a GPT model and shape decoder.
Args:
prompts (list[str]): A list of text prompts to guide the generation.
use_kv_cache (bool): Whether to use key-value caching for the GPT model.
guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
uses default bounding box sizing.
Returns:
mesh_v_f: The generated 3D mesh vertices and faces.
"""
output_ids = self.run_gpt(
prompts, use_kv_cache, guidance_scale, top_p, bounding_box_xyz
)
with torch.autocast(self.device.type, dtype=torch.bfloat16):
mesh_v_f = self.run_shape_decode(output_ids, resolution_base, chunk_size)
return mesh_v_f
class EngineFast(Engine):
def __init__(
self,
config_path: str,
gpt_ckpt_path: str,
shape_ckpt_path: str,
save_gpt_ckpt_path: str,
device: torch.device,
mode: str
):
"""
Initializes the inference engine with the given configuration and checkpoint paths.
Args:
config_path (str): Path to the configuration file.
gpt_ckpt_path (str): Path to the GPT checkpoint file.
shape_ckpt_path (str): Path to the shape checkpoint file.
device (torch.device): The device to run the inference on (e.g., CPU or CUDA).
"""
assert (
device.type == "cuda"
), "EngineFast is only supported on cuda devices, please use Engine on non-cuda devices"
super().__init__(config_path, gpt_ckpt_path, shape_ckpt_path, save_gpt_ckpt_path, device, mode)
# CUDA Graph params
self.graph = torch.cuda.CUDAGraph()
self.embed_buffer = torch.Tensor()
self.cond_buffer = torch.Tensor()
self.logits_buffer = torch.Tensor()
self.curr_pos_id = torch.tensor([0], dtype=torch.long, device=self.device)
self.kv_cache: list[Cache] = []
#self._warmup_and_capture_graph()
def _warmup_and_capture_graph(self):
"""
Warms up the model by running a series of forward passes and captures the CUDA graph for efficient execution.
This method performs the following steps:
1. Prepares the input embeddings and conditions using a warmup prompt.
2. Initializes buffers for embeddings and conditions.
3. Initializes the key-value cache for the GPT model.
4. Runs a series of warmup passes to prefill the model and generate logits.
5. Captures the CUDA graph for the model's forward pass to optimize future executions.
"""
warmup_prompt = "A cube"
embed, cond = self.prepare_inputs([warmup_prompt], guidance_scale=3.0)
batch_size, input_seq_len, dim = embed.shape
max_seq_len = input_seq_len + self.max_new_tokens
self.embed_buffer = torch.zeros(
(batch_size, max_seq_len, dim), dtype=embed.dtype, device=self.device
)
self.embed_buffer[:, :input_seq_len, :].copy_(embed)
self.cond_buffer = torch.empty_like(cond)
self.cond_buffer.copy_(cond)
cond_len = self.cond_buffer.shape[1]
# Initialize kv_cache for the first time
self.kv_cache = self.gpt_model.init_kv_cache(
batch_size,
cond_len,
self.max_new_tokens + 1, # +1 for the BOS token
torch.bfloat16,
self.device,
)
num_warmup_passes = 10
with torch.autocast(self.device.type, dtype=torch.bfloat16):
self._set_curr_pos_id(0)
_ = self._prefill_and_return_logits()
for x in range(1, num_warmup_passes):
self._set_curr_pos_id(x)
self.logits_buffer = self.gpt_model(
embed=self.embed_buffer,
cond=self.cond_buffer,
kv_cache=self.kv_cache,
curr_pos_id=self.curr_pos_id,
#decode=True,
decode=False
)
side_stream = torch.cuda.Stream(device=self.device)
with torch.cuda.graph(self.graph, stream=side_stream):
with torch.autocast(self.device.type, dtype=torch.bfloat16):
self.logits_buffer = self.gpt_model(
embed=self.embed_buffer,
cond=self.cond_buffer,
kv_cache=self.kv_cache,
curr_pos_id=self.curr_pos_id,
decode=True,
)
def _reset_kv_cache(self):
"""
Resets the key-value cache by setting all key and value states to zero.
This method iterates through each cache in the `kv_cache` attribute and
calls the `zero_()` method on both `key_states` and `value_states` to
reset them to their initial state.
"""
for cache in self.kv_cache:
cache.key_states.zero_()
cache.value_states.zero_()
def _prefill_and_return_logits(self) -> torch.Tensor:
"""
Prefills the model's key-value cache and returns the logits.
This method resets the key-value cache and then performs a forward pass
through the GPT model in eager mode to prefill the logits.
Returns:
torch.Tensor: The prefilled logits tensor with the first dimension removed.
"""
self._reset_kv_cache()
# Prefill is always eager
prefill_logits = self.gpt_model(
embed=self.embed_buffer,
cond=self.cond_buffer,
kv_cache=self.kv_cache,
curr_pos_id=self.curr_pos_id,
decode=False,
)
return prefill_logits[:, 0, ...]
def _set_curr_pos_id(self, pos: int):
"""
Set the current position ID.
This method updates the `curr_pos_id` attribute with the given position.
Args:
pos (int): The position ID to set.
"""
self.curr_pos_id.copy_(
torch.tensor([pos], dtype=torch.long, device=self.device)
)
def run_gpt(
self,
prompts: list[str],
use_kv_cache: bool,
guidance_scale: float = 3.0,
top_p: float = None,
bounding_box_xyz: Optional[Tuple[float]] = None,
):
"""
Runs the GPT model to generate text based on the provided prompts.
Args:
prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported.
use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
If None, argmax selection is performed. Otherwise, smallest
set of tokens with cumulative probability ≥ top_p are kept.
bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
uses default bounding box sizing.
Returns:
torch.Tensor: A tensor containing the generated output token IDs.
Raises:
AssertionError: If the batch size is greater than 1.
"""
embed, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz)
assert len(prompts) == 1, "batch size > 1 not support for EngineFast"
batch_size, input_seq_len, _ = embed.shape
self.embed_buffer.zero_()
self.embed_buffer[:, :input_seq_len, :].copy_(embed)
assert self.cond_buffer.shape == cond.shape
self.cond_buffer.copy_(cond)
output_ids = torch.zeros(
(batch_size // 2, self.max_new_tokens), dtype=torch.int, device=self.device
)
with torch.autocast(self.device.type, dtype=torch.bfloat16):
self._set_curr_pos_id(0)
logits = self._prefill_and_return_logits()
# import ipdb; ipdb.set_trace()
logits = logits[..., self.min_id : self.max_id] #[2, 16387]
if guidance_scale > 0.0:
logits, uncond_logits = logits.float().chunk(2, dim=0)
gamma = guidance_scale
logits = (1 + gamma) * logits - gamma * uncond_logits
next_id = process_logits(logits, top_p=top_p)
output_ids[:, 0] = next_id.squeeze()
next_embed = self.gpt_model.encode_token(next_id)
next_embed = next_embed.repeat(2, 1, 1)
self.embed_buffer[:, input_seq_len, :].copy_(next_embed.squeeze(1))
for i in tqdm(range(1, self.max_new_tokens), desc=f"generating"):
self._set_curr_pos_id(i)
self.graph.replay()
logits = self.logits_buffer[:, 0, ...]
logits = logits[..., self.min_id : self.max_id]
if guidance_scale > 0.0:
logits, uncond_logits = logits.float().chunk(2, dim=0)
gamma = (
guidance_scale * (self.max_new_tokens - i) / self.max_new_tokens
)
logits = (1 + gamma) * logits - gamma * uncond_logits
next_id = process_logits(logits, top_p=top_p)
output_ids[:, i] = next_id.squeeze()
next_embed = self.gpt_model.encode_token(next_id)
next_embed = next_embed.repeat(2, 1, 1)
self.embed_buffer[:, i + input_seq_len, :].copy_(next_embed.squeeze(1))
print(logits)
return output_ids
def pad_id_and_attn(self, inputs_ids, attention_mask): # same
# reserve one space for `bos`, the pad_id will be replaced to `bos`
place_holder = torch.ones_like(inputs_ids[:, [0]]) # batch x 1
# prepare input_ids and attention_mask for transformers
#input_ids[attention_mask.bool()] += 3 # 0 - num_tokens to 3 - num_tokens + 3, total: 0 - num_tokens + 3, num: numtokens + 4
#input_ids[~attention_mask.bool()] = self.padding_token_id # 2 # in transformers pad token id is only used for init nn.embedding which we won't use
# input_ids = torch.cat(
# (place_holder * self.shape_bos_id, input_ids, place_holder * self.pad_id),
# dim=1
# )
inputs_ids = torch.cat(
#(place_holder * self.gpt_model.shape_bos_id, input_ids, place_holder * self.gpt_model.shape_eos_id),
(place_holder * self.gpt_model.shape_bos_id, inputs_ids),
dim=1
)
#input_ids[torch.arange(0, input_ids.shape[0]), attention_mask.sum(dim=1).long()+1] = self.eos_token_id #
#bos: begin of sequence, eos: end of sequence, pad: padding token
#import ipdb; ipdb.set_trace()
#input_ids[attention_mask.sum(dim=1).long()+1] = self.gpt_model.shape_eos_id #
attention_mask = torch.cat(
(place_holder, place_holder, attention_mask, ),
dim=1
)
# length
return inputs_ids, attention_mask
def precompute_freqs_cis_position(self, b, x_l, y_l, z_l, device):
"""
Set the current position ID.
This method updates the `curr_pos_id` attribute with the given position.
Args:
pos (int): The position ID to set.
"""
x_ids = torch.arange(x_l, dtype=torch.long, device=device) # shape (t)
x_ids = x_ids.unsqueeze_(0).expand(b, -1)
x_freqs_cis = precompute_freqs_cis(
dim=self.gpt_model.cfg.n_embd // self.gpt_model.cfg.n_head * 4, # 128
t=x_ids,
theta=self.gpt_model.cfg.rope_theta, #10000.0
)
y_ids = torch.arange(y_l, dtype=torch.long, device=device) # shape (t)
y_ids = y_ids.unsqueeze_(0).expand(b, -1)
y_freqs_cis = precompute_freqs_cis(
dim=self.gpt_model.cfg.n_embd // self.gpt_model.cfg.n_head * 4, # 128*4
t=y_ids,
theta=self.gpt_model.cfg.rope_theta, #10000.0
)
z_ids = torch.arange(z_l, dtype=torch.long, device=device) # shape (t)
z_ids = z_ids.unsqueeze_(0).expand(b, -1)
z_freqs_cis = precompute_freqs_cis(
dim=self.gpt_model.cfg.n_embd // self.gpt_model.cfg.n_head * 4, # 128
t=z_ids,
theta=self.gpt_model.cfg.rope_theta, #10000.0
)
return x_freqs_cis, y_freqs_cis, z_freqs_cis
def fwd_gpt(
self,
prompts: list[str],
inputs_ids: list[torch.Tensor],
latent: list[torch.Tensor],
use_kv_cache: bool,
guidance_scale: float = 3.0,
top_p: float = None,
bounding_box_xyz: Optional[Tuple[float]] = None,
strategy: int = None,
mode: str = 'train'
):
"""
Runs the GPT model to generate text based on the provided prompts.
Args:
prompts (list[str]): A list of input prompts for the GPT model. Only a single prompt is supported.
use_kv_cache (bool): Flag indicating whether to use key-value caching. (Currently not used)
guidance_scale (float, optional): The scale factor for guidance. Default is 3.0.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
If None, argmax selection is performed. Otherwise, smallest
set of tokens with cumulative probability ≥ top_p are kept.
bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
uses default bounding box sizing.
Returns:
torch.Tensor: A tensor containing the generated output token IDs.
Raises:
AssertionError: If the batch size is greater than 1.
"""
#_, cond = self.prepare_inputs(prompts, guidance_scale, bounding_box_xyz)
#assert len(prompts) == 1, "batch size > 1 not support for EngineFast" #why?
#batch_size, input_seq_len, _ = embed.shape
with torch.no_grad():
attention_mask = inputs_ids != -1
cut_idx = (attention_mask == False)[:, :, -3].int().argmax(dim=1)
#dat_id = inputs_ids[:,:,self.gpt_model.xyz:self.gpt_model.xyz+self.gpt_model.dat_num].argmax(-1)
dat_id = inputs_ids[:,:,-6].long()
dat_id = torch.where(torch.arange(dat_id.shape[1], device=dat_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.dat_num, dat_id)
inputs_embeds = self.gpt_model.dte(dat_id)
# x_id = inputs_ids[:,:,24:self.gpt_model.x+24].argmax(-1)
# y_id = inputs_ids[:,:,self.gpt_model.x:self.gpt_model.xy].argmax(-1)
# z_id = inputs_ids[:,:,self.gpt_model.xy:self.gpt_model.xyz].argmax(-1)
# coord_ids = torch.cat([x_id.unsqueeze(-1), y_id.unsqueeze(-1), z_id.unsqueeze(-1)], dim=-1)
# max_vals = torch.tensor([self.gpt_model.x_num - 1, self.gpt_model.y_num - 1, self.gpt_model.z_num - 1],
# dtype=torch.float32,
# device=coord_ids.device)
# normliz_coord = coord_ids.float() / max_vals.view(1, 1, 3) * 2 - 1 #
# pos_embeds = positional_encoding(normliz_coord, 128)
#embeds_from_id = self.gpt_model.encode_embed(inputs_ids[:, :, self.gpt_model.xyz:self.gpt_model.xyz + self.gpt_model.dat_num].float())
#embeds_from_id = self.gpt_model.encode_embed(inputs_ids[:, :, 24:self.gpt_model.xyz + self.gpt_model.dat_num].float())
#embeds_from_id = self.gpt_model.encode_embed(inputs_ids[:, :, 24:self.gpt_model.xyz].float())
#flatten rot id
r_id = inputs_ids[:,:,0]
r_id = torch.where(torch.arange(r_id.shape[1], device=r_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.rot_num, r_id)
#flatten postion id
x_id = inputs_ids[:,:,-5]
y_id = inputs_ids[:,:,-4]
z_id = inputs_ids[:,:,-3]
x_id = torch.where(torch.arange(x_id.shape[1], device=x_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.x_num, x_id)
y_id = torch.where(torch.arange(y_id.shape[1], device=y_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.y_num, y_id)
z_id = torch.where(torch.arange(z_id.shape[1], device=z_id.device)[None,:] >= cut_idx[:,None], self.gpt_model.z_num, z_id)
inputs_ids[:, :, 0] = r_id.clone()
inputs_ids[:, :, -6] = dat_id.clone()
inputs_ids[:, :, -5] = x_id.clone()
inputs_ids[:, :, -4] = y_id.clone()
inputs_ids[:, :, -3] = z_id.clone()
#mask token
strategy = strategy if mode=='test' else torch.randint(0, 4, (1,)).item()
if strategy == 0:
x_id = torch.where(torch.arange(x_id.shape[1], device=x_id.device)[None,:] < cut_idx[:,None], self.gpt_model.x_num+1, x_id)
y_id = torch.where(torch.arange(y_id.shape[1], device=y_id.device)[None,:] < cut_idx[:,None], self.gpt_model.y_num+1, y_id)
z_id = torch.where(torch.arange(z_id.shape[1], device=z_id.device)[None,:] < cut_idx[:,None], self.gpt_model.z_num+1, z_id)
mask = None
elif strategy == 1:
x_id = torch.where(torch.arange(x_id.shape[1], device=x_id.device)[None,:] < cut_idx[:,None], self.gpt_model.x_num+1, x_id)
y_id = torch.where(torch.arange(y_id.shape[1], device=y_id.device)[None,:] < cut_idx[:,None], self.gpt_model.y_num+1, y_id)
z_id = torch.where(torch.arange(z_id.shape[1], device=z_id.device)[None,:] < cut_idx[:,None], self.gpt_model.z_num+1, z_id)
r_id = torch.where(torch.arange(r_id.shape[1], device=r_id.device)[None,:] < cut_idx[:,None], self.gpt_model.rot_num+1, r_id)
mask = None
elif strategy == 2:
x_id = torch.where(torch.arange(x_id.shape[1], device=x_id.device)[None,:] < cut_idx[:,None], self.gpt_model.x_num+1, x_id)
y_id = torch.where(torch.arange(y_id.shape[1], device=y_id.device)[None,:] < cut_idx[:,None], self.gpt_model.y_num+1, y_id)
z_id = torch.where(torch.arange(z_id.shape[1], device=z_id.device)[None,:] < cut_idx[:,None], self.gpt_model.z_num+1, z_id)
mask = (torch.arange(r_id.shape[1], device=r_id.device)[None,:] < cut_idx[:,None]) & (torch.rand(r_id.shape, device=r_id.device) > torch.empty(1, device=r_id.device).uniform_(0.0, 1.0).item())
r_id = torch.where(mask, self.gpt_model.rot_num+1, r_id)
else:
mask = (torch.arange(x_id.shape[1], device=x_id.device)[None,:] < cut_idx[:,None]) & (torch.rand(x_id.shape, device=x_id.device) > torch.empty(1, device=r_id.device).uniform_(0.0, 1.).item())
x_id = torch.where(mask, self.gpt_model.x_num+1, x_id)
y_id = torch.where(mask, self.gpt_model.y_num+1, y_id)
z_id = torch.where(mask, self.gpt_model.z_num+1, z_id)
#print(strategy)
rembeds_from_id = self.gpt_model.rte(r_id)
xembeds_from_id = self.gpt_model.xte(x_id)
yembeds_from_id = self.gpt_model.yte(y_id)
zembeds_from_id = self.gpt_model.zte(z_id)
embeds_from_id = torch.stack([inputs_embeds.clone(), rembeds_from_id, yembeds_from_id, xembeds_from_id, zembeds_from_id], dim=2) # [b, 310, 3, 1536]
#embeds_from_id = torch.stack([yembeds_from_id, xembeds_from_id, zembeds_from_id], dim=2)
embeds_from_id = embeds_from_id.view(xembeds_from_id.shape[0], xembeds_from_id.shape[1] * 5, xembeds_from_id.shape[2]) # [b, 930, 1536]
#inputs_embeds = self.gpt_model.encode_token(latent)
#position embedding
#inputs_embeds = torch.cat([pos_embeds, inputs_embeds], dim=-1)
inputs_embeds = self.prepare_conditions_with_bboxs(inputs_embeds, bounding_box_xyz)
#add token number padding
#sequence_length = inputs_ids.shape[1]
#pad_sequence = torch.ones((inputs_ids.shape[0], sequence_length), dtype=torch.long, device=inputs_ids.device) * self.gpt_model.dat_num #self.gpt_model.padding_id
#pad_sequence_embed = self.gpt_model.encode_token(pad_sequence) #[b, 1536]
#!!!--------litte wrong
#embeds_from_id[~attention_mask[:,:,:inputs_embeds.shape[2]]] = pad_sequence_embed[~attention_mask[:,:,:inputs_embeds.shape[2]]]
#add bos
place_holder = torch.ones_like(inputs_ids[:, 0, 0]).long() # batch x 1
bos_embed = self.gpt_model.encode_token(place_holder * self.gpt_model.shape_bos_id) #[1, 1536]
embeds_from_id = torch.cat([bos_embed[:, None, :], embeds_from_id], dim=1)
inputs_embeds = bos_embed.unsqueeze(1)
#exchange
# ex = inputs_embeds.clone()
# inputs_embeds = self.prepare_conditions_with_bboxs(embeds_from_id, bounding_box_xyz)
# embeds_from_id = torch.cat([bos_embed[:, None, :], ex], dim=1)
# Prefill is always eager
prefill_logits = self.gpt_model(
embed=embeds_from_id, #_repeat,
cond=inputs_embeds, #_repeat,
kv_cache=None,
curr_pos_id=None,
decode=False,
)
logits = prefill_logits[..., self.min_id : self.max_id]
# if guidance_scale > 0.0:
# logits, uncond_logits = logits.float().chunk(2, dim=0)
# gamma = guidance_scale
# # seq_len = logits.size(1)
# # gamma_list = guidance_scale * (seq_len - torch.arange(seq_len)) / seq_len
# # # shape: [seq_len]
# logits = (1 + gamma) * logits - gamma * uncond_logits
return logits, inputs_ids, strategy, mask, cut_idx
def t2t(
self,
prompts: list[str],
inputs_ids: list[torch.Tensor],
latent: list[torch.Tensor],
use_kv_cache: bool,
guidance_scale: float = 3.0,
resolution_base: float = 8.0,
chunk_size: int = 100_000,
top_p: float = None,
bounding_box_xyz: Optional[Tuple[float]] = None,
strategy: int = None,
mode: str = None
):
"""
Generates a 3D mesh from text prompts using a GPT model.
Args:
prompts (list[str]): A list of text prompts to guide the generation.
use_kv_cache (bool): Whether to use key-value caching for the GPT model.
guidance_scale (float, optional): The scale of guidance for the GPT model. Default is 3.0.
resolution_base (float, optional): The base resolution for the shape decoder. Default is 8.0.
chunk_size (int, optional): The chunk size for processing the shape decoding. Default is 100,000.
top_p (float, optional): The cumulative probability threshold for nucleus sampling.
If None, argmax selection is performed (deterministic generation). Otherwise, smallest set of tokens with cumulative probability ≥ top_p are kept (stochastic generation).
bounding_box_xyz (Tuple[float] | None, optional): The size of the bounding box for the generated mesh
as (x, y, z) dimensions. Each value must be between 0 and 1.925. If None,
uses default bounding box sizing.
Returns:
output_ids: The generated 3D mesh tokens.
"""
logits = self.fwd_gpt(
prompts, inputs_ids, latent, use_kv_cache, guidance_scale, top_p, bounding_box_xyz, strategy, mode
)
return logits
def configure_optimizers(
self,
train_config
):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
# import ipdb; ipdb.set_trace()
for mn, m in self.gpt_model.named_modules():
#print(mn, m)
if mn!='lm_head':
continue
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
# random note: because named_modules and named_parameters are recursive
# we will see the same tensors p many many times. but doing it this way
# allows us to know which parent module any tensor p belongs to...
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
elif '_norm.weight' in pn: #
no_decay.add(fpn)
#import ipdb; ipdb.set_trace()
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.gpt_model.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
# assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
# assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
# % (str(param_dict.keys() - union_params), )
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
return optimizer
def configure_optimizers_lora(
self,
train_config
):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
optim_groups = (p for p in self.gpt_model.parameters() if p.requires_grad)
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
return optimizer
def configure_optimizers_lora_linear(
self,
train_config
):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.gpt_model.named_modules():
#print(mn, m)
if mn!='ldr_head' or mn!='ldr_proj' or mn!='dte' or mn!='xte' or mn!='yte' or mn!='zte' or mn!='rte':
continue
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
# random note: because named_modules and named_parameters are recursive
# we will see the same tensors p many many times. but doing it this way
# allows us to know which parent module any tensor p belongs to...
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
elif '_norm.weight' in pn: #
no_decay.add(fpn)
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.gpt_model.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
lora_optim_groups = [p for p in self.gpt_model.parameters() if p.requires_grad]
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
{"params": lora_optim_groups},
]
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
scheduler = CosineAnnealingLR(
optimizer,
T_max=train_config.max_iters,
eta_min=train_config.learning_rate * 0.01
)
return optimizer, scheduler
def configure_optimizers_scratch_linear(
self,
train_config
):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.gpt_model.named_modules():
#print(mn, m)
# if mn!='ldr_head' or mn!='ldr_proj' or mn!='dte' or mn!='xte' or mn!='yte' or mn!='zte' or mn!='rte':
# continue
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
# random note: because named_modules and named_parameters are recursive
# we will see the same tensors p many many times. but doing it this way
# allows us to know which parent module any tensor p belongs to...
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
elif '_norm.weight' in pn: #
no_decay.add(fpn)
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.gpt_model.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
scheduler = CosineAnnealingLR(
optimizer,
T_max=train_config.max_iters,
eta_min=train_config.learning_rate * 0.01
)
return optimizer, scheduler