Spaces:
Running
on
Zero
Running
on
Zero
Update model.py
Browse files
model.py
CHANGED
|
@@ -65,28 +65,19 @@ class ModelHandler:
|
|
| 65 |
|
| 66 |
# 2. Load ControlNets
|
| 67 |
print("Loading ControlNets (InstantID, Zoe, LineArt)...")
|
| 68 |
-
|
| 69 |
-
# Load the InstantID ControlNet from the correct subfolder
|
| 70 |
-
print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
|
| 71 |
cn_instantid = ControlNetModel.from_pretrained(
|
| 72 |
Config.INSTANTID_REPO,
|
| 73 |
subfolder="ControlNetModel",
|
| 74 |
torch_dtype=Config.DTYPE
|
| 75 |
)
|
| 76 |
-
print(" [OK] Loaded InstantID ControlNet.")
|
| 77 |
-
|
| 78 |
-
# Load other ControlNets normally
|
| 79 |
-
print("Loading Zoe and LineArt ControlNets...")
|
| 80 |
cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
|
| 81 |
cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
|
| 82 |
|
| 83 |
-
# --- Manually wrap the list of models in a MultiControlNetModel ---
|
| 84 |
print("Wrapping ControlNets in MultiControlNetModel...")
|
| 85 |
controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
|
| 86 |
controlnet = MultiControlNetModel(controlnet_list)
|
| 87 |
-
# --- End wrapping ---
|
| 88 |
|
| 89 |
-
# 3. Load SDXL Pipeline
|
| 90 |
print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
|
| 91 |
|
| 92 |
checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
|
|
@@ -109,21 +100,17 @@ class ModelHandler:
|
|
| 109 |
|
| 110 |
self.pipeline.to(Config.DEVICE)
|
| 111 |
|
| 112 |
-
# Enable xFormers
|
| 113 |
try:
|
| 114 |
self.pipeline.enable_xformers_memory_efficient_attention()
|
| 115 |
print(" [OK] xFormers memory efficient attention enabled.")
|
| 116 |
except Exception as e:
|
| 117 |
print(f" [WARNING] Failed to enable xFormers: {e}")
|
| 118 |
|
| 119 |
-
# 4. Set TCD Scheduler
|
| 120 |
print("Configuring TCDScheduler...")
|
| 121 |
|
| 122 |
-
#
|
| 123 |
-
# Convert FrozenDict to a mutable standard Python dict
|
| 124 |
tcd_config = dict(self.pipeline.scheduler.config)
|
| 125 |
-
|
| 126 |
-
# Now we can update it safely
|
| 127 |
tcd_config.update({
|
| 128 |
"beta_start": 0.00085,
|
| 129 |
"beta_end": 0.012,
|
|
@@ -137,34 +124,28 @@ class ModelHandler:
|
|
| 137 |
use_karras_sigmas=True,
|
| 138 |
timestep_spacing="trailing"
|
| 139 |
)
|
| 140 |
-
# --- FIX ENDS HERE ---
|
| 141 |
-
|
| 142 |
print(" [OK] TCDScheduler loaded (Forced SDXL Defaults + Karras + Trailing).")
|
| 143 |
|
| 144 |
-
# 5. Load Adapters
|
| 145 |
print("Loading Adapters...")
|
| 146 |
|
| 147 |
-
# 5a. IP-Adapter
|
| 148 |
ip_adapter_filename = "ip-adapter.bin"
|
| 149 |
ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
|
| 150 |
-
|
| 151 |
if not os.path.exists(ip_adapter_local_path):
|
| 152 |
-
print(f"Downloading IP-Adapter to {ip_adapter_local_path}...")
|
| 153 |
hf_hub_download(
|
| 154 |
repo_id=Config.INSTANTID_REPO,
|
| 155 |
filename=ip_adapter_filename,
|
| 156 |
local_dir="./models",
|
| 157 |
local_dir_use_symlinks=False
|
| 158 |
)
|
| 159 |
-
|
| 160 |
-
print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
|
| 161 |
self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
|
|
|
|
| 162 |
|
| 163 |
-
# 5b.
|
| 164 |
print("Loading TCD-SDXL-LoRA...")
|
| 165 |
tcd_lora_filename = "pytorch_lora_weights.safetensors"
|
| 166 |
tcd_lora_path = os.path.join("./models", tcd_lora_filename)
|
| 167 |
-
|
| 168 |
if not os.path.exists(tcd_lora_path):
|
| 169 |
hf_hub_download(
|
| 170 |
repo_id="h1t/TCD-SDXL-LoRA",
|
|
@@ -172,19 +153,28 @@ class ModelHandler:
|
|
| 172 |
local_dir="./models",
|
| 173 |
local_dir_use_symlinks=False
|
| 174 |
)
|
| 175 |
-
self.pipeline.load_lora_weights("./models", weight_name=tcd_lora_filename)
|
| 176 |
-
|
| 177 |
-
print(" [OK] TCD LoRA fused.")
|
| 178 |
|
| 179 |
-
# 5c.
|
| 180 |
-
print("Loading Style LoRA
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
-
# 6.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
print("Loading Preprocessors (LeReS, LineArtAnime)...")
|
| 189 |
self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
|
| 190 |
self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
|
|
@@ -195,18 +185,12 @@ class ModelHandler:
|
|
| 195 |
"""Extracts the largest face, returns insightface result object."""
|
| 196 |
if not self.face_analysis_loaded:
|
| 197 |
return None
|
| 198 |
-
|
| 199 |
try:
|
| 200 |
cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 201 |
faces = self.app.get(cv2_img)
|
| 202 |
-
|
| 203 |
if len(faces) == 0:
|
| 204 |
return None
|
| 205 |
-
|
| 206 |
-
# Sort by size (width * height) to find the main character
|
| 207 |
faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
|
| 208 |
-
|
| 209 |
-
# Return the largest face info
|
| 210 |
return faces[0]
|
| 211 |
except Exception as e:
|
| 212 |
print(f"Face embedding extraction failed: {e}")
|
|
|
|
| 65 |
|
| 66 |
# 2. Load ControlNets
|
| 67 |
print("Loading ControlNets (InstantID, Zoe, LineArt)...")
|
|
|
|
|
|
|
|
|
|
| 68 |
cn_instantid = ControlNetModel.from_pretrained(
|
| 69 |
Config.INSTANTID_REPO,
|
| 70 |
subfolder="ControlNetModel",
|
| 71 |
torch_dtype=Config.DTYPE
|
| 72 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
|
| 74 |
cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
|
| 75 |
|
|
|
|
| 76 |
print("Wrapping ControlNets in MultiControlNetModel...")
|
| 77 |
controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
|
| 78 |
controlnet = MultiControlNetModel(controlnet_list)
|
|
|
|
| 79 |
|
| 80 |
+
# 3. Load SDXL Pipeline (Now from 'reality.safetensors')
|
| 81 |
print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
|
| 82 |
|
| 83 |
checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
|
|
|
|
| 100 |
|
| 101 |
self.pipeline.to(Config.DEVICE)
|
| 102 |
|
|
|
|
| 103 |
try:
|
| 104 |
self.pipeline.enable_xformers_memory_efficient_attention()
|
| 105 |
print(" [OK] xFormers memory efficient attention enabled.")
|
| 106 |
except Exception as e:
|
| 107 |
print(f" [WARNING] Failed to enable xFormers: {e}")
|
| 108 |
|
| 109 |
+
# 4. Set TCD Scheduler (Sanitized Config)
|
| 110 |
print("Configuring TCDScheduler...")
|
| 111 |
|
| 112 |
+
# Force standard SDXL config to prevent noise artifacts
|
|
|
|
| 113 |
tcd_config = dict(self.pipeline.scheduler.config)
|
|
|
|
|
|
|
| 114 |
tcd_config.update({
|
| 115 |
"beta_start": 0.00085,
|
| 116 |
"beta_end": 0.012,
|
|
|
|
| 124 |
use_karras_sigmas=True,
|
| 125 |
timestep_spacing="trailing"
|
| 126 |
)
|
|
|
|
|
|
|
| 127 |
print(" [OK] TCDScheduler loaded (Forced SDXL Defaults + Karras + Trailing).")
|
| 128 |
|
| 129 |
+
# 5. Load Adapters
|
| 130 |
print("Loading Adapters...")
|
| 131 |
|
| 132 |
+
# 5a. IP-Adapter (for InstantID)
|
| 133 |
ip_adapter_filename = "ip-adapter.bin"
|
| 134 |
ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
|
|
|
|
| 135 |
if not os.path.exists(ip_adapter_local_path):
|
|
|
|
| 136 |
hf_hub_download(
|
| 137 |
repo_id=Config.INSTANTID_REPO,
|
| 138 |
filename=ip_adapter_filename,
|
| 139 |
local_dir="./models",
|
| 140 |
local_dir_use_symlinks=False
|
| 141 |
)
|
|
|
|
|
|
|
| 142 |
self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
|
| 143 |
+
print(" [OK] IP-Adapter loaded.")
|
| 144 |
|
| 145 |
+
# 5b. TCD LoRA (for speed)
|
| 146 |
print("Loading TCD-SDXL-LoRA...")
|
| 147 |
tcd_lora_filename = "pytorch_lora_weights.safetensors"
|
| 148 |
tcd_lora_path = os.path.join("./models", tcd_lora_filename)
|
|
|
|
| 149 |
if not os.path.exists(tcd_lora_path):
|
| 150 |
hf_hub_download(
|
| 151 |
repo_id="h1t/TCD-SDXL-LoRA",
|
|
|
|
| 153 |
local_dir="./models",
|
| 154 |
local_dir_use_symlinks=False
|
| 155 |
)
|
| 156 |
+
self.pipeline.load_lora_weights("./models", weight_name=tcd_lora_filename, adapter_name="tcd")
|
| 157 |
+
print(" [OK] TCD LoRA loaded.")
|
|
|
|
| 158 |
|
| 159 |
+
# 5c. Style LoRA (lucasart)
|
| 160 |
+
print(f"Loading Style LoRA ({Config.LORA_FILENAME})...")
|
| 161 |
+
style_lora_path = os.path.join("./models", Config.LORA_FILENAME)
|
| 162 |
+
if not os.path.exists(style_lora_path):
|
| 163 |
+
hf_hub_download(
|
| 164 |
+
repo_id=Config.REPO_ID,
|
| 165 |
+
filename=Config.LORA_FILENAME,
|
| 166 |
+
local_dir="./models",
|
| 167 |
+
local_dir_use_symlinks=False
|
| 168 |
+
)
|
| 169 |
+
self.pipeline.load_lora_weights("./models", weight_name=Config.LORA_FILENAME, adapter_name="style")
|
| 170 |
+
print(" [OK] Style LoRA loaded.")
|
| 171 |
|
| 172 |
+
# 6. Set Adapter Weights (TCD + Style)
|
| 173 |
+
# We set both adapters to run simultaneously
|
| 174 |
+
print(f"Setting adapter weights: TCD (1.0), Style ({Config.LORA_STRENGTH})")
|
| 175 |
+
self.pipeline.set_adapters(["tcd", "style"], adapter_weights=[1.0, Config.LORA_STRENGTH])
|
| 176 |
+
|
| 177 |
+
# 7. Load Preprocessors
|
| 178 |
print("Loading Preprocessors (LeReS, LineArtAnime)...")
|
| 179 |
self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
|
| 180 |
self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
|
|
|
|
| 185 |
"""Extracts the largest face, returns insightface result object."""
|
| 186 |
if not self.face_analysis_loaded:
|
| 187 |
return None
|
|
|
|
| 188 |
try:
|
| 189 |
cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| 190 |
faces = self.app.get(cv2_img)
|
|
|
|
| 191 |
if len(faces) == 0:
|
| 192 |
return None
|
|
|
|
|
|
|
| 193 |
faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
|
|
|
|
|
|
|
| 194 |
return faces[0]
|
| 195 |
except Exception as e:
|
| 196 |
print(f"Face embedding extraction failed: {e}")
|