Spaces:
Running
on
Zero
Running
on
Zero
| import gc | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import spaces | |
| import random | |
| import utils | |
| import logging | |
| from PIL import Image | |
| from diffusers.models import AutoencoderKL | |
| from diffusers import StableDiffusionXLImg2ImgPipeline | |
| from config import ( | |
| MODEL, | |
| MIN_IMAGE_SIZE, | |
| MAX_IMAGE_SIZE, | |
| DEFAULT_PROMPT, | |
| DEFAULT_NEGATIVE_PROMPT, | |
| scheduler_list, | |
| ) | |
| from transformers import AutoProcessor, AutoModelForImageClassification | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # Enhanced logging configuration | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # PyTorch settings for better performance and determinism | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # Model initialization | |
| if torch.cuda.is_available(): | |
| try: | |
| logger.info("Loading VAE and pipeline...") | |
| vae = AutoencoderKL.from_pretrained( | |
| "madebyollin/sdxl-vae-fp16-fix", | |
| torch_dtype=torch.float16, | |
| ) | |
| pipe = utils.load_pipeline(MODEL, device, vae=vae) | |
| logger.info("Pipeline loaded successfully on GPU!") | |
| except Exception as e: | |
| logger.error(f"Error loading VAE, falling back to default: {e}") | |
| pipe = utils.load_pipeline(MODEL, device) | |
| else: | |
| logger.warning("CUDA not available, running on CPU") | |
| pipe = None | |
| # -------------------- NSFW 检测模型加载 -------------------- | |
| try: | |
| logger.info("Loading NSFW detector...") | |
| from transformers import AutoProcessor, AutoModelForImageClassification | |
| nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection") | |
| nsfw_model = AutoModelForImageClassification.from_pretrained( | |
| "Falconsai/nsfw_image_detection" | |
| ).to(device) | |
| logger.info("NSFW detector loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to load NSFW detector: {e}") | |
| nsfw_model = None | |
| nsfw_processor = None | |
| # ----------------------------------------------------------- | |
| class GenerationError(Exception): | |
| """Custom exception for generation errors""" | |
| pass | |
| def validate_prompt(prompt: str) -> str: | |
| """Validate and clean up the input prompt.""" | |
| if not isinstance(prompt, str): | |
| raise GenerationError("Prompt must be a string") | |
| try: | |
| # Ensure proper UTF-8 encoding/decoding | |
| prompt = prompt.encode('utf-8').decode('utf-8') | |
| # Add space between ! and , | |
| prompt = prompt.replace("!,", "! ,") | |
| except UnicodeError: | |
| raise GenerationError("Invalid characters in prompt") | |
| # Only check if the prompt is completely empty or only whitespace | |
| if not prompt or prompt.isspace(): | |
| raise GenerationError("Prompt cannot be empty") | |
| return prompt.strip() | |
| def validate_dimensions(width: int, height: int) -> None: | |
| """Validate image dimensions.""" | |
| if not MIN_IMAGE_SIZE <= width <= MAX_IMAGE_SIZE: | |
| raise GenerationError(f"Width must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}") | |
| if not MIN_IMAGE_SIZE <= height <= MAX_IMAGE_SIZE: | |
| raise GenerationError(f"Height must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}") | |
| def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool: | |
| """Returns True if image is NSFW""" | |
| inputs = nsfw_processor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = nsfw_model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| nsfw_score = probs[0][1].item() # label 1 = NSFW | |
| return nsfw_score > threshold | |
| progress=gr.Progress() | |
| def _generate_on_gpu( | |
| prompt: str, | |
| negative_prompt: str, | |
| width: int, | |
| height: int, | |
| scheduler: str, | |
| opt_strength:float, | |
| opt_scale:float, | |
| seed: int, | |
| randomize_seed: bool, | |
| guidance_scale: float, | |
| num_inference_steps: int | |
| ): | |
| progress(0,desc="Starting") | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| """Generate images based on the given parameters.""" | |
| upscaler_pipe = None | |
| backup_scheduler = None | |
| def callback1(pipe, step, timestep, callback_kwargs): | |
| progress_value = 0.1 + ((step+1.0)/num_inference_steps)*(0.5/1.0) | |
| progress(progress_value, desc=f"Image generating, {step + 1}/{num_inference_steps} steps") | |
| return callback_kwargs | |
| optimizing_steps = int(num_inference_steps * opt_strength) | |
| def callback2(pipe, step, timestep, callback_kwargs): | |
| progress_value = 0.6 + ((step+1.0)/optimizing_steps)*(0.4/1.0) | |
| progress(progress_value, desc=f"Image optimizing, {step + 1}/{optimizing_steps} steps") | |
| return callback_kwargs | |
| try: | |
| # Memory management | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Input validation | |
| prompt = validate_prompt(prompt) | |
| if negative_prompt: | |
| negative_prompt = negative_prompt.encode('utf-8').decode('utf-8') | |
| validate_dimensions(width, height) | |
| # Set up generation | |
| generator = utils.seed_everything(seed) | |
| width, height = utils.preprocess_image_dimensions(width, height) | |
| # Set up pipeline | |
| backup_scheduler = pipe.scheduler | |
| pipe.scheduler = utils.get_scheduler(pipe.scheduler.config, scheduler) | |
| upscaler_pipe = StableDiffusionXLImg2ImgPipeline(**pipe.components) | |
| progress(0.1,desc="Image generating") | |
| latents = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| width=width, | |
| height=height, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| output_type="latent", | |
| callback_on_step_end=callback1 | |
| ).images | |
| upscaled_latents = utils.upscale(latents, "nearest-exact", opt_scale) | |
| images = upscaler_pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| image=upscaled_latents, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| strength=opt_strength, | |
| generator=generator, | |
| output_type="pil", | |
| callback_on_step_end=callback2 | |
| ).images | |
| out_img = images[0] | |
| # NSFW 检测 | |
| if nsfw_model and nsfw_processor: | |
| if detect_nsfw(out_img): | |
| msg = "Generated image contains NSFW content and cannot be displayed. Please modify your prompt and try again." | |
| raise Exception(msg) | |
| path = utils.save_image(out_img, "./outputs") | |
| logger.info(f"output path: {path}") | |
| progress(1, desc="Complete") | |
| info = { | |
| "status": "success" | |
| } | |
| return path, info | |
| except GenerationError as e: | |
| error_info = { | |
| "error": str(e), | |
| "status": "failed", | |
| } | |
| return None, error_info | |
| except Exception as e: | |
| error_info = { | |
| "error": str(e), | |
| "status": "failed", | |
| } | |
| return None, error_info | |
| finally: | |
| # Cleanup | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| if upscaler_pipe is not None: | |
| del upscaler_pipe | |
| if backup_scheduler is not None and pipe is not None: | |
| pipe.scheduler = backup_scheduler | |
| utils.free_memory() | |
| def generate( | |
| prompt: str, | |
| negative_prompt: str, | |
| width: int, | |
| height: int, | |
| scheduler: str, | |
| opt_strength: float, | |
| opt_scale: float, | |
| seed: int, | |
| randomize_seed: bool, | |
| guidance_scale: float, | |
| num_inference_steps: int, | |
| ): | |
| # 调用 GPU 函数 | |
| image_path, info = _generate_on_gpu( | |
| prompt, negative_prompt, | |
| width, height, | |
| scheduler, | |
| opt_strength, opt_scale, | |
| seed, randomize_seed, | |
| guidance_scale, num_inference_steps, | |
| ) | |
| # 如果出错,抛出异常 | |
| if info["status"] == "failed": | |
| raise gr.Error(info["error"]) | |
| # 返回图片路径 | |
| return image_path | |
| title = "# Anime AI Generator" | |
| description = "Our AI-Powered Anime Generator turns your ideas into breathtaking AI anime art—perfect for art, storytelling, or personal AI anime wallpaper. Experience more at [Anime AI Generator](https://www.animeaigen.com)." | |
| custom_css = """ | |
| """ | |
| with gr.Blocks(css=custom_css).queue() as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| with gr.Row( | |
| elem_id="row-container" | |
| ): | |
| with gr.Column(): | |
| gr.Markdown("### Input") | |
| with gr.Column(): | |
| prompt = gr.Text( | |
| label="Prompt", | |
| max_lines=5, | |
| placeholder="Enter your prompt", | |
| value=DEFAULT_PROMPT, | |
| ) | |
| negative_prompt = gr.Text( | |
| label="Negative prompt", | |
| max_lines=5, | |
| placeholder="Enter a negative prompt", | |
| value=DEFAULT_NEGATIVE_PROMPT, | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=MIN_IMAGE_SIZE, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=8, | |
| value=832, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=MIN_IMAGE_SIZE, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=8, | |
| value=1216, | |
| ) | |
| with gr.Row(): | |
| optimization_strength = gr.Slider( | |
| label="Optimization strength", | |
| minimum=0, | |
| maximum=1, | |
| step=0.05, | |
| value=0.55, | |
| ) | |
| optimization_scale = gr.Slider( | |
| label="Optimization scale ratio", | |
| minimum=1, | |
| maximum=1.5, | |
| step=0.1, | |
| value=1.5, | |
| ) | |
| with gr.Column(): | |
| scheduler = gr.Dropdown( | |
| label="scheduler", | |
| choices=scheduler_list, | |
| interactive=True, | |
| value="Euler a", | |
| ) | |
| with gr.Column(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale", | |
| minimum=1.0, | |
| maximum=12.0, | |
| step=0.1, | |
| value=6.0, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=25, | |
| ) | |
| run_button = gr.Button("Run", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("### Output") | |
| result = gr.Image( | |
| type="filepath", | |
| label="Generated Image", | |
| elem_id="output-image" | |
| ) | |
| run_button.click( | |
| fn=generate, | |
| inputs=[ | |
| prompt, negative_prompt, | |
| width, height, | |
| scheduler, | |
| optimization_strength,optimization_scale, | |
| seed,randomize_seed, | |
| guidance_scale,num_inference_steps | |
| ], | |
| outputs=[result], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |