Spaces:
Paused
Paused
Image to video script: make determinist by random seed.
Browse files- xora/examples/image_to_video.py +23 -10
xora/examples/image_to_video.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 3 |
from xora.models.transformers.transformer3d import Transformer3DModel
|
|
@@ -14,6 +15,8 @@ import os
|
|
| 14 |
import numpy as np
|
| 15 |
import cv2
|
| 16 |
from PIL import Image
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def load_vae(vae_dir):
|
| 19 |
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
|
@@ -65,9 +68,8 @@ def load_video_to_tensor_with_resize(video_path, target_height=512, target_width
|
|
| 65 |
frame_resized = center_crop_and_resize(frame_rgb, target_height, target_width)
|
| 66 |
frames.append(frame_resized)
|
| 67 |
cap.release()
|
| 68 |
-
video_np = np.array(frames)
|
| 69 |
video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
|
| 70 |
-
video_tensor = (video_tensor / 127.5) - 1.0
|
| 71 |
return video_tensor
|
| 72 |
|
| 73 |
def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
|
|
@@ -154,9 +156,13 @@ def main():
|
|
| 154 |
'media_items': media_items,
|
| 155 |
}
|
| 156 |
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
# Run the pipeline
|
| 160 |
images = pipeline(
|
| 161 |
num_inference_steps=args.num_inference_steps,
|
| 162 |
num_images_per_prompt=args.num_images_per_prompt,
|
|
@@ -173,20 +179,27 @@ def main():
|
|
| 173 |
vae_per_channel_normalize=True,
|
| 174 |
conditioning_method=ConditioningMethod.FIRST_FRAME
|
| 175 |
).images
|
| 176 |
-
|
| 177 |
# Save output video
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
for i in range(images.shape[0]):
|
| 179 |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|
| 180 |
video_np = (video_np * 255).astype(np.uint8)
|
| 181 |
fps = args.frame_rate
|
| 182 |
height, width = video_np.shape[1:3]
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
(width, height))
|
| 188 |
for frame in video_np[..., ::-1]:
|
| 189 |
out.write(frame)
|
|
|
|
| 190 |
out.release()
|
| 191 |
|
| 192 |
|
|
|
|
| 1 |
+
import time
|
| 2 |
import torch
|
| 3 |
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 4 |
from xora.models.transformers.transformer3d import Transformer3DModel
|
|
|
|
| 15 |
import numpy as np
|
| 16 |
import cv2
|
| 17 |
from PIL import Image
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
import random
|
| 20 |
|
| 21 |
def load_vae(vae_dir):
|
| 22 |
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
|
|
|
|
| 68 |
frame_resized = center_crop_and_resize(frame_rgb, target_height, target_width)
|
| 69 |
frames.append(frame_resized)
|
| 70 |
cap.release()
|
| 71 |
+
video_np = (np.array(frames) / 127.5) - 1.0
|
| 72 |
video_tensor = torch.tensor(video_np).permute(3, 0, 1, 2).float()
|
|
|
|
| 73 |
return video_tensor
|
| 74 |
|
| 75 |
def load_image_to_tensor_with_resize(image_path, target_height=512, target_width=768):
|
|
|
|
| 156 |
'media_items': media_items,
|
| 157 |
}
|
| 158 |
|
| 159 |
+
start_time = time.time()
|
| 160 |
+
random.seed(args.seed)
|
| 161 |
+
np.random.seed(args.seed)
|
| 162 |
+
torch.manual_seed(args.seed)
|
| 163 |
+
torch.cuda.manual_seed(args.seed)
|
| 164 |
+
generator = torch.Generator(device="cuda").manual_seed(args.seed)
|
| 165 |
|
|
|
|
| 166 |
images = pipeline(
|
| 167 |
num_inference_steps=args.num_inference_steps,
|
| 168 |
num_images_per_prompt=args.num_images_per_prompt,
|
|
|
|
| 179 |
vae_per_channel_normalize=True,
|
| 180 |
conditioning_method=ConditioningMethod.FIRST_FRAME
|
| 181 |
).images
|
|
|
|
| 182 |
# Save output video
|
| 183 |
+
def get_unique_filename(base, ext, dir='.', index_range=1000):
|
| 184 |
+
for i in range(index_range):
|
| 185 |
+
filename = os.path.join(dir, f"{base}_{i}{ext}")
|
| 186 |
+
if not os.path.exists(filename):
|
| 187 |
+
return filename
|
| 188 |
+
raise FileExistsError(f"Could not find a unique filename after {index_range} attempts.")
|
| 189 |
+
|
| 190 |
+
|
| 191 |
for i in range(images.shape[0]):
|
| 192 |
video_np = images.squeeze(0).permute(1, 2, 3, 0).cpu().float().numpy()
|
| 193 |
video_np = (video_np * 255).astype(np.uint8)
|
| 194 |
fps = args.frame_rate
|
| 195 |
height, width = video_np.shape[1:3]
|
| 196 |
+
output_filename = get_unique_filename(f"video_output_{i}", ".mp4", ".")
|
| 197 |
+
|
| 198 |
+
out = cv2.VideoWriter(output_filename, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
| 199 |
+
|
|
|
|
| 200 |
for frame in video_np[..., ::-1]:
|
| 201 |
out.write(frame)
|
| 202 |
+
|
| 203 |
out.release()
|
| 204 |
|
| 205 |
|