Spaces:
Running
on
Zero
Running
on
Zero
| import time | |
| import gradio as gr | |
| import torch as th | |
| import torch | |
| import numpy as np | |
| import tempfile | |
| from diffusers import AutoencoderKLWan | |
| from diffusers.utils import export_to_video, load_image | |
| from diffusers.schedulers import UniPCMultistepScheduler | |
| from transformers import CLIPVisionModel | |
| from chronoedit_diffusers.pipeline_chronoedit import ChronoEditPipeline | |
| from chronoedit_diffusers.transformer_chronoedit import ChronoEditTransformer3DModel | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from prompt_enhancer import load_model, enhance_prompt | |
| # th.enable_grad(False) | |
| # th.backends.cuda.preferred_linalg_library(backend="magma") | |
| start = time.time() | |
| model_id = "nvidia/ChronoEdit-14B-Diffusers" | |
| image_encoder = CLIPVisionModel.from_pretrained( | |
| model_id, | |
| subfolder="image_encoder", | |
| torch_dtype=torch.float32 | |
| ) | |
| print("β Loaded image encoder") | |
| vae = AutoencoderKLWan.from_pretrained( | |
| model_id, | |
| subfolder="vae", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| print("β Loaded VAE") | |
| transformer = ChronoEditTransformer3DModel.from_pretrained( | |
| model_id, | |
| subfolder="transformer", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| print("β Loaded transformer") | |
| pipe = ChronoEditPipeline.from_pretrained( | |
| model_id, | |
| image_encoder=image_encoder, | |
| transformer=transformer, | |
| vae=vae, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| print("β Created pipeline") | |
| lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors") | |
| # Load LoRA if specified | |
| if lora_path: | |
| print(f"Loading LoRA weights from {lora_path}...") | |
| pipe.load_lora_weights(lora_path) | |
| pipe.fuse_lora(lora_scale=1.0) | |
| print(f"β Fused LoRA with scale 1.0") | |
| # Setup scheduler | |
| pipe.scheduler = UniPCMultistepScheduler.from_config( | |
| pipe.scheduler.config, | |
| flow_shift=2.0 | |
| ) | |
| print(f"β Configured scheduler (flow_shift=2.0)") | |
| # Move to device | |
| # pipe.to("cuda") | |
| print(f"β Models loaded and moved to cuda") | |
| end = time.time() | |
| print(f"Model loaded in {end - start:.2f}s.") | |
| start = time.time() | |
| prompt_enhancer_model = "Qwen/Qwen3-VL-30B-A3B-Instruct" | |
| prompt_model, processor = load_model(prompt_enhancer_model) | |
| end = time.time() | |
| print(f"Prompt enhancer loaded in {end - start:.2f}s.") | |
| def calculate_dimensions(image, mod_value): | |
| """ | |
| Calculate output dimensions based on resolution settings. | |
| Args: | |
| image: PIL Image | |
| mod_value: Modulo value for dimension alignment | |
| Returns: | |
| Tuple of (width, height) | |
| """ | |
| # Get max area from preset or override | |
| target_area = 720 * 1280 | |
| # Calculate dimensions maintaining aspect ratio | |
| aspect_ratio = image.height / image.width | |
| calculated_height = round(np.sqrt(target_area * aspect_ratio)) // mod_value * mod_value | |
| calculated_width = round(np.sqrt(target_area / aspect_ratio)) // mod_value * mod_value | |
| return calculated_width, calculated_height | |
| def run_inference( | |
| image_path: str, | |
| prompt: str, | |
| enable_temporal_reasoning: bool, | |
| num_inference_steps: int = 8, | |
| guidance_scale: float = 1.0, | |
| shift: float = 2.0, | |
| num_temporal_reasoning_steps: int = 8, | |
| ): | |
| # Rewriter | |
| prompt_model.to("cuda") | |
| final_prompt = prompt | |
| with th.no_grad(): | |
| # Enhance prompt with CoT reasoning | |
| cot_prompt = enhance_prompt( | |
| image_path, | |
| prompt, | |
| prompt_model, | |
| processor, | |
| ) | |
| # Print enhanced CoT prompt | |
| print("\n" + "=" * 80) | |
| print("Enhanced CoT Prompt:") | |
| print("=" * 80) | |
| print(cot_prompt) | |
| print("=" * 80 + "\n") | |
| final_prompt = cot_prompt | |
| prompt_model.to("cpu") | |
| # Inference | |
| pipe.to("cuda") | |
| print(f"Loading input image: {image_path}") | |
| image = load_image(image_path) | |
| mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] | |
| width, height = calculate_dimensions( | |
| image, | |
| mod_value | |
| ) | |
| print(f"Output dimensions: {width}x{height}") | |
| image = image.resize((width, height)) | |
| num_frames = 29 if enable_temporal_reasoning else 5 | |
| with th.no_grad(): | |
| start = time.time() | |
| output = pipe( | |
| image=image, | |
| prompt=final_prompt, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| enable_temporal_reasoning=enable_temporal_reasoning, | |
| num_temporal_reasoning_steps=num_temporal_reasoning_steps, | |
| ).frames[0] | |
| end = time.time() | |
| pipe.to("cpu") | |
| video_tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| output_path_video = video_tmp.name | |
| video_tmp.close() | |
| image_tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| output_path_image = image_tmp.name | |
| image_tmp.close() | |
| export_to_video(output, output_path_video, fps=10) | |
| Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save(output_path_image) | |
| log_text = ( | |
| f"Final prompt: {final_prompt}\n" | |
| f"Guidance: {guidance_scale}, Shift: {shift}, Steps: {num_inference_steps}\n" | |
| f"Inference: {end - start:.2f}s" | |
| ) | |
| if enable_temporal_reasoning: | |
| log_text += f"Temporal reasoning: {enable_temporal_reasoning}, Steps: {num_temporal_reasoning_steps}\n" | |
| return output_path_image, output_path_video #, log_text | |
| def build_ui() -> gr.Blocks: | |
| with gr.Blocks(title="ChronoEdit", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π ChronoEdit Demo | |
| [[Project Page]](https://research.nvidia.com/labs/toronto-ai/chronoedit/) | | |
| [[Code]](https://github.com/nv-tlabs/ChronoEdit) | | |
| [[Technical Report]](https://arxiv.org/abs/2510.04290) | |
| """) | |
| with gr.Row(): | |
| image = gr.Image(type="filepath", label="Input Image") | |
| output_image = gr.Image(label="Generated Image") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt = gr.Textbox(label="Prompt", lines=4, value="") | |
| enable_temporal_reasoning = gr.Checkbox(label="Enable temporal reasoning", value=False) | |
| run_btn = gr.Button("Start Generation", variant="primary") | |
| with gr.Column(scale=1): | |
| output_video = gr.Video(label="Temporal Reasoning Visualization", visible=False) | |
| # with gr.Row(): | |
| # num_inference_steps = gr.Slider(minimum=4, maximum=75, step=1, value=50, label="Num Inference Steps") | |
| # guidance_scale = gr.Slider(minimum=1.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale") | |
| # shift = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=5.0, label="Shift") | |
| # num_temporal_reasoning_steps = gr.Slider(minimum=0, maximum=50, step=1, value=50, label="Number of temporal reasoning steps") | |
| # log_text = gr.Markdown("Logs will appear here.") | |
| def _on_run(image_path, prompt, enable_temporal_reasoning): | |
| image_out_path, video_out_path = run_inference( | |
| image_path=image_path, | |
| prompt=prompt, | |
| enable_temporal_reasoning=enable_temporal_reasoning, | |
| ) | |
| video_update = gr.update(visible=enable_temporal_reasoning, value=(video_out_path if enable_temporal_reasoning else None)) | |
| return image_out_path, video_update | |
| run_btn.click( | |
| _on_run, | |
| inputs=[image, prompt, enable_temporal_reasoning], | |
| outputs=[output_image, output_video] #, log_text], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "examples/1.png", | |
| "The user wants to change the provided illustration of an elegant woman in a flowing red kimono into a high-end Japanese anime PVC scale figure, rendered photorealistically as a pre-painted collectible. Preserve her long black hair styled with golden hair ornaments and delicate floral accessories, her slightly tilted head and confident gaze, and the detailed red kimono with golden and floral embroidery tied with a wide gold obi. Cherry blossom petals drift around. Maintain the pose and camera view point unchanged. The scene should look like a premium finished PVC figure on display, with realistic textures, fine paint detailing, and a polished collectible presentation. Place the figure on a simple round base on a computer desk, with blurred keyboard and monitor glow in the background. Emphasize a strong 3D sense of volume and depth, realistic shadows and lighting, and painted PVC figure textures. Professional studio photography style, shallow depth of field, focus on the figure as a physical collectible. The lighting on the figure is uniform and highlighted, emphasizing every sculpted detail and painted accent.", | |
| True, | |
| ], | |
| [ | |
| "examples/2.png", | |
| "The user wants to change the scene so that the girl in the traditional-style painting, wearing her ornate floral robe and headdress, is now playing a guitar. ", | |
| False, | |
| ], | |
| ], | |
| inputs=[image, prompt, enable_temporal_reasoning], outputs=[output_image, output_video], fn=_on_run, cache_examples="lazy" | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| # demo.launch(server_name="0.0.0.0", server_port=7869) | |
| demo.queue().launch() |