File size: 7,974 Bytes
2dddf31
6d5987b
 
 
 
 
2dddf31
 
3885620
 
2dddf31
cc0ae1f
61380fb
acd970e
 
 
a92aea4
2dddf31
ea17e03
2dddf31
 
 
 
ea17e03
4a72459
 
6d5987b
 
 
 
 
948869c
6d5987b
 
2dddf31
948869c
 
 
 
6d5987b
 
e01d167
ea17e03
6d5987b
 
 
 
 
 
 
948869c
 
3ae9ca7
6d5987b
 
 
 
 
 
948869c
6d5987b
 
ea17e03
6d5987b
 
2dddf31
ea17e03
 
3885620
 
 
aecb45b
65a7aea
 
2dddf31
 
3885620
 
 
 
ea17e03
 
8333ca9
3885620
61380fb
ea17e03
61380fb
3885620
76e0564
3885620
6d5987b
acd970e
4089031
 
 
 
 
 
 
94327de
4089031
 
 
acd970e
65a7aea
 
2dddf31
6d5987b
e01d167
4089031
e01d167
2dddf31
16dc50a
 
0117fa7
16dc50a
 
2e66212
3885620
 
40e76ba
66cd9e8
 
 
 
45890f7
 
66cd9e8
 
 
3885620
 
 
5e35e8b
 
3620e60
5e35e8b
3885620
fa327ca
5e35e8b
 
fa327ca
 
 
3885620
 
3620e60
5e35e8b
ea17e03
3620e60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3885620
3620e60
3885620
 
ea17e03
4a72459
ea17e03
2dddf31
6d5987b
 
e4dd0ff
 
6d5987b
 
3885620
6d5987b
76e0564
6d5987b
3885620
6d5987b
 
3885620
 
ea17e03
3885620
 
e4dd0ff
6d5987b
 
ea17e03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import torch
import os
import cv2
import numpy as np
from config import Config

from diffusers import (
    ControlNetModel, 
    LCMScheduler,
    # AutoencoderKL # Removed as requested
)
from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel

# Import the custom pipeline from your local file
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline

from huggingface_hub import snapshot_download, hf_hub_download
from insightface.app import FaceAnalysis
from controlnet_aux import LeresDetector, LineartAnimeDetector

class ModelHandler:
    def __init__(self):
        self.pipeline = None
        self.app = None # InsightFace
        self.leres_detector = None
        self.lineart_anime_detector = None
        self.face_analysis_loaded = False

    def load_face_analysis(self):
        """
        Load face analysis model. 
        Downloads from HF Hub to the path insightface expects.
        """
        print("Loading face analysis model...")
        
        model_path = os.path.join(Config.ANTELOPEV2_ROOT, "models", Config.ANTELOPEV2_NAME)

        if not os.path.exists(os.path.join(model_path, "scrfd_10g_bnkps.onnx")):
            print(f"Downloading AntelopeV2 models from {Config.ANTELOPEV2_REPO} to {model_path}...")
            try:
                snapshot_download(
                    repo_id=Config.ANTELOPEV2_REPO,
                    local_dir=model_path, # Download to the correct expected path
                )
            except Exception as e:
                print(f"  [ERROR] Failed to download AntelopeV2 models: {e}")
                return False

        try:
            self.app = FaceAnalysis(
                name=Config.ANTELOPEV2_NAME,
                root=Config.ANTELOPEV2_ROOT, 
                providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] 
            )
            self.app.prepare(ctx_id=0, det_size=(640, 640))
            print(f"  [OK] Face analysis model loaded successfully.")
            return True
            
        except Exception as e:
            print(f"  [WARNING] Face detection system failed to initialize: {e}") 
            return False

    def load_models(self):
        # 1. Load Face Analysis
        self.face_analysis_loaded = self.load_face_analysis()

        # 2. Load ControlNets
        print("Loading ControlNets (InstantID, Zoe, LineArt)...")

        # Load the InstantID ControlNet from the correct subfolder
        print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
        cn_instantid = ControlNetModel.from_pretrained(
            Config.INSTANTID_REPO,
            subfolder="ControlNetModel",
            torch_dtype=Config.DTYPE
        )
        print("  [OK] Loaded InstantID ControlNet.")
        
        # Load other ControlNets normally
        print("Loading Zoe and LineArt ControlNets...")
        cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
        cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)

        # --- Manually wrap the list of models in a MultiControlNetModel ---
        print("Wrapping ControlNets in MultiControlNetModel...")
        controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
        controlnet = MultiControlNetModel(controlnet_list)
        # --- End wrapping ---
        
        # 3. Load SDXL Pipeline
        print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
        
        checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
        if not os.path.exists(checkpoint_local_path):
            print(f"Downloading checkpoint to {checkpoint_local_path}...")
            hf_hub_download(
                repo_id=Config.REPO_ID,
                filename=Config.CHECKPOINT_FILENAME,
                local_dir="./models",
                local_dir_use_symlinks=False
            )
        
        print(f"Loading pipeline from local file: {checkpoint_local_path}")
        self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
            checkpoint_local_path,
            controlnet=controlnet,
            torch_dtype=Config.DTYPE,
            use_safetensors=True
        )
        
        self.pipeline.to(Config.DEVICE)

        try:
            self.pipeline.enable_xformers_memory_efficient_attention()
            print("  [OK] xFormers memory efficient attention enabled.")
        except Exception as e:
            print(f"  [WARNING] Failed to enable xFormers: {e}")

        print("Configuring LCMScheduler...")
        scheduler_config = self.pipeline.scheduler.config
        scheduler_config['clip_sample'] = False
        
        # --- MODIFIED: optimize for sharp pixel art style ---
        self.pipeline.scheduler = LCMScheduler.from_config(
            scheduler_config,
            timestep_spacing="trailing",
            beta_schedule="scaled_linear"
        )
        print("  [OK] LCMScheduler loaded (clip_sample=False, trailing spacing).")
        
        # 5. Load Adapters (IP-Adapter & LoRA)
        print("Loading Adapters (IP-Adapter & LoRA)...")
        
        ip_adapter_filename = "ip-adapter.bin"
        ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
        
        if not os.path.exists(ip_adapter_local_path):
            print(f"Downloading IP-Adapter to {ip_adapter_local_path}...")
            hf_hub_download(
                repo_id=Config.INSTANTID_REPO,
                filename=ip_adapter_filename,
                local_dir="./models",
                local_dir_use_symlinks=False
            )
        
        print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
        # Load InstantID adapter first
        self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
        
        print("Loading LCM LoRA weights...")
        # KEY CHANGE 1: Assign an adapter_name so Diffusers distinguishes it from InstantID
        self.pipeline.load_lora_weights(
            Config.REPO_ID, 
            weight_name=Config.LORA_FILENAME, 
            adapter_name="lcm_lora" 
        )
        
        # KEY CHANGE 2: Hardcode scale to 1.0 for LCM to remove trigger word dependency
        # (Or ensure Config.LORA_STRENGTH is set to 1.0)
        fuse_scale = 1.0 
        
        print(f"Fusing LoRA 'lcm_lora' with scale {fuse_scale}...")
        
        # KEY CHANGE 3: Fuse ONLY the named adapter
        self.pipeline.fuse_lora(
            adapter_names=["lcm_lora"], 
            lora_scale=fuse_scale
        )
        
        # KEY CHANGE 4: Unload the side-car weights to free VRAM (since they are now inside the UNet)
        self.pipeline.unload_lora_weights() 
        
        print("  [OK] LoRA fused and cleaned up.")
        
        # 6. Load Preprocessors
        print("Loading Preprocessors (LeReS, LineArtAnime)...")
        self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
        self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
        
        print("--- All models loaded successfully ---")

    def get_face_info(self, image):
        """Extracts the largest face, returns insightface result object."""
        if not self.face_analysis_loaded:
            return None
            
        try:
            cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 
            faces = self.app.get(cv2_img)
            
            if len(faces) == 0:
                return None
                
            # Sort by size (width * height) to find the main character
            faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
            
            # Return the largest face info
            return faces[0]
        except Exception as e:
            print(f"Face embedding extraction failed: {e}")
            return None