Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from diffusers import ( | |
| AutoencoderKLCogVideoX, | |
| CogVideoXTransformer3DModel, | |
| ) | |
| from diffusers.utils import export_to_video | |
| import tqdm | |
| from torchvision.transforms import ToPILImage | |
| import os | |
| import spaces | |
| #from torchao.quantization import autoquant | |
| device="cuda" | |
| shape=(1,48//4,16,256//8,256//8) | |
| sample_N=25 | |
| torch_dtype=torch.bfloat16 | |
| eps=1 | |
| cfg=2.5 | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "llm-jp/llm-jp-3-1.8b" | |
| ) | |
| text_encoder = AutoModelForCausalLM.from_pretrained( | |
| "llm-jp/llm-jp-3-1.8b", | |
| torch_dtype=torch_dtype | |
| ) | |
| text_encoder=text_encoder.to(device) | |
| transformer = CogVideoXTransformer3DModel.from_pretrained( | |
| "aidealab/AIdeaLab-VideoJP", | |
| torch_dtype=torch_dtype, | |
| token=os.environ['TOKEN'] | |
| ) | |
| #transformer = autoquant(transformer, error_on_unseen=False) | |
| #transformer.to(memory_format=torch.channels_last) | |
| #transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) | |
| transformer=transformer.to(device) | |
| vae = AutoencoderKLCogVideoX.from_pretrained( | |
| "THUDM/CogVideoX-2b", | |
| subfolder="vae" | |
| ) | |
| vae=vae.to(dtype=torch_dtype, device=device) | |
| vae.enable_slicing() | |
| vae.enable_tiling() | |
| def text_to_video(prompt, cfg=cfg): | |
| text_inputs = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=512, | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True, attention_mask=text_inputs.attention_mask.to(device)).hidden_states[-1] | |
| prompt_embeds = prompt_embeds.to(dtype=torch_dtype, device=device) | |
| null_text_inputs = tokenizer( | |
| "", | |
| padding="max_length", | |
| max_length=512, | |
| truncation=True, | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| ) | |
| null_text_input_ids = null_text_inputs.input_ids | |
| null_prompt_embeds = text_encoder(null_text_input_ids.to(device), output_hidden_states=True, attention_mask=null_text_inputs.attention_mask.to(device)).hidden_states[-1] | |
| null_prompt_embeds = null_prompt_embeds.to(dtype=torch_dtype, device=device) | |
| # euler discreate sampler with cfg | |
| z0 = torch.randn(shape, device=device) | |
| latents = z0.detach().clone().to(torch_dtype) | |
| dt = 1.0 / sample_N | |
| with torch.no_grad(): | |
| for i in tqdm.tqdm(range(sample_N)): | |
| num_t = i / sample_N | |
| t = torch.ones(shape[0], device=device) * num_t | |
| psudo_t=(1000-eps)*(1-t)+eps | |
| positive_conditional = transformer(hidden_states=latents, timestep=psudo_t, encoder_hidden_states=prompt_embeds, image_rotary_emb=None) | |
| null_conditional = transformer(hidden_states=latents, timestep=psudo_t, encoder_hidden_states=null_prompt_embeds, image_rotary_emb=None) | |
| pred = null_conditional.sample+cfg*(positive_conditional.sample-null_conditional.sample) | |
| latents = latents.detach().clone() + dt * pred.detach().clone() | |
| # Free vram | |
| latents = latents / vae.config.scaling_factor | |
| latents = latents.permute(0, 2, 1, 3, 4) # [B, F, C, H, W] | |
| x=vae.decode(latents).sample | |
| x = x / 2 + 0.5 | |
| x = x.clamp(0,1) | |
| x=x.permute(0, 2, 1, 3, 4).to(torch.float32)# [B, F, C, H, W] | |
| print(x.shape) | |
| x=[ToPILImage()(frame) for frame in x[0]] | |
| export_to_video(x,"output.mp4",fps=24) | |
| return "output.mp4" | |
| css=""" | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 520px; | |
| } | |
| """ | |
| # Gradio アプリケーションのレイアウトを定義 | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("""# AIdeaLab VideoJP Demo | |
| AIdeaLab VideoJPは、Rectified Flow Transformerで作られている軽量な動画生成モデルです ([詳細](https://note.com/aidealab/n/n677018ea1953)、[モデル](https://huggingface.co/aidealab/AIdeaLab-VideoJP))。十数秒で動画を作ることができます。なお、AIdeaLab VideoJPは経済産業省と国立研究開発法人新エネルギー・産業技術総合開発機構(NEDO)が実施する、国内の生成AIの開発力強化を目的としたプロジェクト「GENIAC(Generative AI Accelerator Challenge)」の成果をもとに作成されました。""") | |
| # テキストボックスで自由入力 | |
| text_input = gr.Textbox( | |
| label="動画生成のプロンプトを入力してください", | |
| placeholder="例: 静かな森の中を、やわらかな朝陽が差し込む。木漏れ日に照らされた小川には小さな魚が泳ぎ、森の奥からは小鳥のさえずりが聞こえる。", | |
| lines=5 | |
| ) | |
| generate_button = gr.Button("生成") | |
| output_video = gr.Video(label="生成された動画") | |
| # ボタンクリック時の挙動を設定 | |
| generate_button.click( | |
| fn=text_to_video, | |
| inputs=text_input, | |
| outputs=output_video | |
| ) | |
| demo.launch() | |