|
|
import os |
|
|
import spaces |
|
|
import torch |
|
|
import gradio as gr |
|
|
import tempfile |
|
|
import subprocess |
|
|
import sys |
|
|
from pathlib import Path |
|
|
import datetime |
|
|
import math |
|
|
import random |
|
|
import gc |
|
|
import json |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from moviepy import * |
|
|
import librosa |
|
|
from omegaconf import OmegaConf |
|
|
from transformers import AutoTokenizer, Wav2Vec2Model, Wav2Vec2Processor |
|
|
from diffusers import FlowMatchEulerDiscreteScheduler |
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
|
|
def setup_repository(): |
|
|
if not os.path.exists("echomimic_v3"): |
|
|
print("π Cloning EchoMimicV3 repository...") |
|
|
subprocess.run([ |
|
|
"git", "clone", |
|
|
"https://github.com/antgroup/echomimic_v3.git" |
|
|
], check=True) |
|
|
print("β
Repository cloned successfully") |
|
|
|
|
|
sys.path.insert(0, "echomimic_v3") |
|
|
print("β
Repository added to Python path") |
|
|
|
|
|
def download_models(): |
|
|
print("π₯ Downloading models...") |
|
|
os.makedirs("models", exist_ok=True) |
|
|
try: |
|
|
print("π Downloading base model...") |
|
|
base_model_path = snapshot_download( |
|
|
repo_id="alibaba-pai/Wan2.1-Fun-V1.1-1.3B-InP", |
|
|
local_dir="models/Wan2.1-Fun-V1.1-1.3B-InP", |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
print(f"β
Base model downloaded to: {base_model_path}") |
|
|
|
|
|
print("π Downloading EchoMimicV3 transformer...") |
|
|
os.makedirs("models/transformer", exist_ok=True) |
|
|
transformer_file = hf_hub_download( |
|
|
repo_id="BadToBest/EchoMimicV3", |
|
|
filename="transformer/diffusion_pytorch_model.safetensors", |
|
|
local_dir="models", |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
print(f"β
Transformer downloaded to: {transformer_file}") |
|
|
|
|
|
config_file = hf_hub_download( |
|
|
repo_id="BadToBest/EchoMimicV3", |
|
|
filename="transformer/config.json", |
|
|
local_dir="models", |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
print(f"β
Config downloaded to: {config_file}") |
|
|
|
|
|
print("π Downloading Wav2Vec model...") |
|
|
wav2vec_path = snapshot_download( |
|
|
repo_id="facebook/wav2vec2-base-960h", |
|
|
local_dir="models/wav2vec2-base-960h", |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
print(f"β
Wav2Vec model downloaded to: {wav2vec_path}") |
|
|
|
|
|
print("β
All models downloaded successfully!") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error downloading models: {e}") |
|
|
return False |
|
|
|
|
|
def download_examples(): |
|
|
print("π Downloading example files...") |
|
|
os.makedirs("examples", exist_ok=True) |
|
|
try: |
|
|
example_files = [ |
|
|
"datasets/echomimicv3_demos/imgs/demo_ch_woman_04.png", |
|
|
"datasets/echomimicv3_demos/audios/demo_ch_woman_04.WAV", |
|
|
"datasets/echomimicv3_demos/prompts/demo_ch_woman_04.txt", |
|
|
"datasets/echomimicv3_demos/imgs/guitar_woman_01.png", |
|
|
"datasets/echomimicv3_demos/audios/guitar_woman_01.WAV", |
|
|
"datasets/echomimicv3_demos/prompts/guitar_woman_01.txt" |
|
|
] |
|
|
repo_url = "https://github.com/antgroup/echomimic_v3/raw/main/" |
|
|
for file_path in example_files: |
|
|
try: |
|
|
import urllib.request |
|
|
filename = os.path.basename(file_path) |
|
|
local_path = f"examples/{filename}" |
|
|
if not os.path.exists(local_path): |
|
|
print(f"π Downloading {filename}...") |
|
|
urllib.request.urlretrieve(f"{repo_url}{file_path}", local_path) |
|
|
print(f"β
Downloaded {filename}") |
|
|
else: |
|
|
print(f"β
{filename} already exists") |
|
|
except Exception as e: |
|
|
print(f"β οΈ Could not download {filename}: {e}") |
|
|
print("β
Example files downloaded!") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Error downloading examples: {e}") |
|
|
return False |
|
|
|
|
|
setup_repository() |
|
|
|
|
|
from src.dist import set_multi_gpus_devices |
|
|
from src.wan_vae import AutoencoderKLWan |
|
|
from src.wan_image_encoder import CLIPModel |
|
|
from src.wan_text_encoder import WanT5EncoderModel |
|
|
from src.wan_transformer3d_audio import WanTransformerAudioMask3DModel |
|
|
from src.pipeline_wan_fun_inpaint_audio import WanFunInpaintAudioPipeline |
|
|
from src.utils import filter_kwargs, get_image_to_video_latent3, save_videos_grid |
|
|
from src.fm_solvers import FlowDPMSolverMultistepScheduler |
|
|
from src.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
from src.cache_utils import get_teacache_coefficients |
|
|
from src.face_detect import get_mask_coord |
|
|
|
|
|
class ComprehensiveConfig: |
|
|
def __init__(self): |
|
|
self.ulysses_degree = 1 |
|
|
self.ring_degree = 1 |
|
|
self.fsdp_dit = False |
|
|
self.config_path = "echomimic_v3/config/config.yaml" |
|
|
self.model_name = "models/Wan2.1-Fun-V1.1-1.3B-InP" |
|
|
self.transformer_path = "models/transformer/diffusion_pytorch_model.safetensors" |
|
|
self.wav2vec_model_dir = "models/wav2vec2-base-960h" |
|
|
self.weight_dtype = torch.bfloat16 |
|
|
self.sample_size = [768, 768] |
|
|
self.sampler_name = "Flow_DPM++" |
|
|
self.lora_weight = 1.0 |
|
|
|
|
|
config = ComprehensiveConfig() |
|
|
pipeline = None |
|
|
wav2vec_processor = None |
|
|
wav2vec_model = None |
|
|
|
|
|
def load_wav2vec_models(wav2vec_model_dir): |
|
|
print(f"π Loading Wav2Vec models from {wav2vec_model_dir}...") |
|
|
try: |
|
|
processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir) |
|
|
model = Wav2Vec2Model.from_pretrained(wav2vec_model_dir).eval() |
|
|
model.requires_grad_(False) |
|
|
print("β
Wav2Vec models loaded successfully") |
|
|
return processor, model |
|
|
except Exception as e: |
|
|
print(f"β Error loading Wav2Vec models: {e}") |
|
|
raise |
|
|
|
|
|
def extract_audio_features(audio_path, processor, model): |
|
|
try: |
|
|
sr = 16000 |
|
|
audio_segment, sample_rate = librosa.load(audio_path, sr=sr) |
|
|
input_values = processor(audio_segment, sampling_rate=sample_rate, return_tensors="pt").input_values |
|
|
input_values = input_values.to(model.device) |
|
|
with torch.no_grad(): |
|
|
features = model(input_values).last_hidden_state |
|
|
return features.squeeze(0) |
|
|
except Exception as e: |
|
|
print(f"β Error extracting audio features: {e}") |
|
|
raise |
|
|
|
|
|
def get_sample_size(image, default_size): |
|
|
width, height = image.size |
|
|
original_area = width * height |
|
|
default_area = default_size[0] * default_size[1] |
|
|
if default_area < original_area: |
|
|
ratio = math.sqrt(original_area / default_area) |
|
|
width = width / ratio // 16 * 16 |
|
|
height = height / ratio // 16 * 16 |
|
|
else: |
|
|
width = width // 16 * 16 |
|
|
height = height // 16 * 16 |
|
|
return int(height), int(width) |
|
|
|
|
|
def get_ip_mask(coords): |
|
|
y1, y2, x1, x2, h, w = coords |
|
|
Y, X = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij') |
|
|
mask = (Y.unsqueeze(-1) >= y1) & (Y.unsqueeze(-1) < y2) & (X.unsqueeze(-1) >= x1) & (X.unsqueeze(-1) < x2) |
|
|
mask = mask.reshape(-1) |
|
|
return mask.float() |
|
|
|
|
|
def initialize_models(): |
|
|
global pipeline, wav2vec_processor, wav2vec_model, config |
|
|
print("π Initializing EchoMimicV3 models...") |
|
|
try: |
|
|
if not download_models(): |
|
|
raise Exception("Failed to download required models") |
|
|
download_examples() |
|
|
device = set_multi_gpus_devices(config.ulysses_degree, config.ring_degree) |
|
|
print(f"β
Device set to: {device}") |
|
|
cfg = OmegaConf.load(config.config_path) |
|
|
print(f"β
Config loaded from {config.config_path}") |
|
|
print("π Loading transformer...") |
|
|
transformer = WanTransformerAudioMask3DModel.from_pretrained( |
|
|
os.path.join(config.model_name, cfg['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), |
|
|
transformer_additional_kwargs=OmegaConf.to_container(cfg['transformer_additional_kwargs']), |
|
|
torch_dtype=config.weight_dtype, |
|
|
) |
|
|
if config.transformer_path is not None and os.path.exists(config.transformer_path): |
|
|
print(f"π Loading custom transformer weights from {config.transformer_path}...") |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(config.transformer_path) |
|
|
state_dict = state_dict.get("state_dict", state_dict) |
|
|
missing, unexpected = transformer.load_state_dict(state_dict, strict=False) |
|
|
print(f"β
Custom transformer weights loaded - Missing: {len(missing)}, Unexpected: {len(unexpected)}") |
|
|
|
|
|
print("π Loading VAE...") |
|
|
vae = AutoencoderKLWan.from_pretrained( |
|
|
os.path.join(config.model_name, cfg['vae_kwargs'].get('vae_subpath', 'vae')), |
|
|
additional_kwargs=OmegaConf.to_container(cfg['vae_kwargs']), |
|
|
).to(config.weight_dtype) |
|
|
print("β
VAE loaded") |
|
|
|
|
|
print("π Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
os.path.join(config.model_name, cfg['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), |
|
|
) |
|
|
print("β
Tokenizer loaded") |
|
|
|
|
|
print("π Loading text encoder...") |
|
|
text_encoder = WanT5EncoderModel.from_pretrained( |
|
|
os.path.join(config.model_name, cfg['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), |
|
|
additional_kwargs=OmegaConf.to_container(cfg['text_encoder_kwargs']), |
|
|
torch_dtype=config.weight_dtype, |
|
|
).eval() |
|
|
print("β
Text encoder loaded") |
|
|
|
|
|
print("π Loading CLIP image encoder...") |
|
|
clip_image_encoder = CLIPModel.from_pretrained( |
|
|
os.path.join(config.model_name, cfg['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), |
|
|
).to(config.weight_dtype).eval() |
|
|
print("β
CLIP image encoder loaded") |
|
|
|
|
|
print("π Loading scheduler...") |
|
|
scheduler_cls_map = { |
|
|
"Flow": FlowMatchEulerDiscreteScheduler, |
|
|
"Flow_Unipc": FlowUniPCMultistepScheduler, |
|
|
"Flow_DPM++": FlowDPMSolverMultistepScheduler, |
|
|
} |
|
|
scheduler_cls = scheduler_cls_map.get(config.sampler_name, FlowDPMSolverMultistepScheduler) |
|
|
scheduler = scheduler_cls(**filter_kwargs(scheduler_cls, OmegaConf.to_container(cfg['scheduler_kwargs']))) |
|
|
print("β
Scheduler loaded") |
|
|
|
|
|
print("π Creating pipeline...") |
|
|
pipeline = WanFunInpaintAudioPipeline( |
|
|
transformer=transformer, |
|
|
vae=vae, |
|
|
tokenizer=tokenizer, |
|
|
text_encoder=text_encoder, |
|
|
scheduler=scheduler, |
|
|
clip_image_encoder=clip_image_encoder, |
|
|
) |
|
|
pipeline.to(device=device) |
|
|
|
|
|
if torch.__version__ >= "2.0": |
|
|
print("π Compiling the pipeline with torch.compile()...") |
|
|
pipeline.transformer = torch.compile(pipeline.transformer, mode="reduce-overhead", fullgraph=True) |
|
|
print("β
Pipeline transformer compiled!") |
|
|
|
|
|
print("β
Pipeline created and moved to device") |
|
|
|
|
|
print("π Loading Wav2Vec models...") |
|
|
wav2vec_processor, wav2vec_model = load_wav2vec_models(config.wav2vec_model_dir) |
|
|
wav2vec_model.to(device) |
|
|
print("β
Wav2Vec models loaded") |
|
|
|
|
|
print("π All models initialized successfully!") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"β Model initialization failed: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def generate_video( |
|
|
image_path, |
|
|
audio_path, |
|
|
prompt, |
|
|
negative_prompt, |
|
|
seed_param, |
|
|
num_inference_steps, |
|
|
guidance_scale, |
|
|
audio_guidance_scale, |
|
|
fps, |
|
|
partial_video_length, |
|
|
overlap_video_length, |
|
|
neg_scale, |
|
|
neg_steps, |
|
|
use_dynamic_cfg, |
|
|
use_dynamic_acfg, |
|
|
sampler_name, |
|
|
shift, |
|
|
audio_scale, |
|
|
use_un_ip_mask, |
|
|
enable_teacache, |
|
|
teacache_threshold, |
|
|
teacache_offload, |
|
|
num_skip_start_steps, |
|
|
enable_riflex, |
|
|
riflex_k, |
|
|
progress=gr.Progress(track_tqdm=True) |
|
|
): |
|
|
global pipeline, wav2vec_processor, wav2vec_model, config |
|
|
|
|
|
progress(0, desc="Starting video generation...") |
|
|
|
|
|
if image_path is None: |
|
|
raise gr.Error("Please upload an image") |
|
|
if audio_path is None: |
|
|
raise gr.Error("Please upload an audio file") |
|
|
if not models_ready or pipeline is None: |
|
|
raise gr.Error("Models not initialized. Please restart the space.") |
|
|
|
|
|
device = pipeline.device |
|
|
|
|
|
if seed_param < 0: |
|
|
seed = random.randint(0, np.iinfo(np.int32).max) |
|
|
else: |
|
|
seed = int(seed_param) |
|
|
|
|
|
print(f"π² Using seed: {seed}") |
|
|
|
|
|
try: |
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
ref_img_pil = Image.open(image_path).convert("RGB") |
|
|
print(f"πΈ Image loaded: {ref_img_pil.size}") |
|
|
|
|
|
progress(0.1, desc="Detecting face...") |
|
|
try: |
|
|
y1, y2, x1, x2, h_, w_ = get_mask_coord(image_path) |
|
|
print("β
Face detection successful") |
|
|
except Exception as e: |
|
|
print(f"β οΈ Face detection failed: {e}, using center crop") |
|
|
h_, w_ = ref_img_pil.size[1], ref_img_pil.size[0] |
|
|
y1, y2 = h_ // 4, 3 * h_ // 4 |
|
|
x1, x2 = w_ // 4, 3 * w_ // 4 |
|
|
|
|
|
progress(0.2, desc="Processing audio...") |
|
|
audio_clip = AudioFileClip(audio_path) |
|
|
audio_features = extract_audio_features(audio_path, wav2vec_processor, wav2vec_model) |
|
|
audio_embeds = audio_features.unsqueeze(0).to(device=device, dtype=config.weight_dtype) |
|
|
|
|
|
progress(0.25, desc="Encoding prompts...") |
|
|
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt( |
|
|
prompt, |
|
|
device=device, |
|
|
do_classifier_free_guidance=(guidance_scale > 1.0), |
|
|
negative_prompt=negative_prompt |
|
|
) |
|
|
|
|
|
video_length = int(audio_clip.duration * fps) |
|
|
video_length = ( |
|
|
int((video_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 |
|
|
if video_length != 1 else 1 |
|
|
) |
|
|
print(f"π₯ Total video length: {video_length} frames") |
|
|
|
|
|
sample_height, sample_width = get_sample_size(ref_img_pil, config.sample_size) |
|
|
print(f"π Sample size: {sample_width}x{sample_height}") |
|
|
|
|
|
downratio = math.sqrt(sample_height * sample_width / h_ / w_) |
|
|
coords = ( |
|
|
y1 * downratio // 16, y2 * downratio // 16, |
|
|
x1 * downratio // 16, x2 * downratio // 16, |
|
|
sample_height // 16, sample_width // 16, |
|
|
) |
|
|
ip_mask = get_ip_mask(coords).unsqueeze(0) |
|
|
ip_mask = torch.cat([ip_mask]*3).to(device=device, dtype=config.weight_dtype) |
|
|
|
|
|
if enable_riflex: |
|
|
latent_frames = (video_length - 1) // pipeline.vae.config.temporal_compression_ratio + 1 |
|
|
pipeline.transformer.enable_riflex(k=riflex_k, L_test=latent_frames) |
|
|
|
|
|
if enable_teacache: |
|
|
try: |
|
|
coefficients = get_teacache_coefficients(config.model_name) |
|
|
if coefficients: |
|
|
pipeline.transformer.enable_teacache( |
|
|
coefficients, num_inference_steps, teacache_threshold, |
|
|
num_skip_start_steps=num_skip_start_steps, |
|
|
offload=teacache_offload |
|
|
) |
|
|
print("β
TeaCache enabled for this run") |
|
|
except Exception as e: |
|
|
print(f"β οΈ Could not enable TeaCache: {e}") |
|
|
|
|
|
init_frames = 0 |
|
|
new_sample = None |
|
|
ref_img_for_loop = ref_img_pil |
|
|
total_chunks = math.ceil(video_length / (partial_video_length - overlap_video_length)) if video_length > partial_video_length else 1 |
|
|
chunk_num = 0 |
|
|
|
|
|
while init_frames < video_length: |
|
|
chunk_num += 1 |
|
|
progress(0.3 + (0.6 * (chunk_num / total_chunks)), desc=f"Generating chunk {chunk_num}/{total_chunks}...") |
|
|
|
|
|
current_partial_length = min(partial_video_length, video_length - init_frames) |
|
|
current_partial_length = ( |
|
|
int((current_partial_length - 1) // pipeline.vae.config.temporal_compression_ratio * pipeline.vae.config.temporal_compression_ratio) + 1 |
|
|
if current_partial_length > 1 else 1 |
|
|
) |
|
|
if current_partial_length <= 0: |
|
|
break |
|
|
|
|
|
input_video, input_video_mask, clip_image = get_image_to_video_latent3( |
|
|
ref_img_for_loop, None, video_length=current_partial_length, |
|
|
sample_size=[sample_height, sample_width] |
|
|
) |
|
|
|
|
|
audio_start_frame = init_frames * 2 |
|
|
audio_end_frame = (init_frames + current_partial_length) * 2 |
|
|
|
|
|
if audio_embeds.shape[1] < audio_end_frame: |
|
|
repeat_times = (audio_end_frame // audio_embeds.shape[1]) + 1 |
|
|
audio_embeds = audio_embeds.repeat(1, repeat_times, 1) |
|
|
|
|
|
partial_audio_embeds = audio_embeds[:, audio_start_frame:audio_end_frame] |
|
|
|
|
|
with torch.no_grad(): |
|
|
sample = pipeline( |
|
|
prompt_embeds=prompt_embeds, |
|
|
negative_prompt_embeds=negative_prompt_embeds, |
|
|
num_frames=current_partial_length, |
|
|
audio_embeds=partial_audio_embeds, |
|
|
audio_scale=audio_scale, |
|
|
ip_mask=ip_mask, |
|
|
use_un_ip_mask=use_un_ip_mask, |
|
|
height=sample_height, |
|
|
width=sample_width, |
|
|
generator=generator, |
|
|
neg_scale=neg_scale, |
|
|
neg_steps=neg_steps, |
|
|
use_dynamic_cfg=use_dynamic_cfg, |
|
|
use_dynamic_acfg=use_dynamic_acfg, |
|
|
guidance_scale=guidance_scale, |
|
|
audio_guidance_scale=audio_guidance_scale, |
|
|
num_inference_steps=num_inference_steps, |
|
|
video=input_video, |
|
|
mask_video=input_video_mask, |
|
|
clip_image=clip_image, |
|
|
shift=shift, |
|
|
).videos |
|
|
|
|
|
if new_sample is None: |
|
|
new_sample = sample |
|
|
else: |
|
|
mix_ratio = torch.linspace(0, 1, steps=overlap_video_length, device=device).view(1, 1, -1, 1, 1).to(new_sample.dtype) |
|
|
new_sample[:, :, -overlap_video_length:] = ( |
|
|
new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + |
|
|
sample[:, :, :overlap_video_length] * mix_ratio |
|
|
) |
|
|
new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim=2) |
|
|
|
|
|
if new_sample.shape[2] >= video_length: |
|
|
break |
|
|
|
|
|
ref_img_for_loop = [ |
|
|
Image.fromarray( |
|
|
(new_sample[0, :, i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) |
|
|
) for i in range(-overlap_video_length, 0) |
|
|
] |
|
|
|
|
|
init_frames += current_partial_length - overlap_video_length |
|
|
|
|
|
progress(0.9, desc="Stitching video and audio...") |
|
|
final_sample = new_sample[:, :, :video_length] |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file: |
|
|
video_path = tmp_file.name |
|
|
with tempfile.NamedTemporaryFile(suffix="_audio.mp4", delete=False) as tmp_file: |
|
|
video_audio_path = tmp_file.name |
|
|
|
|
|
save_videos_grid(final_sample, video_path, fps=fps) |
|
|
|
|
|
video_clip_final = VideoFileClip(video_path) |
|
|
audio_clip_trimmed = audio_clip.subclip(0, final_sample.shape[2] / fps) |
|
|
final_video = video_clip_final.with_audio(audio_clip_trimmed) |
|
|
final_video.write_videofile(video_audio_path, codec="libx264", audio_codec="aac", threads=4, logger=None) |
|
|
|
|
|
video_clip_final.close() |
|
|
audio_clip.close() |
|
|
audio_clip_trimmed.close() |
|
|
final_video.close() |
|
|
|
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.ipc_collect() |
|
|
|
|
|
progress(1.0, desc="Generation complete!") |
|
|
return video_audio_path, seed |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Generation error: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise gr.Error(f"Generation failed: {str(e)}") |
|
|
|
|
|
|
|
|
def create_demo(): |
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="EchoMimicV3 Demo") as demo: |
|
|
gr.Markdown(""" |
|
|
# π EchoMimicV3: Audio-Driven Human Animation |
|
|
|
|
|
Transform a portrait photo into a talking video! Upload an image and an audio file to create lifelike, expressive animations. This demo showcases the power of the EchoMimicV3 model. |
|
|
|
|
|
**Key Features:** |
|
|
- π― **High-Quality Lip Sync:** Accurate mouth movements that match the input audio. |
|
|
- π¨ **Natural Facial Expressions:** Generates subtle and natural facial emotions. |
|
|
- π΅ **Speech & Singing:** Works with both spoken word and singing. |
|
|
- β‘ **Efficient:** Powered by a compact 1.3B parameter model. |
|
|
""") |
|
|
|
|
|
if not models_ready: |
|
|
gr.Warning("Models are still loading. The UI is disabled. Please wait and refresh the page if necessary.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image( |
|
|
label="πΈ Upload Portrait Image", |
|
|
type="filepath", |
|
|
sources=["upload"], |
|
|
height=400, |
|
|
) |
|
|
audio_input = gr.Audio( |
|
|
label="π΅ Upload Audio", |
|
|
type="filepath", |
|
|
sources=["upload"], |
|
|
) |
|
|
|
|
|
with gr.Accordion("π Text Prompts", open=True): |
|
|
prompt = gr.Textbox( |
|
|
label="βοΈ Prompt", |
|
|
value="A person talking naturally with clear expressions.", |
|
|
) |
|
|
negative_prompt = gr.Textbox( |
|
|
label="π« Negative Prompt", |
|
|
value="Gesture is bad, unclear. Strange, twisted, bad, blurry hands and fingers.", |
|
|
lines=2, |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
video_output = gr.Video( |
|
|
label="π₯ Generated Video", |
|
|
interactive=False, |
|
|
height=400 |
|
|
) |
|
|
seed_output = gr.Number( |
|
|
label="π² Used Seed", |
|
|
interactive=False, |
|
|
precision=0 |
|
|
) |
|
|
|
|
|
with gr.Accordion("βοΈ Advanced Settings", open=False): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Core Generation Parameters") |
|
|
seed_param = gr.Number(label="π² Seed", value=-1, precision=0, info="-1 for random seed.") |
|
|
num_inference_steps = gr.Slider(label="Inference Steps", minimum=5, maximum=50, value=20, step=1, info="More steps can improve quality but take longer. 15-25 is a good range.") |
|
|
fps = gr.Slider(label="Frames Per Second (FPS)", minimum=10, maximum=30, value=25, step=1, info="Controls the smoothness of the output video.") |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Classifier-Free Guidance (CFG)") |
|
|
guidance_scale = gr.Slider(label="Text Guidance Scale (CFG)", minimum=1.0, maximum=10.0, value=4.5, step=0.1, info="How strongly to follow the text prompt. Recommended: 3.0-6.0.") |
|
|
audio_guidance_scale = gr.Slider(label="Audio Guidance Scale (aCFG)", minimum=1.0, maximum=10.0, value=2.5, step=0.1, info="How strongly to follow the audio for lip sync. Recommended: 2.0-3.0.") |
|
|
use_dynamic_cfg = gr.Checkbox(label="Use Dynamic Text CFG", value=True, info="Gradually adjusts CFG during generation, can improve quality.") |
|
|
use_dynamic_acfg = gr.Checkbox(label="Use Dynamic Audio aCFG", value=True, info="Gradually adjusts aCFG during generation, can improve quality.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Performance & VRAM (Chunking)") |
|
|
partial_video_length = gr.Slider(label="Partial Video Length (Chunk Size)", minimum=49, maximum=161, value=113, step=16, info="Key for VRAM usage. 24G VRAM: ~113, 16G: ~81, 12G: ~49. Lower values use less memory but may affect consistency.") |
|
|
overlap_video_length = gr.Slider(label="Overlap Length", minimum=4, maximum=16, value=8, step=1, info="How many frames to overlap between chunks for smooth transitions.") |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Sampler & Scheduler") |
|
|
sampler_name = gr.Dropdown(label="Sampler", choices=["Flow", "Flow_Unipc", "Flow_DPM++"], value="Flow_DPM++", info="Algorithm for the diffusion process.") |
|
|
shift = gr.Slider(label="Scheduler Shift", minimum=1.0, maximum=10.0, value=5.0, step=0.1, info="Adjusts the noise schedule. Optimal range depends on the sampler.") |
|
|
audio_scale = gr.Slider(label="Audio Scale", minimum=0.5, maximum=2.0, value=1.0, step=0.1, info="Global scale for audio feature influence.") |
|
|
use_un_ip_mask = gr.Checkbox(label="Use Un-IP Mask", value=False, info="Inverts the inpainting mask.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Negative Guidance (Advanced CFG)") |
|
|
neg_scale = gr.Slider(label="Negative Scale", minimum=1.0, maximum=5.0, value=1.5, step=0.1, info="Strength of negative prompt in early steps.") |
|
|
neg_steps = gr.Slider(label="Negative Steps", minimum=0, maximum=10, value=2, step=1, info="How many initial steps to apply the negative scale.") |
|
|
|
|
|
with gr.Accordion("π¬ Experimental Settings", open=False): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### TeaCache (Performance Boost)") |
|
|
enable_teacache = gr.Checkbox(label="Enable TeaCache", value=True) |
|
|
teacache_threshold = gr.Slider(label="TeaCache Threshold", minimum=0.0, maximum=0.2, value=0.1, step=0.01) |
|
|
teacache_offload = gr.Checkbox(label="TeaCache Offload", value=True) |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Riflex (Consistency)") |
|
|
enable_riflex = gr.Checkbox(label="Enable Riflex", value=False) |
|
|
riflex_k = gr.Slider(label="Riflex K", minimum=1, maximum=10, value=6, step=1) |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Other") |
|
|
num_skip_start_steps = gr.Slider(label="Num Skip Start Steps", minimum=0, maximum=10, value=5, step=1) |
|
|
|
|
|
generate_button = gr.Button( |
|
|
"π¬ Generate Video", |
|
|
variant='primary', |
|
|
size="lg", |
|
|
interactive=models_ready |
|
|
) |
|
|
|
|
|
all_inputs = [ |
|
|
image_input, audio_input, prompt, negative_prompt, seed_param, |
|
|
num_inference_steps, guidance_scale, audio_guidance_scale, fps, |
|
|
partial_video_length, overlap_video_length, neg_scale, neg_steps, |
|
|
use_dynamic_cfg, use_dynamic_acfg, sampler_name, shift, audio_scale, |
|
|
use_un_ip_mask, enable_teacache, teacache_threshold, teacache_offload, |
|
|
num_skip_start_steps, enable_riflex, riflex_k |
|
|
] |
|
|
|
|
|
if models_ready: |
|
|
generate_button.click( |
|
|
fn=generate_video, |
|
|
inputs=all_inputs, |
|
|
outputs=[video_output, seed_output] |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### β¨ Click to Try Examples") |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
[ |
|
|
"examples/demo_ch_woman_04.png", |
|
|
"examples/demo_ch_woman_04.WAV", |
|
|
"A Chinese woman is talking naturally.", |
|
|
"bad gestures, blurry, distorted face", |
|
|
42, 20, 4.5, 2.5, 25, 113, 8, 1.5, 2, True, True, "Flow_DPM++", 5.0, 1.0, False, True, 0.1, True, 5, False, 6 |
|
|
], |
|
|
[ |
|
|
"examples/guitar_woman_01.png", |
|
|
"examples/guitar_woman_01.WAV", |
|
|
"A woman with glasses is singing and playing the guitar.", |
|
|
"blurry, distorted face, bad hands", |
|
|
123, 25, 5.0, 2.8, 25, 113, 8, 1.5, 2, True, True, "Flow_DPM++", 5.0, 1.0, False, True, 0.1, True, 5, False, 6 |
|
|
], |
|
|
], |
|
|
inputs=all_inputs, |
|
|
outputs=[video_output, seed_output], |
|
|
fn=generate_video, |
|
|
cache_examples=True, |
|
|
label=None, |
|
|
) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown(""" |
|
|
### π How to Use |
|
|
1. **Upload Image:** Choose a clear portrait photo (front-facing works best). |
|
|
2. **Upload Audio:** Add an audio file with clear speech or singing. |
|
|
3. **Adjust Settings (Optional):** Fine-tune parameters in the advanced sections for different results. For memory issues, try lowering the "Partial Video Length". |
|
|
4. **Generate:** Click the button and wait for your talking video! |
|
|
|
|
|
**Note:** Generation time depends on settings and audio length. It can take a few minutes. |
|
|
|
|
|
This demo is based on the [EchoMimicV3 repository](https://github.com/antgroup/echomimic_v3). |
|
|
""") |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("π Starting model initialization...") |
|
|
models_ready = initialize_models() |
|
|
|
|
|
demo = create_demo() |
|
|
demo.launch(share=True) |