Spaces:
Paused
Paused
added cuda as optional
Browse files- inference.py +16 -6
- xora/pipelines/pipeline_xora_video.py +1 -1
inference.py
CHANGED
|
@@ -55,7 +55,9 @@ def load_vae(vae_dir):
|
|
| 55 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
| 56 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| 57 |
vae.load_state_dict(vae_state_dict)
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
def load_unet(unet_dir):
|
|
@@ -65,7 +67,9 @@ def load_unet(unet_dir):
|
|
| 65 |
transformer = Transformer3DModel.from_config(transformer_config)
|
| 66 |
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
| 67 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
|
| 70 |
|
| 71 |
def load_scheduler(scheduler_dir):
|
|
@@ -254,7 +258,9 @@ def main():
|
|
| 254 |
patchifier = SymmetricPatchifier(patch_size=1)
|
| 255 |
text_encoder = T5EncoderModel.from_pretrained(
|
| 256 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
| 257 |
-
)
|
|
|
|
|
|
|
| 258 |
tokenizer = T5Tokenizer.from_pretrained(
|
| 259 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
| 260 |
)
|
|
@@ -272,7 +278,9 @@ def main():
|
|
| 272 |
"vae": vae,
|
| 273 |
}
|
| 274 |
|
| 275 |
-
pipeline = XoraVideoPipeline(**submodel_dict)
|
|
|
|
|
|
|
| 276 |
|
| 277 |
# Prepare input for the pipeline
|
| 278 |
sample = {
|
|
@@ -286,8 +294,10 @@ def main():
|
|
| 286 |
random.seed(args.seed)
|
| 287 |
np.random.seed(args.seed)
|
| 288 |
torch.manual_seed(args.seed)
|
| 289 |
-
torch.cuda.
|
| 290 |
-
|
|
|
|
|
|
|
| 291 |
|
| 292 |
images = pipeline(
|
| 293 |
num_inference_steps=args.num_inference_steps,
|
|
|
|
| 55 |
vae = CausalVideoAutoencoder.from_config(vae_config)
|
| 56 |
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
|
| 57 |
vae.load_state_dict(vae_state_dict)
|
| 58 |
+
if torch.cuda.is_available():
|
| 59 |
+
vae = vae.cuda()
|
| 60 |
+
return vae.to(torch.bfloat16)
|
| 61 |
|
| 62 |
|
| 63 |
def load_unet(unet_dir):
|
|
|
|
| 67 |
transformer = Transformer3DModel.from_config(transformer_config)
|
| 68 |
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
|
| 69 |
transformer.load_state_dict(unet_state_dict, strict=True)
|
| 70 |
+
if torch.cuda.is_available():
|
| 71 |
+
transformer = transformer.cuda()
|
| 72 |
+
return transformer
|
| 73 |
|
| 74 |
|
| 75 |
def load_scheduler(scheduler_dir):
|
|
|
|
| 258 |
patchifier = SymmetricPatchifier(patch_size=1)
|
| 259 |
text_encoder = T5EncoderModel.from_pretrained(
|
| 260 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
|
| 261 |
+
)
|
| 262 |
+
if torch.cuda.is_available():
|
| 263 |
+
text_encoder = text_encoder.to("cuda")
|
| 264 |
tokenizer = T5Tokenizer.from_pretrained(
|
| 265 |
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
|
| 266 |
)
|
|
|
|
| 278 |
"vae": vae,
|
| 279 |
}
|
| 280 |
|
| 281 |
+
pipeline = XoraVideoPipeline(**submodel_dict)
|
| 282 |
+
if torch.cuda.is_available():
|
| 283 |
+
pipeline = pipeline.to("cuda")
|
| 284 |
|
| 285 |
# Prepare input for the pipeline
|
| 286 |
sample = {
|
|
|
|
| 294 |
random.seed(args.seed)
|
| 295 |
np.random.seed(args.seed)
|
| 296 |
torch.manual_seed(args.seed)
|
| 297 |
+
if torch.cuda.is_available():
|
| 298 |
+
torch.cuda.manual_seed(args.seed)
|
| 299 |
+
|
| 300 |
+
generator = torch.Generator(device="cuda" if torch.cuda.is_available() else 'cpu').manual_seed(args.seed)
|
| 301 |
|
| 302 |
images = pipeline(
|
| 303 |
num_inference_steps=args.num_inference_steps,
|
xora/pipelines/pipeline_xora_video.py
CHANGED
|
@@ -1010,7 +1010,7 @@ class XoraVideoPipeline(DiffusionPipeline):
|
|
| 1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
| 1011 |
# Choose the appropriate context manager based on `mixed_precision`
|
| 1012 |
if mixed_precision:
|
| 1013 |
-
context_manager = torch.autocast("cuda", dtype=torch.bfloat16)
|
| 1014 |
else:
|
| 1015 |
context_manager = nullcontext() # Dummy context manager
|
| 1016 |
|
|
|
|
| 1010 |
current_timestep = current_timestep * (1 - conditioning_mask)
|
| 1011 |
# Choose the appropriate context manager based on `mixed_precision`
|
| 1012 |
if mixed_precision:
|
| 1013 |
+
context_manager = torch.autocast("cuda" if torch.cuda.is_available() else 'cpu', dtype=torch.bfloat16)
|
| 1014 |
else:
|
| 1015 |
context_manager = nullcontext() # Dummy context manager
|
| 1016 |
|