primerz commited on
Commit
ea17e03
·
verified ·
1 Parent(s): 8d98c0c

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +28 -92
model.py CHANGED
@@ -6,7 +6,7 @@ from config import Config
6
 
7
  from diffusers import (
8
  ControlNetModel,
9
- TCDScheduler,
10
  )
11
  from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
12
 
@@ -15,17 +15,15 @@ from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInst
15
 
16
  from huggingface_hub import snapshot_download, hf_hub_download
17
  from insightface.app import FaceAnalysis
18
- from controlnet_aux import LeresDetector, LineartAnimeDetector, CannyDetector
19
 
20
  class ModelHandler:
21
  def __init__(self):
22
  self.pipeline = None
23
- self.app = None # InsightFace
24
  self.leres_detector = None
25
  self.lineart_anime_detector = None
26
- self.canny_detector = None
27
  self.face_analysis_loaded = False
28
- self.edge_type = Config.DEFAULT_EDGE_TYPE
29
 
30
  def load_face_analysis(self):
31
  """
@@ -41,7 +39,7 @@ class ModelHandler:
41
  try:
42
  snapshot_download(
43
  repo_id=Config.ANTELOPEV2_REPO,
44
- local_dir=model_path,
45
  )
46
  except Exception as e:
47
  print(f" [ERROR] Failed to download AntelopeV2 models: {e}")
@@ -61,65 +59,25 @@ class ModelHandler:
61
  print(f" [WARNING] Face detection system failed to initialize: {e}")
62
  return False
63
 
64
- def load_models(self, edge_type="canny"):
65
- """
66
- Load all models with support for different edge detection types.
67
-
68
- Args:
69
- edge_type: "canny", "lineart", or "both"
70
- """
71
- self.edge_type = edge_type
72
-
73
  # 1. Load Face Analysis
74
  self.face_analysis_loaded = self.load_face_analysis()
75
 
76
- # 2. Load ControlNets based on edge_type
77
- print(f"Loading ControlNets (InstantID, Zoe, {edge_type.upper()})...")
78
  cn_instantid = ControlNetModel.from_pretrained(
79
  Config.INSTANTID_REPO,
80
  subfolder="ControlNetModel",
81
  torch_dtype=Config.DTYPE
82
  )
83
- cn_zoe = ControlNetModel.from_pretrained(
84
- Config.CN_ZOE_REPO,
85
- torch_dtype=Config.DTYPE
86
- )
87
-
88
- # Load edge ControlNet(s)
89
- controlnet_list = [cn_instantid, cn_zoe]
90
-
91
- if edge_type == "canny":
92
- cn_canny = ControlNetModel.from_pretrained(
93
- Config.CN_CANNY_REPO,
94
- torch_dtype=Config.DTYPE
95
- )
96
- controlnet_list.append(cn_canny)
97
- print(" [OK] Loaded Canny ControlNet")
98
-
99
- elif edge_type == "lineart":
100
- cn_lineart = ControlNetModel.from_pretrained(
101
- Config.CN_LINEART_REPO,
102
- torch_dtype=Config.DTYPE
103
- )
104
- controlnet_list.append(cn_lineart)
105
- print(" [OK] Loaded LineArt ControlNet")
106
-
107
- elif edge_type == "both":
108
- cn_canny = ControlNetModel.from_pretrained(
109
- Config.CN_CANNY_REPO,
110
- torch_dtype=Config.DTYPE
111
- )
112
- cn_lineart = ControlNetModel.from_pretrained(
113
- Config.CN_LINEART_REPO,
114
- torch_dtype=Config.DTYPE
115
- )
116
- controlnet_list.extend([cn_canny, cn_lineart])
117
- print(" [OK] Loaded both Canny and LineArt ControlNets")
118
 
119
  print("Wrapping ControlNets in MultiControlNetModel...")
 
120
  controlnet = MultiControlNetModel(controlnet_list)
121
 
122
- # 3. Load SDXL Pipeline
123
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
124
 
125
  checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
@@ -148,15 +106,18 @@ class ModelHandler:
148
  except Exception as e:
149
  print(f" [WARNING] Failed to enable xFormers: {e}")
150
 
151
- # 4. Set TCD Scheduler
152
- print("Configuring TCDScheduler...")
153
- self.pipeline.scheduler = TCDScheduler.from_config(self.pipeline.scheduler.config)
154
- print(" [OK] TCDScheduler loaded.")
 
 
 
155
 
156
  # 5. Load Adapters
157
  print("Loading Adapters...")
158
 
159
- # 5a. Load and Fuse Style LoRA
160
  print(f"Loading and Fusing Style LoRA ({Config.LORA_FILENAME})...")
161
  style_lora_path = os.path.join("./models", Config.LORA_FILENAME)
162
  if not os.path.exists(style_lora_path):
@@ -170,7 +131,7 @@ class ModelHandler:
170
  self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
171
  print(" [OK] Style LoRA fused.")
172
 
173
- # 5b. Load IP-Adapter for InstantID
174
  ip_adapter_filename = "ip-adapter.bin"
175
  ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
176
  if not os.path.exists(ip_adapter_local_path):
@@ -181,19 +142,14 @@ class ModelHandler:
181
  local_dir_use_symlinks=False
182
  )
183
  self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
184
- print(" [OK] InstantID IP-Adapter loaded.")
 
 
185
 
186
- # 6. Load Preprocessors
187
- print("Loading Preprocessors...")
188
  self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
189
-
190
- if edge_type in ["canny", "both"]:
191
- self.canny_detector = CannyDetector()
192
- print(" [OK] Canny detector loaded")
193
-
194
- if edge_type in ["lineart", "both"]:
195
- self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
196
- print(" [OK] LineArt detector loaded")
197
 
198
  print("--- All models loaded successfully ---")
199
 
@@ -206,28 +162,8 @@ class ModelHandler:
206
  faces = self.app.get(cv2_img)
207
  if len(faces) == 0:
208
  return None
209
- faces = sorted(
210
- faces,
211
- key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]),
212
- reverse=True
213
- )
214
  return faces[0]
215
  except Exception as e:
216
  print(f"Face embedding extraction failed: {e}")
217
- return None
218
-
219
- def extract_depth(self, image):
220
- """Extract depth map using LeReS detector"""
221
- return self.leres_detector(image)
222
-
223
- def extract_canny(self, image, low_threshold=100, high_threshold=200):
224
- """Extract Canny edges"""
225
- if self.canny_detector is None:
226
- raise ValueError("Canny detector not loaded. Initialize with edge_type='canny' or 'both'")
227
- return self.canny_detector(image, low_threshold=low_threshold, high_threshold=high_threshold)
228
-
229
- def extract_lineart(self, image):
230
- """Extract LineArt edges"""
231
- if self.lineart_anime_detector is None:
232
- raise ValueError("LineArt detector not loaded. Initialize with edge_type='lineart' or 'both'")
233
- return self.lineart_anime_detector(image)
 
6
 
7
  from diffusers import (
8
  ControlNetModel,
9
+ DPMSolverMultistepScheduler,
10
  )
11
  from diffusers.models.controlnets.multicontrolnet import MultiControlNetModel
12
 
 
15
 
16
  from huggingface_hub import snapshot_download, hf_hub_download
17
  from insightface.app import FaceAnalysis
18
+ from controlnet_aux import LeresDetector, LineartAnimeDetector
19
 
20
  class ModelHandler:
21
  def __init__(self):
22
  self.pipeline = None
23
+ self.app = None # InsightFace
24
  self.leres_detector = None
25
  self.lineart_anime_detector = None
 
26
  self.face_analysis_loaded = False
 
27
 
28
  def load_face_analysis(self):
29
  """
 
39
  try:
40
  snapshot_download(
41
  repo_id=Config.ANTELOPEV2_REPO,
42
+ local_dir=model_path, # Download to the correct expected path
43
  )
44
  except Exception as e:
45
  print(f" [ERROR] Failed to download AntelopeV2 models: {e}")
 
59
  print(f" [WARNING] Face detection system failed to initialize: {e}")
60
  return False
61
 
62
+ def load_models(self):
 
 
 
 
 
 
 
 
63
  # 1. Load Face Analysis
64
  self.face_analysis_loaded = self.load_face_analysis()
65
 
66
+ # 2. Load ControlNets
67
+ print("Loading ControlNets (InstantID, Zoe, LineArt)...")
68
  cn_instantid = ControlNetModel.from_pretrained(
69
  Config.INSTANTID_REPO,
70
  subfolder="ControlNetModel",
71
  torch_dtype=Config.DTYPE
72
  )
73
+ cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
74
+ cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  print("Wrapping ControlNets in MultiControlNetModel...")
77
+ controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
78
  controlnet = MultiControlNetModel(controlnet_list)
79
 
80
+ # 3. Load SDXL Pipeline (Now from 'reality.safetensors')
81
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
82
 
83
  checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
 
106
  except Exception as e:
107
  print(f" [WARNING] Failed to enable xFormers: {e}")
108
 
109
+ # 4. Set DPMSolver++ Scheduler with Karras sigmas
110
+ print("Configuring DPMSolverMultistepScheduler...")
111
+ self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
112
+ self.pipeline.scheduler.config,
113
+ use_karras_sigmas=True
114
+ )
115
+ print(" [OK] DPMSolverMultistepScheduler loaded with Karras sigmas.")
116
 
117
  # 5. Load Adapters
118
  print("Loading Adapters...")
119
 
120
+ # 5b. Load and Fuse Style LoRA (lucasart)
121
  print(f"Loading and Fusing Style LoRA ({Config.LORA_FILENAME})...")
122
  style_lora_path = os.path.join("./models", Config.LORA_FILENAME)
123
  if not os.path.exists(style_lora_path):
 
131
  self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
132
  print(" [OK] Style LoRA fused.")
133
 
134
+ # 5c. Load IP-Adapter (for InstantID) - *Must be loaded AFTER fusing*
135
  ip_adapter_filename = "ip-adapter.bin"
136
  ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
137
  if not os.path.exists(ip_adapter_local_path):
 
142
  local_dir_use_symlinks=False
143
  )
144
  self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
145
+ print(" [OK] IP-Adapter loaded.")
146
+
147
+ # --- END FIX ---
148
 
149
+ # 7. Load Preprocessors
150
+ print("Loading Preprocessors (LeReS, LineArtAnime)...")
151
  self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
152
+ self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
 
 
 
 
 
 
 
153
 
154
  print("--- All models loaded successfully ---")
155
 
 
162
  faces = self.app.get(cv2_img)
163
  if len(faces) == 0:
164
  return None
165
+ faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
 
 
 
 
166
  return faces[0]
167
  except Exception as e:
168
  print(f"Face embedding extraction failed: {e}")
169
+ return None