Spaces:
Runtime error
Runtime error
| #coding=utf-8 | |
| import logging | |
| import os | |
| from pathlib import Path | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| import os | |
| from pathlib import Path | |
| import soundfile as sf | |
| import torch | |
| import torchvision | |
| from huggingface_hub import snapshot_download | |
| from moviepy.editor import AudioFileClip, VideoFileClip | |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
| from third_party.FoleyCrafter.foleycrafter.models.onset import torch_utils | |
| from third_party.FoleyCrafter.foleycrafter.models.time_detector.model import VideoOnsetNet | |
| from third_party.FoleyCrafter.foleycrafter.pipelines.auffusion_pipeline import Generator, denormalize_spectrogram | |
| from third_party.FoleyCrafter.foleycrafter.utils.util import build_foleycrafter, read_frames_with_moviepy | |
| vision_transform_list = [ | |
| torchvision.transforms.Resize((128, 128)), | |
| torchvision.transforms.CenterCrop((112, 112)), | |
| torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| video_transform = torchvision.transforms.Compose(vision_transform_list) | |
| model_base_dir = "pretrained/v2a/foleycrafter" | |
| class V2A_FoleyCrafter: | |
| def __init__(self, | |
| pretrained_model_name_or_path: str=f"{model_base_dir}/checkpoints/auffusion", | |
| ckpt: str=f"{model_base_dir}/checkpoints",): | |
| self.log = logging.getLogger(self.__class__.__name__) | |
| self.log.setLevel(logging.INFO) | |
| self.log.info(f"The V2A model uses FoleyCrafter, init...") | |
| self.device = 'cpu' | |
| if torch.cuda.is_available(): | |
| self.device = 'cuda' | |
| elif torch.backends.mps.is_available(): | |
| self.device = 'mps' | |
| else: | |
| self.log.warning('CUDA/MPS are not available, running on CPU') | |
| # download ckpt | |
| if not os.path.isdir(pretrained_model_name_or_path): | |
| pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path) | |
| # ckpt path | |
| temporal_ckpt_path = os.path.join(ckpt, "temporal_adapter.ckpt") | |
| # load vocoder | |
| self.vocoder = Generator.from_pretrained(ckpt, subfolder="vocoder").to(self.device) | |
| # load time_detector | |
| time_detector_ckpt = os.path.join(ckpt, "timestamp_detector.pth.tar") | |
| self.time_detector = VideoOnsetNet(False) | |
| self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, self.time_detector, device=self.device, strict=True) | |
| # load adapters | |
| self.pipe = build_foleycrafter().to(self.device) | |
| ckpt = torch.load(temporal_ckpt_path) | |
| # load temporal adapter | |
| if "state_dict" in ckpt.keys(): | |
| ckpt = ckpt["state_dict"] | |
| load_gligen_ckpt = {} | |
| for key, value in ckpt.items(): | |
| if key.startswith("module."): | |
| load_gligen_ckpt[key[len("module.") :]] = value | |
| else: | |
| load_gligen_ckpt[key] = value | |
| m, u = self.pipe.controlnet.load_state_dict(load_gligen_ckpt, strict=False) | |
| print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};") | |
| # load semantic adapter | |
| self.pipe.load_ip_adapter( | |
| os.path.join(ckpt, "semantic"), subfolder="", weight_name="semantic_adapter.bin", image_encoder_folder=None | |
| ) | |
| # ip_adapter_weight = semantic_scale | |
| # self.pipe.set_ip_adapter_scale(ip_adapter_weight) | |
| self.generator = torch.Generator(device=self.device) | |
| # self.generator.manual_seed(seed) | |
| self.image_processor = CLIPImageProcessor() | |
| self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| "h94/IP-Adapter", subfolder="models/image_encoder" | |
| ).to(self.device) | |
| def generate_audio(self, | |
| video_path, | |
| output_dir, | |
| prompt: str='', | |
| negative_prompt: str='', | |
| seed: int=42, | |
| temporal_scale: float=0.2, | |
| semantic_scale: float=1.0, | |
| is_postp=False,): | |
| self.pipe.set_ip_adapter_scale(semantic_scale) | |
| self.generator.manual_seed(seed) | |
| video_path = Path(video_path).expanduser() | |
| output_dir = Path(output_dir).expanduser() | |
| self.log.info(f"Loading video: {video_path}") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| frames, duration = read_frames_with_moviepy(video_path, max_frame_nums=150) | |
| time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2) | |
| time_frames = video_transform(time_frames) | |
| time_frames = {"frames": time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)} | |
| preds = self.time_detector(time_frames) | |
| preds = torch.sigmoid(preds) | |
| time_condition = [ | |
| -1 if preds[0][int(i / (1024 / 10 * duration) * 150)] < 0.5 else 1 | |
| for i in range(int(1024 / 10 * duration)) | |
| ] | |
| time_condition = time_condition + [-1] * (1024 - len(time_condition)) | |
| # w -> b c h w | |
| time_condition = ( | |
| torch.FloatTensor(time_condition) | |
| .unsqueeze(0) | |
| .unsqueeze(0) | |
| .unsqueeze(0) | |
| .repeat(1, 1, 256, 1) | |
| .to("cuda") | |
| ) | |
| images = self.image_processor(images=frames, return_tensors="pt").to("cuda") | |
| image_embeddings = self.image_encoder(**images).image_embeds | |
| image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0) | |
| neg_image_embeddings = torch.zeros_like(image_embeddings) | |
| image_embeddings = torch.cat([neg_image_embeddings, image_embeddings], dim=1) | |
| sample = self.pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| ip_adapter_image_embeds=image_embeddings, | |
| image=time_condition, | |
| controlnet_conditioning_scale=temporal_scale, | |
| num_inference_steps=25, | |
| height=256, | |
| width=1024, | |
| output_type="pt", | |
| generator=self.generator, | |
| ) | |
| audio_img = sample.images[0] | |
| audio = denormalize_spectrogram(audio_img) | |
| audio = self.vocoder.inference(audio, lengths=160000)[0] | |
| audio = audio[: int(duration * 16000)] | |
| if is_postp: | |
| audio_save_path = output_dir / f'{video_path.stem}.neg.wav' | |
| video_save_path = output_dir / f'{video_path.stem}.neg.mp4' | |
| else: | |
| audio_save_path = output_dir / f'{video_path.stem}.step1.wav' | |
| video_save_path = output_dir / f'{video_path.stem}.step1.mp4' | |
| self.log.info(f"Saving generated audio and video to {output_dir}") | |
| sf.write(audio_save_path, audio, 16000) | |
| audio = AudioFileClip(audio_save_path) | |
| video = VideoFileClip(video_path) | |
| duration = min(audio.duration, video.duration) | |
| audio = audio.subclip(0, duration) | |
| video.audio = audio | |
| video = video.subclip(0, duration) | |
| video.write_videofile(video_save_path) | |
| self.log.info(f'Video saved to {video_save_path}') | |
| return audio_save_path, video_save_path | |