face-to-pixel-art / generator.py
primerz's picture
Update generator.py
0d4d25b verified
import torch
from config import Config
from utils import resize_image_to_1mp, get_caption, draw_kps
from PIL import Image
class Generator:
def __init__(self, model_handler):
self.mh = model_handler
def prepare_control_images(self, image, width, height):
"""
Generates conditioning maps, ensuring they are resized
to the exact target dimensions (width, height).
"""
print(f"Generating control maps for {width}x{height}...")
# Generate depth map
depth_map_raw = self.mh.leres_detector(image)
# Generate lineart map
lineart_map_raw = self.mh.lineart_anime_detector(image)
# Manually resize maps to match the exact output resolution
depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
return depth_map, lineart_map
def predict(
self,
input_image,
user_prompt="",
negative_prompt="",
guidance_scale=1.5,
num_inference_steps=6,
img2img_strength=0.3,
face_strength=0.3,
depth_strength=0.3,
lineart_strength=0.3,
seed=-1
):
# 1. Pre-process Inputs
print("Processing Input...")
processed_image = resize_image_to_1mp(input_image)
target_width, target_height = processed_image.size
# 2. Get Face Info (replaces get_face_embedding)
face_info = self.mh.get_face_info(processed_image)
# 3. Generate Prompt
if not user_prompt.strip():
try:
generated_caption = get_caption(processed_image)
final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
except Exception as e:
print(f"Captioning failed: {e}, using default prompt.")
final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful pixel art image"
else:
final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
print(f"Prompt: {final_prompt}")
print(f"Negative Prompt: {negative_prompt}")
# 4. Generate OTHER Control Maps (Structure)
print("Generating Control Maps (Depth, LineArt)...")
depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
# 5. Logic for Face vs No-Face (NOW INCLUDES KPS)
# ControlNet order: [InstantID_KPS, Zoe_Depth, LineArt]
if face_info is not None:
print("Face detected: Applying InstantID with keypoints.")
# We use face_info['embedding'] (raw) instead of normed_embedding.
# Raw embedding has higher magnitude (~20-30) required for the adapter.
face_emb = torch.tensor(
face_info['embedding'],
dtype=Config.DTYPE,
device=Config.DEVICE
).unsqueeze(0)
# Create keypoint image
face_kps = draw_kps(processed_image, face_info['kps'])
# Set strengths
controlnet_conditioning_scale = [face_strength, depth_strength, lineart_strength]
# --- UPDATED: Reduced IP Adapter Scale ---
# Lowered from 0.8 to 0.7 to allow LoRA style (pixel art) to
# override realistic skin textures while keeping identity.
self.mh.pipeline.set_ip_adapter_scale(0.7)
else:
print("No face detected: Disabling InstantID.")
# Create dummy embedding
face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
# Create dummy keypoint image (black)
face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
# Set strengths
controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength]
self.mh.pipeline.set_ip_adapter_scale(0.0)
# --- UPDATED: Control Guidance End Strategy ---
# We cap the Face ControlNet duration.
# Even if strength is 1.0, we stop it at 0.6 (60%) of the steps.
# This leaves the final 40% of steps pure for the Pixel Art LoRA
# to "pixelize" the face without the ControlNet trying to fix it back to a photo.
face_end_step = min(0.6, face_strength)
control_guidance_end = [
face_end_step, # InstantID: Stop early for style
depth_strength, # Depth: Keep structure longer
lineart_strength # Lineart: Keep outlines longer
]
# --- Seed/Generator Logic ---
if seed == -1 or seed is None:
seed = torch.Generator().seed()
generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
print(f"Using seed: {seed}")
# --- END ---
# 6. Run Inference
print("Running pipeline...")
result = self.mh.pipeline(
prompt=final_prompt,
negative_prompt=negative_prompt,
image=processed_image, # Base img2img image
control_image=[face_kps, depth_map, lineart_map],
image_embeds=face_emb, # Face identity embedding
generator=generator,
# --- Parameters from UI ---
strength=img2img_strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
# --- End Parameters from UI ---
controlnet_conditioning_scale=controlnet_conditioning_scale,
control_guidance_end=control_guidance_end,
clip_skip=Config.CLIP_SKIP,
).images[0]
return result