Spaces:
Running
on
Zero
Running
on
Zero
Update
Browse files
app.py
CHANGED
|
@@ -10,9 +10,7 @@ import numpy as np
|
|
| 10 |
import spaces
|
| 11 |
import torch
|
| 12 |
import torchvision
|
| 13 |
-
from diffusers.utils.import_utils import is_xformers_available
|
| 14 |
from huggingface_hub import snapshot_download
|
| 15 |
-
from packaging import version
|
| 16 |
from PIL import Image
|
| 17 |
from scipy.interpolate import PchipInterpolator
|
| 18 |
|
|
@@ -39,55 +37,40 @@ snapshot_download(
|
|
| 39 |
)
|
| 40 |
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
|
|
|
| 44 |
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
parser.add_argument("--min_guidance_scale", type=float, default=1.0)
|
| 48 |
-
parser.add_argument("--max_guidance_scale", type=float, default=3.0)
|
| 49 |
-
parser.add_argument("--middle_max_guidance", type=int, default=0, choices=[0, 1])
|
| 50 |
-
parser.add_argument("--with_control", type=int, default=1, choices=[0, 1])
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
parser.add_argument(
|
| 61 |
-
"--model",
|
| 62 |
-
type=str,
|
| 63 |
-
default="checkpoints/framer_512x320",
|
| 64 |
-
help="Path to model.",
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
parser.add_argument("--output_dir", type=str, default="gradio_demo/outputs", help="Path to the output video.")
|
| 68 |
-
|
| 69 |
-
parser.add_argument("--seed", type=int, default=42, help="random seed.")
|
| 70 |
-
|
| 71 |
-
parser.add_argument("--noise_aug", type=float, default=0.02)
|
| 72 |
-
|
| 73 |
-
parser.add_argument("--num_frames", type=int, default=14)
|
| 74 |
-
parser.add_argument("--frame_interval", type=int, default=2)
|
| 75 |
-
|
| 76 |
-
parser.add_argument("--width", type=int, default=512)
|
| 77 |
-
parser.add_argument("--height", type=int, default=320)
|
| 78 |
-
|
| 79 |
-
parser.add_argument(
|
| 80 |
-
"--num_workers",
|
| 81 |
-
type=int,
|
| 82 |
-
default=0,
|
| 83 |
-
help=(
|
| 84 |
-
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 85 |
-
),
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
args = parser.parse_args()
|
| 89 |
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
def interpolate_trajectory(points, n_points):
|
|
@@ -164,7 +147,7 @@ def get_vis_image(
|
|
| 164 |
vis_img = new_img.copy()
|
| 165 |
# ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
|
| 166 |
|
| 167 |
-
if idxx >=
|
| 168 |
break
|
| 169 |
|
| 170 |
# for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)):
|
|
@@ -363,187 +346,6 @@ def validate_and_convert_image(image, target_size=(512, 512)):
|
|
| 363 |
return image
|
| 364 |
|
| 365 |
|
| 366 |
-
class Drag:
|
| 367 |
-
|
| 368 |
-
@spaces.GPU
|
| 369 |
-
def __init__(self, device, args, height, width, model_length, dtype=torch.float16, use_sift=False):
|
| 370 |
-
self.device = device
|
| 371 |
-
self.dtype = dtype
|
| 372 |
-
|
| 373 |
-
unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
| 374 |
-
os.path.join(args.model, "unet"),
|
| 375 |
-
torch_dtype=torch.float16,
|
| 376 |
-
low_cpu_mem_usage=True,
|
| 377 |
-
custom_resume=True,
|
| 378 |
-
)
|
| 379 |
-
unet = unet.to(device, dtype)
|
| 380 |
-
|
| 381 |
-
controlnet = ControlNetSVDModel.from_pretrained(
|
| 382 |
-
os.path.join(args.model, "controlnet"),
|
| 383 |
-
)
|
| 384 |
-
controlnet = controlnet.to(device, dtype)
|
| 385 |
-
|
| 386 |
-
if is_xformers_available():
|
| 387 |
-
import xformers
|
| 388 |
-
|
| 389 |
-
xformers_version = version.parse(xformers.__version__)
|
| 390 |
-
unet.enable_xformers_memory_efficient_attention()
|
| 391 |
-
# controlnet.enable_xformers_memory_efficient_attention()
|
| 392 |
-
else:
|
| 393 |
-
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| 394 |
-
|
| 395 |
-
pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
|
| 396 |
-
"checkpoints/stable-video-diffusion-img2vid-xt",
|
| 397 |
-
unet=unet,
|
| 398 |
-
controlnet=controlnet,
|
| 399 |
-
low_cpu_mem_usage=False,
|
| 400 |
-
torch_dtype=torch.float16,
|
| 401 |
-
variant="fp16",
|
| 402 |
-
local_files_only=True,
|
| 403 |
-
)
|
| 404 |
-
pipe.to(device)
|
| 405 |
-
|
| 406 |
-
self.pipeline = pipe
|
| 407 |
-
# self.pipeline.enable_model_cpu_offload()
|
| 408 |
-
|
| 409 |
-
self.height = height
|
| 410 |
-
self.width = width
|
| 411 |
-
self.args = args
|
| 412 |
-
self.model_length = model_length
|
| 413 |
-
self.use_sift = use_sift
|
| 414 |
-
|
| 415 |
-
@spaces.GPU
|
| 416 |
-
def run(self, first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
|
| 417 |
-
original_width, original_height = 512, 320 # TODO
|
| 418 |
-
|
| 419 |
-
# load_image
|
| 420 |
-
image = Image.open(first_frame_path).convert("RGB")
|
| 421 |
-
width, height = image.size
|
| 422 |
-
image = image.resize((self.width, self.height))
|
| 423 |
-
|
| 424 |
-
image_end = Image.open(last_frame_path).convert("RGB")
|
| 425 |
-
image_end = image_end.resize((self.width, self.height))
|
| 426 |
-
|
| 427 |
-
input_all_points = tracking_points
|
| 428 |
-
|
| 429 |
-
sift_track_update = False
|
| 430 |
-
anchor_points_flag = None
|
| 431 |
-
|
| 432 |
-
if (len(input_all_points) == 0) and self.use_sift:
|
| 433 |
-
sift_track_update = True
|
| 434 |
-
controlnet_cond_scale = 0.5
|
| 435 |
-
|
| 436 |
-
from models_diffusers.sift_match import interpolate_trajectory as sift_interpolate_trajectory
|
| 437 |
-
from models_diffusers.sift_match import sift_match
|
| 438 |
-
|
| 439 |
-
output_file_sift = os.path.join(args.output_dir, "sift.png")
|
| 440 |
-
|
| 441 |
-
# (f, topk, 2), f=2 (before interpolation)
|
| 442 |
-
pred_tracks = sift_match(
|
| 443 |
-
image,
|
| 444 |
-
image_end,
|
| 445 |
-
thr=0.5,
|
| 446 |
-
topk=5,
|
| 447 |
-
method="random",
|
| 448 |
-
output_path=output_file_sift,
|
| 449 |
-
)
|
| 450 |
-
|
| 451 |
-
if pred_tracks is not None:
|
| 452 |
-
# interpolate the tracks, following draganything gradio demo
|
| 453 |
-
pred_tracks = sift_interpolate_trajectory(pred_tracks, num_frames=self.model_length)
|
| 454 |
-
|
| 455 |
-
anchor_points_flag = torch.zeros((self.model_length, pred_tracks.shape[1])).to(pred_tracks.device)
|
| 456 |
-
anchor_points_flag[0] = 1
|
| 457 |
-
anchor_points_flag[-1] = 1
|
| 458 |
-
|
| 459 |
-
pred_tracks = pred_tracks.permute(1, 0, 2) # (num_points, num_frames, 2)
|
| 460 |
-
|
| 461 |
-
else:
|
| 462 |
-
|
| 463 |
-
resized_all_points = [
|
| 464 |
-
tuple(
|
| 465 |
-
[
|
| 466 |
-
tuple([int(e1[0] * self.width / original_width), int(e1[1] * self.height / original_height)])
|
| 467 |
-
for e1 in e
|
| 468 |
-
]
|
| 469 |
-
)
|
| 470 |
-
for e in input_all_points
|
| 471 |
-
]
|
| 472 |
-
|
| 473 |
-
# a list of num_tracks tuples, each tuple contains a track with several points, represented as (x, y)
|
| 474 |
-
# in image w & h scale
|
| 475 |
-
|
| 476 |
-
for idx, splited_track in enumerate(resized_all_points):
|
| 477 |
-
if len(splited_track) == 0:
|
| 478 |
-
warnings.warn("running without point trajectory control")
|
| 479 |
-
continue
|
| 480 |
-
|
| 481 |
-
if len(splited_track) == 1: # stationary point
|
| 482 |
-
displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
|
| 483 |
-
splited_track = tuple([splited_track[0], displacement_point])
|
| 484 |
-
# interpolate the track
|
| 485 |
-
splited_track = interpolate_trajectory(splited_track, self.model_length)
|
| 486 |
-
splited_track = splited_track[: self.model_length]
|
| 487 |
-
resized_all_points[idx] = splited_track
|
| 488 |
-
|
| 489 |
-
pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2)
|
| 490 |
-
|
| 491 |
-
vis_images = get_vis_image(
|
| 492 |
-
target_size=(self.args.height, self.args.width),
|
| 493 |
-
points=pred_tracks,
|
| 494 |
-
num_frames=self.model_length,
|
| 495 |
-
)
|
| 496 |
-
|
| 497 |
-
if len(pred_tracks.shape) != 3:
|
| 498 |
-
print("pred_tracks.shape", pred_tracks.shape)
|
| 499 |
-
with_control = False
|
| 500 |
-
controlnet_cond_scale = 0.0
|
| 501 |
-
else:
|
| 502 |
-
with_control = True
|
| 503 |
-
pred_tracks = pred_tracks.permute(1, 0, 2).to(self.device, self.dtype) # (num_frames, num_points, 2)
|
| 504 |
-
|
| 505 |
-
point_embedding = None
|
| 506 |
-
video_frames = self.pipeline(
|
| 507 |
-
image,
|
| 508 |
-
image_end,
|
| 509 |
-
# trajectory control
|
| 510 |
-
with_control=with_control,
|
| 511 |
-
point_tracks=pred_tracks,
|
| 512 |
-
point_embedding=point_embedding,
|
| 513 |
-
with_id_feature=False,
|
| 514 |
-
controlnet_cond_scale=controlnet_cond_scale,
|
| 515 |
-
# others
|
| 516 |
-
num_frames=14,
|
| 517 |
-
width=width,
|
| 518 |
-
height=height,
|
| 519 |
-
# decode_chunk_size=8,
|
| 520 |
-
# generator=generator,
|
| 521 |
-
motion_bucket_id=motion_bucket_id,
|
| 522 |
-
fps=7,
|
| 523 |
-
num_inference_steps=30,
|
| 524 |
-
# track
|
| 525 |
-
sift_track_update=sift_track_update,
|
| 526 |
-
anchor_points_flag=anchor_points_flag,
|
| 527 |
-
).frames[0]
|
| 528 |
-
|
| 529 |
-
vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
|
| 530 |
-
vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
|
| 531 |
-
vis_images = [Image.fromarray(img) for img in vis_images]
|
| 532 |
-
|
| 533 |
-
# video_frames = [img for sublist in video_frames for img in sublist]
|
| 534 |
-
val_save_dir = os.path.join(args.output_dir, "vis_gif.gif")
|
| 535 |
-
save_gifs_side_by_side(
|
| 536 |
-
video_frames,
|
| 537 |
-
vis_images[: self.model_length],
|
| 538 |
-
val_save_dir,
|
| 539 |
-
target_size=(self.width, self.height),
|
| 540 |
-
duration=110,
|
| 541 |
-
point_tracks=pred_tracks,
|
| 542 |
-
)
|
| 543 |
-
|
| 544 |
-
return val_save_dir
|
| 545 |
-
|
| 546 |
-
|
| 547 |
def reset_states(first_frame_path, last_frame_path, tracking_points):
|
| 548 |
first_frame_path = None
|
| 549 |
last_frame_path = None
|
|
@@ -561,7 +363,7 @@ def preprocess_image(image):
|
|
| 561 |
# image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB'))
|
| 562 |
image_pil = image_pil.resize((512, 320), Image.BILINEAR)
|
| 563 |
|
| 564 |
-
first_frame_path = os.path.join(
|
| 565 |
|
| 566 |
image_pil.save(first_frame_path)
|
| 567 |
|
|
@@ -578,7 +380,7 @@ def preprocess_image_end(image_end):
|
|
| 578 |
# image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB'))
|
| 579 |
image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR)
|
| 580 |
|
| 581 |
-
last_frame_path = os.path.join(
|
| 582 |
|
| 583 |
image_end_pil.save(last_frame_path)
|
| 584 |
|
|
@@ -692,7 +494,7 @@ def add_tracking_points(
|
|
| 692 |
transparent_layer = 0
|
| 693 |
for idx, track in enumerate(tracking_points):
|
| 694 |
# mask = cv2.imread(
|
| 695 |
-
# os.path.join(
|
| 696 |
# )
|
| 697 |
mask = np.zeros((320, 512, 3))
|
| 698 |
color = color_list[idx + 1]
|
|
@@ -737,10 +539,136 @@ def add_tracking_points(
|
|
| 737 |
return tracking_points, trajectory_map, trajectory_map_end
|
| 738 |
|
| 739 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
if __name__ == "__main__":
|
| 741 |
|
| 742 |
-
|
| 743 |
-
ensure_dirname(args.output_dir)
|
| 744 |
|
| 745 |
color_list = []
|
| 746 |
for i in range(20):
|
|
@@ -771,8 +699,6 @@ if __name__ == "__main__":
|
|
| 771 |
3. Interpolate the images (according the path) with a click on "Run" button. <br>"""
|
| 772 |
)
|
| 773 |
|
| 774 |
-
# device, args, height, width, model_length
|
| 775 |
-
Framer = Drag("cuda", args, 320, 512, 14)
|
| 776 |
first_frame_path = gr.State()
|
| 777 |
last_frame_path = gr.State()
|
| 778 |
tracking_points = gr.State([])
|
|
@@ -898,7 +824,7 @@ if __name__ == "__main__":
|
|
| 898 |
)
|
| 899 |
|
| 900 |
run_button.click(
|
| 901 |
-
fn=
|
| 902 |
inputs=[first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id],
|
| 903 |
outputs=output_video,
|
| 904 |
)
|
|
|
|
| 10 |
import spaces
|
| 11 |
import torch
|
| 12 |
import torchvision
|
|
|
|
| 13 |
from huggingface_hub import snapshot_download
|
|
|
|
| 14 |
from PIL import Image
|
| 15 |
from scipy.interpolate import PchipInterpolator
|
| 16 |
|
|
|
|
| 37 |
)
|
| 38 |
|
| 39 |
|
| 40 |
+
model_id = "checkpoints/framer_512x320"
|
| 41 |
+
device = "cuda"
|
| 42 |
+
dtype = torch.float16
|
| 43 |
|
| 44 |
+
OUTPUT_DIR = "gradio_demo/outputs"
|
| 45 |
+
HEIGHT = 320
|
| 46 |
+
WIDTH = 512
|
| 47 |
+
MODEL_LENGTH = 14
|
| 48 |
+
USE_SIFT = False
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
| 52 |
+
os.path.join(model_id, "unet"),
|
| 53 |
+
torch_dtype=torch.float16,
|
| 54 |
+
low_cpu_mem_usage=True,
|
| 55 |
+
custom_resume=True,
|
| 56 |
+
)
|
| 57 |
+
unet = unet.to(device, dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
controlnet = ControlNetSVDModel.from_pretrained(
|
| 60 |
+
os.path.join(model_id, "controlnet"),
|
| 61 |
+
)
|
| 62 |
+
controlnet = controlnet.to(device, dtype)
|
| 63 |
+
|
| 64 |
+
pipe = StableVideoDiffusionInterpControlPipeline.from_pretrained(
|
| 65 |
+
"checkpoints/stable-video-diffusion-img2vid-xt",
|
| 66 |
+
unet=unet,
|
| 67 |
+
controlnet=controlnet,
|
| 68 |
+
low_cpu_mem_usage=False,
|
| 69 |
+
torch_dtype=torch.float16,
|
| 70 |
+
variant="fp16",
|
| 71 |
+
local_files_only=True,
|
| 72 |
+
)
|
| 73 |
+
pipe.to(device)
|
| 74 |
|
| 75 |
|
| 76 |
def interpolate_trajectory(points, n_points):
|
|
|
|
| 147 |
vis_img = new_img.copy()
|
| 148 |
# ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
|
| 149 |
|
| 150 |
+
if idxx >= num_frames:
|
| 151 |
break
|
| 152 |
|
| 153 |
# for cc, (mask, trajectory, radius) in enumerate(zip(mask_list, trajectory_list, radius_list)):
|
|
|
|
| 346 |
return image
|
| 347 |
|
| 348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
def reset_states(first_frame_path, last_frame_path, tracking_points):
|
| 350 |
first_frame_path = None
|
| 351 |
last_frame_path = None
|
|
|
|
| 363 |
# image_pil = transforms.CenterCrop((320, 512))(image_pil.convert('RGB'))
|
| 364 |
image_pil = image_pil.resize((512, 320), Image.BILINEAR)
|
| 365 |
|
| 366 |
+
first_frame_path = os.path.join(OUTPUT_DIR, f"first_frame_{str(uuid.uuid4())[:4]}.png")
|
| 367 |
|
| 368 |
image_pil.save(first_frame_path)
|
| 369 |
|
|
|
|
| 380 |
# image_end_pil = transforms.CenterCrop((320, 512))(image_end_pil.convert('RGB'))
|
| 381 |
image_end_pil = image_end_pil.resize((512, 320), Image.BILINEAR)
|
| 382 |
|
| 383 |
+
last_frame_path = os.path.join(OUTPUT_DIR, f"last_frame_{str(uuid.uuid4())[:4]}.png")
|
| 384 |
|
| 385 |
image_end_pil.save(last_frame_path)
|
| 386 |
|
|
|
|
| 494 |
transparent_layer = 0
|
| 495 |
for idx, track in enumerate(tracking_points):
|
| 496 |
# mask = cv2.imread(
|
| 497 |
+
# os.path.join(OUTPUT_DIR, f"mask_{idx+1}.jpg")
|
| 498 |
# )
|
| 499 |
mask = np.zeros((320, 512, 3))
|
| 500 |
color = color_list[idx + 1]
|
|
|
|
| 539 |
return tracking_points, trajectory_map, trajectory_map_end
|
| 540 |
|
| 541 |
|
| 542 |
+
@spaces.GPU
|
| 543 |
+
def run(first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id):
|
| 544 |
+
original_width, original_height = 512, 320 # TODO
|
| 545 |
+
|
| 546 |
+
# load_image
|
| 547 |
+
image = Image.open(first_frame_path).convert("RGB")
|
| 548 |
+
width, height = image.size
|
| 549 |
+
image = image.resize((WIDTH, HEIGHT))
|
| 550 |
+
|
| 551 |
+
image_end = Image.open(last_frame_path).convert("RGB")
|
| 552 |
+
image_end = image_end.resize((WIDTH, HEIGHT))
|
| 553 |
+
|
| 554 |
+
input_all_points = tracking_points
|
| 555 |
+
|
| 556 |
+
sift_track_update = False
|
| 557 |
+
anchor_points_flag = None
|
| 558 |
+
|
| 559 |
+
if (len(input_all_points) == 0) and USE_SIFT:
|
| 560 |
+
sift_track_update = True
|
| 561 |
+
controlnet_cond_scale = 0.5
|
| 562 |
+
|
| 563 |
+
from models_diffusers.sift_match import interpolate_trajectory as sift_interpolate_trajectory
|
| 564 |
+
from models_diffusers.sift_match import sift_match
|
| 565 |
+
|
| 566 |
+
output_file_sift = os.path.join(OUTPUT_DIR, "sift.png")
|
| 567 |
+
|
| 568 |
+
# (f, topk, 2), f=2 (before interpolation)
|
| 569 |
+
pred_tracks = sift_match(
|
| 570 |
+
image,
|
| 571 |
+
image_end,
|
| 572 |
+
thr=0.5,
|
| 573 |
+
topk=5,
|
| 574 |
+
method="random",
|
| 575 |
+
output_path=output_file_sift,
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
if pred_tracks is not None:
|
| 579 |
+
# interpolate the tracks, following draganything gradio demo
|
| 580 |
+
pred_tracks = sift_interpolate_trajectory(pred_tracks, num_frames=MODEL_LENGTH)
|
| 581 |
+
|
| 582 |
+
anchor_points_flag = torch.zeros((MODEL_LENGTH, pred_tracks.shape[1])).to(pred_tracks.device)
|
| 583 |
+
anchor_points_flag[0] = 1
|
| 584 |
+
anchor_points_flag[-1] = 1
|
| 585 |
+
|
| 586 |
+
pred_tracks = pred_tracks.permute(1, 0, 2) # (num_points, num_frames, 2)
|
| 587 |
+
|
| 588 |
+
else:
|
| 589 |
+
|
| 590 |
+
resized_all_points = [
|
| 591 |
+
tuple([tuple([int(e1[0] * WIDTH / original_width), int(e1[1] * HEIGHT / original_height)]) for e1 in e])
|
| 592 |
+
for e in input_all_points
|
| 593 |
+
]
|
| 594 |
+
|
| 595 |
+
# a list of num_tracks tuples, each tuple contains a track with several points, represented as (x, y)
|
| 596 |
+
# in image w & h scale
|
| 597 |
+
|
| 598 |
+
for idx, splited_track in enumerate(resized_all_points):
|
| 599 |
+
if len(splited_track) == 0:
|
| 600 |
+
warnings.warn("running without point trajectory control")
|
| 601 |
+
continue
|
| 602 |
+
|
| 603 |
+
if len(splited_track) == 1: # stationary point
|
| 604 |
+
displacement_point = tuple([splited_track[0][0] + 1, splited_track[0][1] + 1])
|
| 605 |
+
splited_track = tuple([splited_track[0], displacement_point])
|
| 606 |
+
# interpolate the track
|
| 607 |
+
splited_track = interpolate_trajectory(splited_track, MODEL_LENGTH)
|
| 608 |
+
splited_track = splited_track[:MODEL_LENGTH]
|
| 609 |
+
resized_all_points[idx] = splited_track
|
| 610 |
+
|
| 611 |
+
pred_tracks = torch.tensor(resized_all_points) # (num_points, num_frames, 2)
|
| 612 |
+
|
| 613 |
+
vis_images = get_vis_image(
|
| 614 |
+
target_size=(HEIGHT, WIDTH),
|
| 615 |
+
points=pred_tracks,
|
| 616 |
+
num_frames=MODEL_LENGTH,
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
if len(pred_tracks.shape) != 3:
|
| 620 |
+
print("pred_tracks.shape", pred_tracks.shape)
|
| 621 |
+
with_control = False
|
| 622 |
+
controlnet_cond_scale = 0.0
|
| 623 |
+
else:
|
| 624 |
+
with_control = True
|
| 625 |
+
pred_tracks = pred_tracks.permute(1, 0, 2).to(device, dtype) # (num_frames, num_points, 2)
|
| 626 |
+
|
| 627 |
+
point_embedding = None
|
| 628 |
+
video_frames = pipe(
|
| 629 |
+
image,
|
| 630 |
+
image_end,
|
| 631 |
+
# trajectory control
|
| 632 |
+
with_control=with_control,
|
| 633 |
+
point_tracks=pred_tracks,
|
| 634 |
+
point_embedding=point_embedding,
|
| 635 |
+
with_id_feature=False,
|
| 636 |
+
controlnet_cond_scale=controlnet_cond_scale,
|
| 637 |
+
# others
|
| 638 |
+
num_frames=14,
|
| 639 |
+
width=width,
|
| 640 |
+
height=height,
|
| 641 |
+
# decode_chunk_size=8,
|
| 642 |
+
# generator=generator,
|
| 643 |
+
motion_bucket_id=motion_bucket_id,
|
| 644 |
+
fps=7,
|
| 645 |
+
num_inference_steps=30,
|
| 646 |
+
# track
|
| 647 |
+
sift_track_update=sift_track_update,
|
| 648 |
+
anchor_points_flag=anchor_points_flag,
|
| 649 |
+
).frames[0]
|
| 650 |
+
|
| 651 |
+
vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
|
| 652 |
+
vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
|
| 653 |
+
vis_images = [Image.fromarray(img) for img in vis_images]
|
| 654 |
+
|
| 655 |
+
# video_frames = [img for sublist in video_frames for img in sublist]
|
| 656 |
+
val_save_dir = os.path.join(OUTPUT_DIR, "vis_gif.gif")
|
| 657 |
+
save_gifs_side_by_side(
|
| 658 |
+
video_frames,
|
| 659 |
+
vis_images[:MODEL_LENGTH],
|
| 660 |
+
val_save_dir,
|
| 661 |
+
target_size=(WIDTH, HEIGHT),
|
| 662 |
+
duration=110,
|
| 663 |
+
point_tracks=pred_tracks,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
return val_save_dir
|
| 667 |
+
|
| 668 |
+
|
| 669 |
if __name__ == "__main__":
|
| 670 |
|
| 671 |
+
ensure_dirname(OUTPUT_DIR)
|
|
|
|
| 672 |
|
| 673 |
color_list = []
|
| 674 |
for i in range(20):
|
|
|
|
| 699 |
3. Interpolate the images (according the path) with a click on "Run" button. <br>"""
|
| 700 |
)
|
| 701 |
|
|
|
|
|
|
|
| 702 |
first_frame_path = gr.State()
|
| 703 |
last_frame_path = gr.State()
|
| 704 |
tracking_points = gr.State([])
|
|
|
|
| 824 |
)
|
| 825 |
|
| 826 |
run_button.click(
|
| 827 |
+
fn=run,
|
| 828 |
inputs=[first_frame_path, last_frame_path, tracking_points, controlnet_cond_scale, motion_bucket_id],
|
| 829 |
outputs=output_video,
|
| 830 |
)
|