Spaces:
Runtime error
Runtime error
| import spaces | |
| import torch | |
| import gradio as gr | |
| from diffusers import CogVideoXPipeline | |
| from diffusers.utils import export_to_video | |
| from PIL import Image | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. Load & optimize the CogVideoX pipeline with CPU offload | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| pipe = CogVideoXPipeline.from_pretrained( | |
| "THUDM/CogVideoX1.5-5B", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.enable_model_cpu_offload() # auto move submodules between CPU/GPU | |
| pipe.vae.enable_slicing() # slice VAE for extra VRAM savings | |
| pipe.enable_attention_slicing() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. Resolution parsing & sanitization | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_divisible_by_8(x: int) -> int: | |
| return (x // 8) * 8 | |
| def parse_resolution(res_str: str): | |
| """ | |
| Convert strings like "480p" into (height, width) both divisible by 8 | |
| while preserving ~16:9 aspect ratio. | |
| """ | |
| h = int(res_str.rstrip("p")) | |
| w = int(h * 16 / 9) | |
| return make_divisible_by_8(h), make_divisible_by_8(w) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. GPUβdecorated video generation function | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # allow up to 180s of GPU time | |
| def generate_video( | |
| prompt: str, | |
| steps: int, | |
| frames: int, | |
| fps: int, | |
| resolution: str | |
| ) -> str: | |
| # 3.1 Determine target resolution and native resolution | |
| target_h, target_w = parse_resolution(resolution) | |
| # 3.2 Run the diffusion pipeline at native resolution | |
| output = pipe( | |
| prompt=prompt, | |
| num_inference_steps=steps, | |
| num_frames=frames, | |
| ) | |
| video_frames = output.frames[0] # list of PIL Images at native size | |
| # 3.3 Resize frames to user-specified resolution | |
| resized_frames = [ | |
| frame.resize((target_w, target_h), Image.LANCZOS) | |
| for frame in video_frames | |
| ] | |
| # 3.4 Export to MP4 (H.264) with chosen FPS | |
| video_path = export_to_video(resized_frames, "generated.mp4", fps=fps) | |
| import torch | |
| torch.cuda.empty_cache() | |
| return video_path | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. Build the Gradio interface with interactive controls | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Textual Imagination: A text to video synthesis") as demo: | |
| gr.Markdown( | |
| """ | |
| # ποΈ Textual Imagination: A text to video synthesis | |
| Generate videos from text prompts. | |
| Adjust inference steps, frame count, fps, and resolution below. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| lines=2 | |
| ) | |
| steps_slider = gr.Slider( | |
| minimum=1, maximum=100, step=1, value=20, | |
| label="Inference Steps" | |
| ) | |
| frames_slider = gr.Slider( | |
| minimum=16, maximum=320, step=1, value=70, | |
| label="Total Frames" | |
| ) | |
| fps_slider = gr.Slider( | |
| minimum=1, maximum=60, step=1, value=16, | |
| label="Frames per Second (FPS)" | |
| ) | |
| res_dropdown = gr.Dropdown( | |
| choices=["360p", "480p", "720p", "1080p"], | |
| value="480p", | |
| label="Resolution" | |
| ) | |
| gen_button = gr.Button("Generate Video") | |
| with gr.Column(): | |
| video_output = gr.Video( | |
| label="Generated Video", | |
| format="mp4" | |
| ) | |
| gen_button.click( | |
| fn=generate_video, | |
| inputs=[prompt_input, steps_slider, frames_slider, fps_slider, res_dropdown], | |
| outputs=video_output | |
| ) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. Launch: disable SSR so Gradio blocks and stays alive | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| ssr_mode=False | |
| ) | |