Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from config import Config | |
| from utils import resize_image_to_1mp, get_caption, draw_kps | |
| from PIL import Image | |
| class Generator: | |
| def __init__(self, model_handler): | |
| self.mh = model_handler | |
| def prepare_control_images(self, image, width, height): | |
| """ | |
| Generates conditioning maps, ensuring they are resized | |
| to the exact target dimensions (width, height). | |
| """ | |
| print(f"Generating control maps for {width}x{height}...") | |
| # Generate depth map | |
| depth_map_raw = self.mh.leres_detector(image) | |
| # Generate lineart map | |
| lineart_map_raw = self.mh.lineart_anime_detector(image) | |
| # Manually resize maps to match the exact output resolution | |
| depth_map = depth_map_raw.resize((width, height), Image.LANCZOS) | |
| lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS) | |
| return depth_map, lineart_map | |
| def predict( | |
| self, | |
| input_image, | |
| user_prompt="", | |
| negative_prompt="", | |
| guidance_scale=1.5, | |
| num_inference_steps=6, | |
| img2img_strength=0.3, | |
| face_strength=0.3, | |
| depth_strength=0.3, | |
| lineart_strength=0.3, | |
| seed=-1 | |
| ): | |
| # 1. Pre-process Inputs | |
| print("Processing Input...") | |
| processed_image = resize_image_to_1mp(input_image) | |
| target_width, target_height = processed_image.size | |
| # 2. Get Face Info (replaces get_face_embedding) | |
| face_info = self.mh.get_face_info(processed_image) | |
| # 3. Generate Prompt | |
| if not user_prompt.strip(): | |
| try: | |
| generated_caption = get_caption(processed_image) | |
| final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}" | |
| except Exception as e: | |
| print(f"Captioning failed: {e}, using default prompt.") | |
| final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful pixel art image" | |
| else: | |
| final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}" | |
| print(f"Prompt: {final_prompt}") | |
| print(f"Negative Prompt: {negative_prompt}") | |
| # 4. Generate OTHER Control Maps (Structure) | |
| print("Generating Control Maps (Depth, LineArt)...") | |
| depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height) | |
| # 5. Logic for Face vs No-Face (NOW INCLUDES KPS) | |
| # ControlNet order: [InstantID_KPS, Zoe_Depth, LineArt] | |
| if face_info is not None: | |
| print("Face detected: Applying InstantID with keypoints.") | |
| # We use face_info['embedding'] (raw) instead of normed_embedding. | |
| # Raw embedding has higher magnitude (~20-30) required for the adapter. | |
| face_emb = torch.tensor( | |
| face_info['embedding'], | |
| dtype=Config.DTYPE, | |
| device=Config.DEVICE | |
| ).unsqueeze(0) | |
| # Create keypoint image | |
| face_kps = draw_kps(processed_image, face_info['kps']) | |
| # Set strengths | |
| controlnet_conditioning_scale = [face_strength, depth_strength, lineart_strength] | |
| # --- UPDATED: Reduced IP Adapter Scale --- | |
| # Lowered from 0.8 to 0.7 to allow LoRA style (pixel art) to | |
| # override realistic skin textures while keeping identity. | |
| self.mh.pipeline.set_ip_adapter_scale(0.7) | |
| else: | |
| print("No face detected: Disabling InstantID.") | |
| # Create dummy embedding | |
| face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE) | |
| # Create dummy keypoint image (black) | |
| face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0)) | |
| # Set strengths | |
| controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength] | |
| self.mh.pipeline.set_ip_adapter_scale(0.0) | |
| # --- UPDATED: Control Guidance End Strategy --- | |
| # We cap the Face ControlNet duration. | |
| # Even if strength is 1.0, we stop it at 0.6 (60%) of the steps. | |
| # This leaves the final 40% of steps pure for the Pixel Art LoRA | |
| # to "pixelize" the face without the ControlNet trying to fix it back to a photo. | |
| face_end_step = min(0.6, face_strength) | |
| control_guidance_end = [ | |
| face_end_step, # InstantID: Stop early for style | |
| depth_strength, # Depth: Keep structure longer | |
| lineart_strength # Lineart: Keep outlines longer | |
| ] | |
| # --- Seed/Generator Logic --- | |
| if seed == -1 or seed is None: | |
| seed = torch.Generator().seed() | |
| generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed)) | |
| print(f"Using seed: {seed}") | |
| # --- END --- | |
| # 6. Run Inference | |
| print("Running pipeline...") | |
| result = self.mh.pipeline( | |
| prompt=final_prompt, | |
| negative_prompt=negative_prompt, | |
| image=processed_image, # Base img2img image | |
| control_image=[face_kps, depth_map, lineart_map], | |
| image_embeds=face_emb, # Face identity embedding | |
| generator=generator, | |
| # --- Parameters from UI --- | |
| strength=img2img_strength, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| # --- End Parameters from UI --- | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| control_guidance_end=control_guidance_end, | |
| clip_skip=Config.CLIP_SKIP, | |
| ).images[0] | |
| return result |