Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,974 Bytes
2dddf31 6d5987b 2dddf31 3885620 2dddf31 cc0ae1f 61380fb acd970e a92aea4 2dddf31 ea17e03 2dddf31 ea17e03 4a72459 6d5987b 948869c 6d5987b 2dddf31 948869c 6d5987b e01d167 ea17e03 6d5987b 948869c 3ae9ca7 6d5987b 948869c 6d5987b ea17e03 6d5987b 2dddf31 ea17e03 3885620 aecb45b 65a7aea 2dddf31 3885620 ea17e03 8333ca9 3885620 61380fb ea17e03 61380fb 3885620 76e0564 3885620 6d5987b acd970e 4089031 94327de 4089031 acd970e 65a7aea 2dddf31 6d5987b e01d167 4089031 e01d167 2dddf31 16dc50a 0117fa7 16dc50a 2e66212 3885620 40e76ba 66cd9e8 45890f7 66cd9e8 3885620 5e35e8b 3620e60 5e35e8b 3885620 fa327ca 5e35e8b fa327ca 3885620 3620e60 5e35e8b ea17e03 3620e60 3885620 3620e60 3885620 ea17e03 4a72459 ea17e03 2dddf31 6d5987b e4dd0ff 6d5987b 3885620 6d5987b 76e0564 6d5987b 3885620 6d5987b 3885620 ea17e03 3885620 e4dd0ff 6d5987b ea17e03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import torch
import os
import cv2
import numpy as np
from config import Config
from diffusers import (
ControlNetModel,
LCMScheduler,
# AutoencoderKL # Removed as requested
)
from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
# Import the custom pipeline from your local file
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline
from huggingface_hub import snapshot_download, hf_hub_download
from insightface.app import FaceAnalysis
from controlnet_aux import LeresDetector, LineartAnimeDetector
class ModelHandler:
def __init__(self):
self.pipeline = None
self.app = None # InsightFace
self.leres_detector = None
self.lineart_anime_detector = None
self.face_analysis_loaded = False
def load_face_analysis(self):
"""
Load face analysis model.
Downloads from HF Hub to the path insightface expects.
"""
print("Loading face analysis model...")
model_path = os.path.join(Config.ANTELOPEV2_ROOT, "models", Config.ANTELOPEV2_NAME)
if not os.path.exists(os.path.join(model_path, "scrfd_10g_bnkps.onnx")):
print(f"Downloading AntelopeV2 models from {Config.ANTELOPEV2_REPO} to {model_path}...")
try:
snapshot_download(
repo_id=Config.ANTELOPEV2_REPO,
local_dir=model_path, # Download to the correct expected path
)
except Exception as e:
print(f" [ERROR] Failed to download AntelopeV2 models: {e}")
return False
try:
self.app = FaceAnalysis(
name=Config.ANTELOPEV2_NAME,
root=Config.ANTELOPEV2_ROOT,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.app.prepare(ctx_id=0, det_size=(640, 640))
print(f" [OK] Face analysis model loaded successfully.")
return True
except Exception as e:
print(f" [WARNING] Face detection system failed to initialize: {e}")
return False
def load_models(self):
# 1. Load Face Analysis
self.face_analysis_loaded = self.load_face_analysis()
# 2. Load ControlNets
print("Loading ControlNets (InstantID, Zoe, LineArt)...")
# Load the InstantID ControlNet from the correct subfolder
print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
cn_instantid = ControlNetModel.from_pretrained(
Config.INSTANTID_REPO,
subfolder="ControlNetModel",
torch_dtype=Config.DTYPE
)
print(" [OK] Loaded InstantID ControlNet.")
# Load other ControlNets normally
print("Loading Zoe and LineArt ControlNets...")
cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
# --- Manually wrap the list of models in a MultiControlNetModel ---
print("Wrapping ControlNets in MultiControlNetModel...")
controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
controlnet = MultiControlNetModel(controlnet_list)
# --- End wrapping ---
# 3. Load SDXL Pipeline
print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
if not os.path.exists(checkpoint_local_path):
print(f"Downloading checkpoint to {checkpoint_local_path}...")
hf_hub_download(
repo_id=Config.REPO_ID,
filename=Config.CHECKPOINT_FILENAME,
local_dir="./models",
local_dir_use_symlinks=False
)
print(f"Loading pipeline from local file: {checkpoint_local_path}")
self.pipeline = StableDiffusionXLInstantIDImg2ImgPipeline.from_single_file(
checkpoint_local_path,
controlnet=controlnet,
torch_dtype=Config.DTYPE,
use_safetensors=True
)
self.pipeline.to(Config.DEVICE)
try:
self.pipeline.enable_xformers_memory_efficient_attention()
print(" [OK] xFormers memory efficient attention enabled.")
except Exception as e:
print(f" [WARNING] Failed to enable xFormers: {e}")
print("Configuring LCMScheduler...")
scheduler_config = self.pipeline.scheduler.config
scheduler_config['clip_sample'] = False
# --- MODIFIED: optimize for sharp pixel art style ---
self.pipeline.scheduler = LCMScheduler.from_config(
scheduler_config,
timestep_spacing="trailing",
beta_schedule="scaled_linear"
)
print(" [OK] LCMScheduler loaded (clip_sample=False, trailing spacing).")
# 5. Load Adapters (IP-Adapter & LoRA)
print("Loading Adapters (IP-Adapter & LoRA)...")
ip_adapter_filename = "ip-adapter.bin"
ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
if not os.path.exists(ip_adapter_local_path):
print(f"Downloading IP-Adapter to {ip_adapter_local_path}...")
hf_hub_download(
repo_id=Config.INSTANTID_REPO,
filename=ip_adapter_filename,
local_dir="./models",
local_dir_use_symlinks=False
)
print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
# Load InstantID adapter first
self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
print("Loading LCM LoRA weights...")
# KEY CHANGE 1: Assign an adapter_name so Diffusers distinguishes it from InstantID
self.pipeline.load_lora_weights(
Config.REPO_ID,
weight_name=Config.LORA_FILENAME,
adapter_name="lcm_lora"
)
# KEY CHANGE 2: Hardcode scale to 1.0 for LCM to remove trigger word dependency
# (Or ensure Config.LORA_STRENGTH is set to 1.0)
fuse_scale = 1.0
print(f"Fusing LoRA 'lcm_lora' with scale {fuse_scale}...")
# KEY CHANGE 3: Fuse ONLY the named adapter
self.pipeline.fuse_lora(
adapter_names=["lcm_lora"],
lora_scale=fuse_scale
)
# KEY CHANGE 4: Unload the side-car weights to free VRAM (since they are now inside the UNet)
self.pipeline.unload_lora_weights()
print(" [OK] LoRA fused and cleaned up.")
# 6. Load Preprocessors
print("Loading Preprocessors (LeReS, LineArtAnime)...")
self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
print("--- All models loaded successfully ---")
def get_face_info(self, image):
"""Extracts the largest face, returns insightface result object."""
if not self.face_analysis_loaded:
return None
try:
cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
faces = self.app.get(cv2_img)
if len(faces) == 0:
return None
# Sort by size (width * height) to find the main character
faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
# Return the largest face info
return faces[0]
except Exception as e:
print(f"Face embedding extraction failed: {e}")
return None |