Spaces:
Running
on
Zero
Running
on
Zero
| # This file is modified from https://github.com/xdit-project/xDiT/blob/main/entrypoints/launch.py | |
| import base64 | |
| import gc | |
| import hashlib | |
| import io | |
| import os | |
| import tempfile | |
| from io import BytesIO | |
| import gradio as gr | |
| import requests | |
| import torch | |
| import torch.distributed as dist | |
| from fastapi import FastAPI, HTTPException | |
| from PIL import Image | |
| from .api import download_from_url, encode_file_to_base64 | |
| try: | |
| import ray | |
| except: | |
| print("Ray is not installed. If you want to use multi gpus api. Please install it by running 'pip install ray'.") | |
| ray = None | |
| def save_base64_video_dist(base64_string): | |
| video_data = base64.b64decode(base64_string) | |
| md5_hash = hashlib.md5(video_data).hexdigest() | |
| filename = f"{md5_hash}.mp4" | |
| temp_dir = tempfile.gettempdir() | |
| file_path = os.path.join(temp_dir, filename) | |
| if dist.is_initialized(): | |
| if dist.get_rank() == 0: | |
| with open(file_path, 'wb') as video_file: | |
| video_file.write(video_data) | |
| dist.barrier() | |
| else: | |
| with open(file_path, 'wb') as video_file: | |
| video_file.write(video_data) | |
| return file_path | |
| def save_base64_image_dist(base64_string): | |
| video_data = base64.b64decode(base64_string) | |
| md5_hash = hashlib.md5(video_data).hexdigest() | |
| filename = f"{md5_hash}.jpg" | |
| temp_dir = tempfile.gettempdir() | |
| file_path = os.path.join(temp_dir, filename) | |
| if dist.is_initialized(): | |
| if dist.get_rank() == 0: | |
| with open(file_path, 'wb') as video_file: | |
| video_file.write(video_data) | |
| dist.barrier() | |
| else: | |
| with open(file_path, 'wb') as video_file: | |
| video_file.write(video_data) | |
| return file_path | |
| def save_url_video_dist(url): | |
| video_data = download_from_url(url) | |
| if video_data: | |
| return save_base64_video_dist(base64.b64encode(video_data)) | |
| return None | |
| def save_url_image_dist(url): | |
| image_data = download_from_url(url) | |
| if image_data: | |
| return save_base64_image_dist(base64.b64encode(image_data)) | |
| return None | |
| if ray is not None: | |
| class MultiNodesGenerator: | |
| def __init__( | |
| self, rank: int, world_size: int, Controller, | |
| GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", | |
| config_path=None, ulysses_degree=1, ring_degree=1, | |
| fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False, | |
| weight_dtype=None, savedir_sample=None, | |
| ): | |
| # Set PyTorch distributed environment variables | |
| os.environ["RANK"] = str(rank) | |
| os.environ["WORLD_SIZE"] = str(world_size) | |
| os.environ["MASTER_ADDR"] = "127.0.0.1" | |
| os.environ["MASTER_PORT"] = "29500" | |
| self.rank = rank | |
| self.controller = Controller( | |
| GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path, | |
| ulysses_degree=ulysses_degree, ring_degree=ring_degree, | |
| fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit, | |
| weight_dtype=weight_dtype, savedir_sample=savedir_sample, | |
| ) | |
| def generate(self, datas): | |
| try: | |
| base_model_path = datas.get('base_model_path', 'none') | |
| base_model_2_path = datas.get('base_model_2_path', 'none') | |
| lora_model_path = datas.get('lora_model_path', 'none') | |
| lora_model_2_path = datas.get('lora_model_2_path', 'none') | |
| lora_alpha_slider = datas.get('lora_alpha_slider', 0.55) | |
| prompt_textbox = datas.get('prompt_textbox', None) | |
| negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ') | |
| sampler_dropdown = datas.get('sampler_dropdown', 'Euler') | |
| sample_step_slider = datas.get('sample_step_slider', 30) | |
| resize_method = datas.get('resize_method', "Generate by") | |
| width_slider = datas.get('width_slider', 672) | |
| height_slider = datas.get('height_slider', 384) | |
| base_resolution = datas.get('base_resolution', 512) | |
| is_image = datas.get('is_image', False) | |
| generation_method = datas.get('generation_method', False) | |
| length_slider = datas.get('length_slider', 49) | |
| overlap_video_length = datas.get('overlap_video_length', 4) | |
| partial_video_length = datas.get('partial_video_length', 72) | |
| cfg_scale_slider = datas.get('cfg_scale_slider', 6) | |
| start_image = datas.get('start_image', None) | |
| end_image = datas.get('end_image', None) | |
| validation_video = datas.get('validation_video', None) | |
| validation_video_mask = datas.get('validation_video_mask', None) | |
| control_video = datas.get('control_video', None) | |
| denoise_strength = datas.get('denoise_strength', 0.70) | |
| seed_textbox = datas.get("seed_textbox", 43) | |
| ref_image = datas.get('ref_image', None) | |
| enable_teacache = datas.get('enable_teacache', True) | |
| teacache_threshold = datas.get('teacache_threshold', 0.10) | |
| num_skip_start_steps = datas.get('num_skip_start_steps', 1) | |
| teacache_offload = datas.get('teacache_offload', False) | |
| cfg_skip_ratio = datas.get('cfg_skip_ratio', 0) | |
| enable_riflex = datas.get('enable_riflex', False) | |
| riflex_k = datas.get('riflex_k', 6) | |
| fps = datas.get('fps', None) | |
| generation_method = "Image Generation" if is_image else generation_method | |
| if start_image is not None: | |
| if start_image.startswith('http'): | |
| start_image = save_url_image_dist(start_image) | |
| start_image = [Image.open(start_image).convert("RGB")] | |
| else: | |
| start_image = base64.b64decode(start_image) | |
| start_image = [Image.open(BytesIO(start_image)).convert("RGB")] | |
| if end_image is not None: | |
| if end_image.startswith('http'): | |
| end_image = save_url_image_dist(end_image) | |
| end_image = [Image.open(end_image).convert("RGB")] | |
| else: | |
| end_image = base64.b64decode(end_image) | |
| end_image = [Image.open(BytesIO(end_image)).convert("RGB")] | |
| if validation_video is not None: | |
| if validation_video.startswith('http'): | |
| validation_video = save_url_video_dist(validation_video) | |
| else: | |
| validation_video = save_base64_video_dist(validation_video) | |
| if validation_video_mask is not None: | |
| if validation_video_mask.startswith('http'): | |
| validation_video_mask = save_url_image_dist(validation_video_mask) | |
| else: | |
| validation_video_mask = save_base64_image_dist(validation_video_mask) | |
| if control_video is not None: | |
| if control_video.startswith('http'): | |
| control_video = save_url_video_dist(control_video) | |
| else: | |
| control_video = save_base64_video_dist(control_video) | |
| if ref_image is not None: | |
| if ref_image.startswith('http'): | |
| ref_image = save_url_image_dist(ref_image) | |
| ref_image = [Image.open(ref_image).convert("RGB")] | |
| else: | |
| ref_image = base64.b64decode(ref_image) | |
| ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")] | |
| try: | |
| save_sample_path, comment = self.controller.generate( | |
| "", | |
| base_model_path, | |
| lora_model_path, | |
| lora_alpha_slider, | |
| prompt_textbox, | |
| negative_prompt_textbox, | |
| sampler_dropdown, | |
| sample_step_slider, | |
| resize_method, | |
| width_slider, | |
| height_slider, | |
| base_resolution, | |
| generation_method, | |
| length_slider, | |
| overlap_video_length, | |
| partial_video_length, | |
| cfg_scale_slider, | |
| start_image, | |
| end_image, | |
| validation_video, | |
| validation_video_mask, | |
| control_video, | |
| denoise_strength, | |
| seed_textbox, | |
| ref_image = ref_image, | |
| enable_teacache = enable_teacache, | |
| teacache_threshold = teacache_threshold, | |
| num_skip_start_steps = num_skip_start_steps, | |
| teacache_offload = teacache_offload, | |
| cfg_skip_ratio = cfg_skip_ratio, | |
| enable_riflex = enable_riflex, | |
| riflex_k = riflex_k, | |
| base_model_2_dropdown = base_model_2_path, | |
| lora_model_2_dropdown = lora_model_2_path, | |
| fps = fps, | |
| is_api = True, | |
| ) | |
| except Exception as e: | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| save_sample_path = "" | |
| comment = f"Error. error information is {str(e)}" | |
| if dist.is_initialized(): | |
| if dist.get_rank() == 0: | |
| return {"message": comment, "save_sample_path": None, "base64_encoding": None} | |
| else: | |
| return None | |
| else: | |
| return {"message": comment, "save_sample_path": None, "base64_encoding": None} | |
| if dist.is_initialized(): | |
| if dist.get_rank() == 0: | |
| if save_sample_path != "": | |
| return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)} | |
| else: | |
| return {"message": comment, "save_sample_path": None, "base64_encoding": None} | |
| else: | |
| return None | |
| else: | |
| if save_sample_path != "": | |
| return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)} | |
| else: | |
| return {"message": comment, "save_sample_path": None, "base64_encoding": None} | |
| except Exception as e: | |
| print(f"Error generating: {str(e)}") | |
| comment = f"Error generating: {str(e)}" | |
| if dist.is_initialized(): | |
| if dist.get_rank() == 0: | |
| return {"message": comment, "save_sample_path": None, "base64_encoding": None} | |
| else: | |
| return None | |
| else: | |
| return {"message": comment, "save_sample_path": None, "base64_encoding": None} | |
| class MultiNodesEngine: | |
| def __init__( | |
| self, | |
| world_size, | |
| Controller, | |
| GPU_memory_mode, | |
| scheduler_dict, | |
| model_name, | |
| model_type, | |
| config_path, | |
| ulysses_degree=1, | |
| ring_degree=1, | |
| fsdp_dit=False, | |
| fsdp_text_encoder=False, | |
| compile_dit=False, | |
| weight_dtype=torch.bfloat16, | |
| savedir_sample="samples" | |
| ): | |
| # Ensure Ray is initialized | |
| if not ray.is_initialized(): | |
| ray.init() | |
| num_workers = world_size | |
| self.workers = [ | |
| MultiNodesGenerator.remote( | |
| rank, world_size, Controller, | |
| GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path, | |
| ulysses_degree=ulysses_degree, ring_degree=ring_degree, | |
| fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit, | |
| weight_dtype=weight_dtype, savedir_sample=savedir_sample, | |
| ) | |
| for rank in range(num_workers) | |
| ] | |
| print("Update workers done") | |
| async def generate(self, data): | |
| results = ray.get([ | |
| worker.generate.remote(data) | |
| for worker in self.workers | |
| ]) | |
| return next(path for path in results if path is not None) | |
| def multi_nodes_infer_forward_api(_: gr.Blocks, app: FastAPI, engine): | |
| async def _multi_nodes_infer_forward_api( | |
| datas: dict, | |
| ): | |
| try: | |
| result = await engine.generate(datas) | |
| return result | |
| except Exception as e: | |
| if isinstance(e, HTTPException): | |
| raise e | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| else: | |
| MultiNodesEngine = None | |
| MultiNodesGenerator = None | |
| multi_nodes_infer_forward_api = None |