Spaces:
Paused
Paused
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import os | |
| import tempfile | |
| import shutil | |
| import imageio | |
| import pandas as pd | |
| import numpy as np | |
| from diffsynth import ModelManager, WanVideoReCamMasterPipeline, save_video | |
| import json | |
| from torchvision.transforms import v2 | |
| from einops import rearrange | |
| import torchvision | |
| from PIL import Image | |
| import logging | |
| from pathlib import Path | |
| from huggingface_hub import hf_hub_download | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Camera transformation types | |
| CAMERA_TRANSFORMATIONS = { | |
| "1": "Pan Right", | |
| "2": "Pan Left", | |
| "3": "Tilt Up", | |
| "4": "Tilt Down", | |
| "5": "Zoom In", | |
| "6": "Zoom Out", | |
| "7": "Translate Up (with rotation)", | |
| "8": "Translate Down (with rotation)", | |
| "9": "Arc Left (with rotation)", | |
| "10": "Arc Right (with rotation)" | |
| } | |
| # Global variables for model | |
| model_manager = None | |
| pipe = None | |
| is_model_loaded = False | |
| def download_recammaster_checkpoint(): | |
| """Download ReCamMaster checkpoint from HuggingFace using huggingface_hub""" | |
| # Define paths | |
| repo_id = "KwaiVGI/ReCamMaster-Wan2.1" | |
| filename = "step20000.ckpt" | |
| checkpoint_dir = Path("models/ReCamMaster/checkpoints") | |
| checkpoint_path = checkpoint_dir / filename | |
| # Check if already exists | |
| if checkpoint_path.exists(): | |
| logger.info(f"✓ ReCamMaster checkpoint already exists at {checkpoint_path}") | |
| return checkpoint_path | |
| # Create directory if it doesn't exist | |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) | |
| # Download the checkpoint | |
| logger.info("Downloading ReCamMaster checkpoint from HuggingFace...") | |
| logger.info(f"Repository: {repo_id}") | |
| logger.info(f"File: {filename}") | |
| logger.info(f"Destination: {checkpoint_path}") | |
| try: | |
| # Download using huggingface_hub | |
| downloaded_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| local_dir=checkpoint_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f"✓ Successfully downloaded ReCamMaster checkpoint to {downloaded_path}!") | |
| return downloaded_path | |
| except Exception as e: | |
| logger.error(f"✗ Error downloading checkpoint: {e}") | |
| raise | |
| class Camera(object): | |
| def __init__(self, c2w): | |
| c2w_mat = np.array(c2w).reshape(4, 4) | |
| self.c2w_mat = c2w_mat | |
| self.w2c_mat = np.linalg.inv(c2w_mat) | |
| def parse_matrix(matrix_str): | |
| """Parse camera matrix string from JSON format""" | |
| rows = matrix_str.strip().split('] [') | |
| matrix = [] | |
| for row in rows: | |
| row = row.replace('[', '').replace(']', '') | |
| matrix.append(list(map(float, row.split()))) | |
| return np.array(matrix) | |
| def get_relative_pose(cam_params): | |
| """Calculate relative camera poses""" | |
| abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] | |
| abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] | |
| cam_to_origin = 0 | |
| target_cam_c2w = np.array([ | |
| [1, 0, 0, 0], | |
| [0, 1, 0, -cam_to_origin], | |
| [0, 0, 1, 0], | |
| [0, 0, 0, 1] | |
| ]) | |
| abs2rel = target_cam_c2w @ abs_w2cs[0] | |
| ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] | |
| ret_poses = np.array(ret_poses, dtype=np.float32) | |
| return ret_poses | |
| def load_models(progress_callback=None): | |
| """Load the ReCamMaster models""" | |
| global model_manager, pipe, is_model_loaded | |
| if is_model_loaded: | |
| return "Models already loaded!" | |
| try: | |
| logger.info("Starting model loading...") | |
| # First ensure the checkpoint is downloaded | |
| if progress_callback: | |
| progress_callback(0.05, desc="Checking for ReCamMaster checkpoint...") | |
| try: | |
| ckpt_path = download_recammaster_checkpoint() | |
| logger.info(f"Using checkpoint at {ckpt_path}") | |
| except Exception as e: | |
| error_msg = f"Error downloading ReCamMaster checkpoint: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg | |
| if progress_callback: | |
| progress_callback(0.1, desc="Loading model manager...") | |
| # Load Wan2.1 pre-trained models | |
| model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") | |
| if progress_callback: | |
| progress_callback(0.3, desc="Loading Wan2.1 models...") | |
| model_manager.load_models([ | |
| "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", | |
| "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", | |
| "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", | |
| ]) | |
| if progress_callback: | |
| progress_callback(0.5, desc="Creating pipeline...") | |
| pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda") | |
| if progress_callback: | |
| progress_callback(0.7, desc="Initializing ReCamMaster modules...") | |
| # Initialize additional modules introduced in ReCamMaster | |
| dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0] | |
| for block in pipe.dit.blocks: | |
| block.cam_encoder = nn.Linear(12, dim) | |
| block.projector = nn.Linear(dim, dim) | |
| block.cam_encoder.weight.data.zero_() | |
| block.cam_encoder.bias.data.zero_() | |
| block.projector.weight = nn.Parameter(torch.eye(dim)) | |
| block.projector.bias = nn.Parameter(torch.zeros(dim)) | |
| if progress_callback: | |
| progress_callback(0.9, desc="Loading ReCamMaster checkpoint...") | |
| # Load ReCamMaster checkpoint | |
| if not os.path.exists(ckpt_path): | |
| error_msg = f"Error: ReCamMaster checkpoint not found at {ckpt_path} even after download attempt." | |
| logger.error(error_msg) | |
| return error_msg | |
| state_dict = torch.load(ckpt_path, map_location="cpu") | |
| pipe.dit.load_state_dict(state_dict, strict=True) | |
| pipe.to("cuda") | |
| pipe.to(dtype=torch.bfloat16) | |
| is_model_loaded = True | |
| if progress_callback: | |
| progress_callback(1.0, desc="Models loaded successfully!") | |
| logger.info("Models loaded successfully!") | |
| return "Models loaded successfully!" | |
| except Exception as e: | |
| logger.error(f"Error loading models: {str(e)}") | |
| return f"Error loading models: {str(e)}" | |
| def extract_frames_from_video(video_path, output_dir, max_frames=81): | |
| """Extract frames from video and ensure we have at least 81 frames""" | |
| os.makedirs(output_dir, exist_ok=True) | |
| reader = imageio.get_reader(video_path) | |
| fps = reader.get_meta_data()['fps'] | |
| total_frames = reader.count_frames() | |
| frames = [] | |
| for i, frame in enumerate(reader): | |
| frames.append(frame) | |
| reader.close() | |
| # If we have fewer than required frames, repeat the last frame | |
| if len(frames) < max_frames: | |
| logger.info(f"Video has {len(frames)} frames, padding to {max_frames} frames") | |
| last_frame = frames[-1] | |
| while len(frames) < max_frames: | |
| frames.append(last_frame) | |
| # Save frames | |
| for i, frame in enumerate(frames[:max_frames]): | |
| frame_path = os.path.join(output_dir, f"frame_{i:04d}.png") | |
| imageio.imwrite(frame_path, frame) | |
| return len(frames[:max_frames]), fps | |
| def process_video_for_recammaster(video_path, text_prompt, cam_type, height=480, width=832): | |
| """Process video through ReCamMaster model""" | |
| global pipe | |
| # Create frame processor | |
| frame_process = v2.Compose([ | |
| v2.CenterCrop(size=(height, width)), | |
| v2.Resize(size=(height, width), antialias=True), | |
| v2.ToTensor(), | |
| v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
| ]) | |
| def crop_and_resize(image): | |
| width_img, height_img = image.size | |
| scale = max(width / width_img, height / height_img) | |
| image = torchvision.transforms.functional.resize( | |
| image, | |
| (round(height_img*scale), round(width_img*scale)), | |
| interpolation=torchvision.transforms.InterpolationMode.BILINEAR | |
| ) | |
| return image | |
| # Load video frames | |
| reader = imageio.get_reader(video_path) | |
| frames = [] | |
| for i in range(81): # ReCamMaster needs exactly 81 frames | |
| try: | |
| frame = reader.get_data(i) | |
| frame = Image.fromarray(frame) | |
| frame = crop_and_resize(frame) | |
| frame = frame_process(frame) | |
| frames.append(frame) | |
| except: | |
| # If we run out of frames, repeat the last one | |
| if frames: | |
| frames.append(frames[-1]) | |
| else: | |
| raise ValueError("Video is too short!") | |
| reader.close() | |
| frames = torch.stack(frames, dim=0) | |
| frames = rearrange(frames, "T C H W -> C T H W") | |
| video_tensor = frames.unsqueeze(0) # Add batch dimension | |
| # Load camera trajectory | |
| tgt_camera_path = "./example_test_data/cameras/camera_extrinsics.json" | |
| with open(tgt_camera_path, 'r') as file: | |
| cam_data = json.load(file) | |
| # Get camera trajectory for selected type | |
| cam_idx = list(range(81))[::4] # Sample every 4 frames | |
| traj = [parse_matrix(cam_data[f"frame{idx}"][f"cam{int(cam_type):02d}"]) for idx in cam_idx] | |
| traj = np.stack(traj).transpose(0, 2, 1) | |
| c2ws = [] | |
| for c2w in traj: | |
| c2w = c2w[:, [1, 2, 0, 3]] | |
| c2w[:3, 1] *= -1. | |
| c2w[:3, 3] /= 100 | |
| c2ws.append(c2w) | |
| tgt_cam_params = [Camera(cam_param) for cam_param in c2ws] | |
| relative_poses = [] | |
| for i in range(len(tgt_cam_params)): | |
| relative_pose = get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]]) | |
| relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1]) | |
| pose_embedding = torch.stack(relative_poses, dim=0) # 21x3x4 | |
| pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)') | |
| camera_tensor = pose_embedding.to(torch.bfloat16).unsqueeze(0) # Add batch dimension | |
| # Generate video with ReCamMaster | |
| video = pipe( | |
| prompt=[text_prompt], | |
| negative_prompt=["worst quality, low quality, blurry, jittery, distorted"], | |
| source_video=video_tensor, | |
| target_camera=camera_tensor, | |
| cfg_scale=5.0, | |
| num_inference_steps=50, | |
| seed=0, | |
| tiled=True | |
| ) | |
| return video | |
| def generate_recammaster_video( | |
| video_file, | |
| text_prompt, | |
| camera_type, | |
| progress=gr.Progress() | |
| ): | |
| """Main function to generate video with ReCamMaster""" | |
| global pipe, is_model_loaded | |
| if not is_model_loaded: | |
| return None, "Error: Models not loaded! Please load models first." | |
| if video_file is None: | |
| return None, "Please upload a video file." | |
| try: | |
| # Create temporary directory for processing | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| progress(0.1, desc="Processing input video...") | |
| # Copy uploaded video to temp directory | |
| input_video_path = os.path.join(temp_dir, "input.mp4") | |
| shutil.copy(video_file.name, input_video_path) | |
| # Extract frames | |
| progress(0.2, desc="Extracting video frames...") | |
| num_frames, fps = extract_frames_from_video(input_video_path, os.path.join(temp_dir, "frames")) | |
| logger.info(f"Extracted {num_frames} frames at {fps} fps") | |
| # Process with ReCamMaster | |
| progress(0.3, desc="Processing with ReCamMaster...") | |
| output_video = process_video_for_recammaster( | |
| input_video_path, | |
| text_prompt, | |
| camera_type | |
| ) | |
| # Save output video | |
| progress(0.9, desc="Saving output video...") | |
| output_path = os.path.join(temp_dir, "output.mp4") | |
| save_video(output_video, output_path, fps=30, quality=5) | |
| # Copy to persistent location | |
| final_output_path = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False).name | |
| shutil.copy(output_path, final_output_path) | |
| progress(1.0, desc="Done!") | |
| transformation_name = CAMERA_TRANSFORMATIONS.get(str(camera_type), "Unknown") | |
| status_msg = f"Successfully generated video with '{transformation_name}' camera movement!" | |
| return final_output_path, status_msg | |
| except Exception as e: | |
| logger.error(f"Error generating video: {str(e)}") | |
| return None, f"Error: {str(e)}" | |
| # Create Gradio interface | |
| with gr.Blocks(title="ReCamMaster Demo") as demo: | |
| # Show loading status | |
| loading_status = gr.Textbox( | |
| label="Model Loading Status", | |
| value="Loading models, please wait...", | |
| interactive=False, | |
| visible=True | |
| ) | |
| gr.Markdown(""" | |
| # 🎥 ReCamMaster Demo | |
| ReCamMaster allows you to re-capture videos with novel camera trajectories. | |
| Upload a video and select a camera transformation to see the magic! | |
| **Note:** The ReCamMaster checkpoint will be automatically downloaded from HuggingFace when you start the app. | |
| You still need to download Wan2.1 models using `python download_wan2.1.py` before running this demo. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Video input section | |
| with gr.Group(): | |
| gr.Markdown("### Step 1: Upload Video") | |
| video_input = gr.Video(label="Input Video") | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt (describe your video)", | |
| placeholder="A person walking in the street", | |
| value="A dynamic scene" | |
| ) | |
| # Camera selection | |
| with gr.Group(): | |
| gr.Markdown("### Step 2: Select Camera Movement") | |
| camera_type = gr.Radio( | |
| choices=[(v, k) for k, v in CAMERA_TRANSFORMATIONS.items()], | |
| label="Camera Transformation", | |
| value="1" | |
| ) | |
| # Generate button | |
| generate_btn = gr.Button("Generate Video", variant="primary") | |
| with gr.Column(): | |
| # Output section | |
| output_video = gr.Video(label="Output Video") | |
| status_output = gr.Textbox(label="Generation Status", interactive=False) | |
| # Example videos | |
| gr.Markdown("### Example Videos") | |
| gr.Examples( | |
| examples=[ | |
| ["example_test_data/videos/case0.mp4", "A person dancing", "1"], | |
| ["example_test_data/videos/case1.mp4", "A scenic view", "5"], | |
| ], | |
| inputs=[video_input, text_prompt, camera_type], | |
| ) | |
| # Load models automatically when the interface loads | |
| def on_load(): | |
| status = load_models() | |
| return gr.update(value=status, visible=True if "Error" in status else False) | |
| demo.load(on_load, outputs=[loading_status]) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=generate_recammaster_video, | |
| inputs=[video_input, text_prompt, camera_type], | |
| outputs=[output_video, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |