File size: 5,818 Bytes
911dcd6
 
3885620
a910636
911dcd6
 
 
 
 
ff014fd
460592a
ff014fd
 
460592a
ff014fd
3885620
 
ff014fd
3885620
 
ff014fd
3885620
 
60bf1c5
ff014fd
3885620
ff014fd
3e3e641
31c79b1
 
 
 
f389872
3885620
 
 
27381b4
3885620
 
ff014fd
31c79b1
3885620
911dcd6
3885620
60bf1c5
 
3885620
589234e
911dcd6
 
 
5a9aef6
c82ccd6
 
5a9aef6
 
3885620
911dcd6
 
 
 
cb173bd
911dcd6
3885620
ff014fd
 
911dcd6
3885620
 
62e516c
589234e
ff014fd
3885620
f3238f2
 
963056d
f3238f2
 
963056d
 
3885620
 
589234e
0d4d25b
 
27381b4
0d4d25b
 
 
 
 
911dcd6
ff014fd
3885620
a910636
3885620
589234e
3885620
 
ff014fd
589234e
 
0d4d25b
 
 
 
 
 
 
 
 
 
 
 
 
ff014fd
3885620
069fe14
 
 
 
3885620
069fe14
ff014fd
 
911dcd6
 
f389872
3885620
ff014fd
3885620
f389872
60bf1c5
3885620
5cf276c
ff014fd
3885620
 
31c79b1
ff014fd
911dcd6
3885620
d2e1bcd
228348f
911dcd6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from config import Config
from utils import resize_image_to_1mp, get_caption, draw_kps
from PIL import Image

class Generator:
    def __init__(self, model_handler):
        self.mh = model_handler

    def prepare_control_images(self, image, width, height):
        """
        Generates conditioning maps, ensuring they are resized
        to the exact target dimensions (width, height).
        """
        print(f"Generating control maps for {width}x{height}...")
        
        # Generate depth map
        depth_map_raw = self.mh.leres_detector(image) 
        
        # Generate lineart map
        lineart_map_raw = self.mh.lineart_anime_detector(image)
        
        # Manually resize maps to match the exact output resolution
        depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
        lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
        
        return depth_map, lineart_map

    def predict(
        self, 
        input_image, 
        user_prompt="",
        negative_prompt="",
        guidance_scale=1.5,
        num_inference_steps=6,
        img2img_strength=0.3,
        face_strength=0.3,
        depth_strength=0.3,
        lineart_strength=0.3,
        seed=-1
    ):
        # 1. Pre-process Inputs
        print("Processing Input...")
        processed_image = resize_image_to_1mp(input_image)
        target_width, target_height = processed_image.size
        
        # 2. Get Face Info (replaces get_face_embedding)
        face_info = self.mh.get_face_info(processed_image)
        
        # 3. Generate Prompt
        if not user_prompt.strip():
            try:
                generated_caption = get_caption(processed_image)
                final_prompt = f"{Config.STYLE_TRIGGER}, {generated_caption}"
            except Exception as e:
                print(f"Captioning failed: {e}, using default prompt.")
                final_prompt = f"{Config.STYLE_TRIGGER}, a beautiful pixel art image"
        else:
            final_prompt = f"{Config.STYLE_TRIGGER}, {user_prompt}"
            
        print(f"Prompt: {final_prompt}")
        print(f"Negative Prompt: {negative_prompt}")

        # 4. Generate OTHER Control Maps (Structure)
        print("Generating Control Maps (Depth, LineArt)...")
        depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
        
        # 5. Logic for Face vs No-Face (NOW INCLUDES KPS)
        # ControlNet order: [InstantID_KPS, Zoe_Depth, LineArt]
        
        if face_info is not None:
            print("Face detected: Applying InstantID with keypoints.")
            
            # We use face_info['embedding'] (raw) instead of normed_embedding.
            # Raw embedding has higher magnitude (~20-30) required for the adapter.
            face_emb = torch.tensor(
                face_info['embedding'], 
                dtype=Config.DTYPE,
                device=Config.DEVICE
            ).unsqueeze(0)

            # Create keypoint image
            face_kps = draw_kps(processed_image, face_info['kps'])
            
            # Set strengths
            controlnet_conditioning_scale = [face_strength, depth_strength, lineart_strength] 
            
            # --- UPDATED: Reduced IP Adapter Scale ---
            # Lowered from 0.8 to 0.7 to allow LoRA style (pixel art) to 
            # override realistic skin textures while keeping identity.
            self.mh.pipeline.set_ip_adapter_scale(0.7)
        else:
            print("No face detected: Disabling InstantID.")
            # Create dummy embedding
            face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
            # Create dummy keypoint image (black)
            face_kps = Image.new('RGB', (target_width, target_height), (0, 0, 0))
            
            # Set strengths
            controlnet_conditioning_scale = [0.0, depth_strength, lineart_strength] 
            self.mh.pipeline.set_ip_adapter_scale(0.0)

        # --- UPDATED: Control Guidance End Strategy ---
        # We cap the Face ControlNet duration. 
        # Even if strength is 1.0, we stop it at 0.6 (60%) of the steps.
        # This leaves the final 40% of steps pure for the Pixel Art LoRA 
        # to "pixelize" the face without the ControlNet trying to fix it back to a photo.
        
        face_end_step = min(0.6, face_strength)
        
        control_guidance_end = [
            face_end_step,      # InstantID: Stop early for style
            depth_strength,     # Depth: Keep structure longer
            lineart_strength    # Lineart: Keep outlines longer
        ] 

        # --- Seed/Generator Logic ---
        if seed == -1 or seed is None:
            seed = torch.Generator().seed()
        generator = torch.Generator(device=Config.DEVICE).manual_seed(int(seed))
        print(f"Using seed: {seed}")
        # --- END ---

        # 6. Run Inference
        print("Running pipeline...")
        result = self.mh.pipeline(
            prompt=final_prompt,
            negative_prompt=negative_prompt,
            image=processed_image,  # Base img2img image
            control_image=[face_kps, depth_map, lineart_map],
            image_embeds=face_emb,  # Face identity embedding
            generator=generator,
            
            # --- Parameters from UI ---
            strength=img2img_strength,
            num_inference_steps=num_inference_steps, 
            guidance_scale=guidance_scale,
            # --- End Parameters from UI ---
            
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            control_guidance_end=control_guidance_end,
            
            clip_skip=Config.CLIP_SKIP,
            
        ).images[0]
        
        return result