primerz commited on
Commit
fa327ca
·
verified ·
1 Parent(s): 29c8100

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +27 -43
model.py CHANGED
@@ -65,28 +65,19 @@ class ModelHandler:
65
 
66
  # 2. Load ControlNets
67
  print("Loading ControlNets (InstantID, Zoe, LineArt)...")
68
-
69
- # Load the InstantID ControlNet from the correct subfolder
70
- print("Loading InstantID ControlNet from subfolder 'ControlNetModel'...")
71
  cn_instantid = ControlNetModel.from_pretrained(
72
  Config.INSTANTID_REPO,
73
  subfolder="ControlNetModel",
74
  torch_dtype=Config.DTYPE
75
  )
76
- print(" [OK] Loaded InstantID ControlNet.")
77
-
78
- # Load other ControlNets normally
79
- print("Loading Zoe and LineArt ControlNets...")
80
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
81
  cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
82
 
83
- # --- Manually wrap the list of models in a MultiControlNetModel ---
84
  print("Wrapping ControlNets in MultiControlNetModel...")
85
  controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
86
  controlnet = MultiControlNetModel(controlnet_list)
87
- # --- End wrapping ---
88
 
89
- # 3. Load SDXL Pipeline
90
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
91
 
92
  checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
@@ -109,21 +100,17 @@ class ModelHandler:
109
 
110
  self.pipeline.to(Config.DEVICE)
111
 
112
- # Enable xFormers
113
  try:
114
  self.pipeline.enable_xformers_memory_efficient_attention()
115
  print(" [OK] xFormers memory efficient attention enabled.")
116
  except Exception as e:
117
  print(f" [WARNING] Failed to enable xFormers: {e}")
118
 
119
- # 4. Set TCD Scheduler
120
  print("Configuring TCDScheduler...")
121
 
122
- # --- FIX STARTS HERE ---
123
- # Convert FrozenDict to a mutable standard Python dict
124
  tcd_config = dict(self.pipeline.scheduler.config)
125
-
126
- # Now we can update it safely
127
  tcd_config.update({
128
  "beta_start": 0.00085,
129
  "beta_end": 0.012,
@@ -137,34 +124,28 @@ class ModelHandler:
137
  use_karras_sigmas=True,
138
  timestep_spacing="trailing"
139
  )
140
- # --- FIX ENDS HERE ---
141
-
142
  print(" [OK] TCDScheduler loaded (Forced SDXL Defaults + Karras + Trailing).")
143
 
144
- # 5. Load Adapters (IP-Adapter, TCD-LoRA & Style LoRA)
145
  print("Loading Adapters...")
146
 
147
- # 5a. IP-Adapter
148
  ip_adapter_filename = "ip-adapter.bin"
149
  ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
150
-
151
  if not os.path.exists(ip_adapter_local_path):
152
- print(f"Downloading IP-Adapter to {ip_adapter_local_path}...")
153
  hf_hub_download(
154
  repo_id=Config.INSTANTID_REPO,
155
  filename=ip_adapter_filename,
156
  local_dir="./models",
157
  local_dir_use_symlinks=False
158
  )
159
-
160
- print(f"Loading IP-Adapter from local file: {ip_adapter_local_path}")
161
  self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
 
162
 
163
- # 5b. Load TCD LoRA (Correct Filename)
164
  print("Loading TCD-SDXL-LoRA...")
165
  tcd_lora_filename = "pytorch_lora_weights.safetensors"
166
  tcd_lora_path = os.path.join("./models", tcd_lora_filename)
167
-
168
  if not os.path.exists(tcd_lora_path):
169
  hf_hub_download(
170
  repo_id="h1t/TCD-SDXL-LoRA",
@@ -172,19 +153,28 @@ class ModelHandler:
172
  local_dir="./models",
173
  local_dir_use_symlinks=False
174
  )
175
- self.pipeline.load_lora_weights("./models", weight_name=tcd_lora_filename)
176
- self.pipeline.fuse_lora(lora_scale=1.0)
177
- print(" [OK] TCD LoRA fused.")
178
 
179
- # 5c. Load Style LoRA
180
- print("Loading Style LoRA weights...")
181
- self.pipeline.load_lora_weights(Config.REPO_ID, weight_name=Config.LORA_FILENAME)
182
-
183
- print(f"Fusing Style LoRA with scale {Config.LORA_STRENGTH}...")
184
- self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH)
185
- print(" [OK] Style LoRA fused.")
 
 
 
 
 
186
 
187
- # 6. Load Preprocessors
 
 
 
 
 
188
  print("Loading Preprocessors (LeReS, LineArtAnime)...")
189
  self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
190
  self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
@@ -195,18 +185,12 @@ class ModelHandler:
195
  """Extracts the largest face, returns insightface result object."""
196
  if not self.face_analysis_loaded:
197
  return None
198
-
199
  try:
200
  cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
201
  faces = self.app.get(cv2_img)
202
-
203
  if len(faces) == 0:
204
  return None
205
-
206
- # Sort by size (width * height) to find the main character
207
  faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
208
-
209
- # Return the largest face info
210
  return faces[0]
211
  except Exception as e:
212
  print(f"Face embedding extraction failed: {e}")
 
65
 
66
  # 2. Load ControlNets
67
  print("Loading ControlNets (InstantID, Zoe, LineArt)...")
 
 
 
68
  cn_instantid = ControlNetModel.from_pretrained(
69
  Config.INSTANTID_REPO,
70
  subfolder="ControlNetModel",
71
  torch_dtype=Config.DTYPE
72
  )
 
 
 
 
73
  cn_zoe = ControlNetModel.from_pretrained(Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE)
74
  cn_lineart = ControlNetModel.from_pretrained(Config.CN_LINEART_REPO, torch_dtype=Config.DTYPE)
75
 
 
76
  print("Wrapping ControlNets in MultiControlNetModel...")
77
  controlnet_list = [cn_instantid, cn_zoe, cn_lineart]
78
  controlnet = MultiControlNetModel(controlnet_list)
 
79
 
80
+ # 3. Load SDXL Pipeline (Now from 'reality.safetensors')
81
  print(f"Loading SDXL Pipeline ({Config.CHECKPOINT_FILENAME})...")
82
 
83
  checkpoint_local_path = os.path.join("./models", Config.CHECKPOINT_FILENAME)
 
100
 
101
  self.pipeline.to(Config.DEVICE)
102
 
 
103
  try:
104
  self.pipeline.enable_xformers_memory_efficient_attention()
105
  print(" [OK] xFormers memory efficient attention enabled.")
106
  except Exception as e:
107
  print(f" [WARNING] Failed to enable xFormers: {e}")
108
 
109
+ # 4. Set TCD Scheduler (Sanitized Config)
110
  print("Configuring TCDScheduler...")
111
 
112
+ # Force standard SDXL config to prevent noise artifacts
 
113
  tcd_config = dict(self.pipeline.scheduler.config)
 
 
114
  tcd_config.update({
115
  "beta_start": 0.00085,
116
  "beta_end": 0.012,
 
124
  use_karras_sigmas=True,
125
  timestep_spacing="trailing"
126
  )
 
 
127
  print(" [OK] TCDScheduler loaded (Forced SDXL Defaults + Karras + Trailing).")
128
 
129
+ # 5. Load Adapters
130
  print("Loading Adapters...")
131
 
132
+ # 5a. IP-Adapter (for InstantID)
133
  ip_adapter_filename = "ip-adapter.bin"
134
  ip_adapter_local_path = os.path.join("./models", ip_adapter_filename)
 
135
  if not os.path.exists(ip_adapter_local_path):
 
136
  hf_hub_download(
137
  repo_id=Config.INSTANTID_REPO,
138
  filename=ip_adapter_filename,
139
  local_dir="./models",
140
  local_dir_use_symlinks=False
141
  )
 
 
142
  self.pipeline.load_ip_adapter_instantid(ip_adapter_local_path)
143
+ print(" [OK] IP-Adapter loaded.")
144
 
145
+ # 5b. TCD LoRA (for speed)
146
  print("Loading TCD-SDXL-LoRA...")
147
  tcd_lora_filename = "pytorch_lora_weights.safetensors"
148
  tcd_lora_path = os.path.join("./models", tcd_lora_filename)
 
149
  if not os.path.exists(tcd_lora_path):
150
  hf_hub_download(
151
  repo_id="h1t/TCD-SDXL-LoRA",
 
153
  local_dir="./models",
154
  local_dir_use_symlinks=False
155
  )
156
+ self.pipeline.load_lora_weights("./models", weight_name=tcd_lora_filename, adapter_name="tcd")
157
+ print(" [OK] TCD LoRA loaded.")
 
158
 
159
+ # 5c. Style LoRA (lucasart)
160
+ print(f"Loading Style LoRA ({Config.LORA_FILENAME})...")
161
+ style_lora_path = os.path.join("./models", Config.LORA_FILENAME)
162
+ if not os.path.exists(style_lora_path):
163
+ hf_hub_download(
164
+ repo_id=Config.REPO_ID,
165
+ filename=Config.LORA_FILENAME,
166
+ local_dir="./models",
167
+ local_dir_use_symlinks=False
168
+ )
169
+ self.pipeline.load_lora_weights("./models", weight_name=Config.LORA_FILENAME, adapter_name="style")
170
+ print(" [OK] Style LoRA loaded.")
171
 
172
+ # 6. Set Adapter Weights (TCD + Style)
173
+ # We set both adapters to run simultaneously
174
+ print(f"Setting adapter weights: TCD (1.0), Style ({Config.LORA_STRENGTH})")
175
+ self.pipeline.set_adapters(["tcd", "style"], adapter_weights=[1.0, Config.LORA_STRENGTH])
176
+
177
+ # 7. Load Preprocessors
178
  print("Loading Preprocessors (LeReS, LineArtAnime)...")
179
  self.leres_detector = LeresDetector.from_pretrained(Config.ANNOTATOR_REPO)
180
  self.lineart_anime_detector = LineartAnimeDetector.from_pretrained(Config.ANNOTATOR_REPO)
 
185
  """Extracts the largest face, returns insightface result object."""
186
  if not self.face_analysis_loaded:
187
  return None
 
188
  try:
189
  cv2_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
190
  faces = self.app.get(cv2_img)
 
191
  if len(faces) == 0:
192
  return None
 
 
193
  faces = sorted(faces, key=lambda x: (x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]), reverse=True)
 
 
194
  return faces[0]
195
  except Exception as e:
196
  print(f"Face embedding extraction failed: {e}")