primerz commited on
Commit
60bf1c5
·
verified ·
1 Parent(s): 5a9aef6

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +34 -6
generator.py CHANGED
@@ -1,22 +1,45 @@
1
  import torch
2
  from config import Config
3
  from utils import resize_image_to_1mp, get_caption
 
4
 
5
  class Generator:
6
  def __init__(self, model_handler):
7
  self.mh = model_handler
8
 
9
- def prepare_control_images(self, image):
10
- """Generates the conditioning maps from the input image."""
11
- depth_map = self.mh.zoe_detector(image)
12
- lineart_map = self.mh.lineart_detector(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  return depth_map, lineart_map
 
14
 
15
  def predict(self, input_image, user_prompt=""):
16
  # 1. Pre-process Inputs
17
  print("Processing Input...")
18
  processed_image = resize_image_to_1mp(input_image)
19
 
 
 
 
 
 
20
  # 2. Get Face Embedding (Robust Mode)
21
  face_emb = self.mh.get_face_embedding(processed_image)
22
 
@@ -35,7 +58,10 @@ class Generator:
35
 
36
  # 4. Generate Control Maps (Structure)
37
  print("Generating Control Maps (Depth, LineArt)...")
38
- depth_map, lineart_map = self.prepare_control_images(processed_image)
 
 
 
39
 
40
  # 5. Logic for Face vs No-Face
41
  # ControlNet order: [InstantID, Zoe, LineArt]
@@ -55,7 +81,10 @@ class Generator:
55
  result = self.mh.pipeline(
56
  prompt=final_prompt,
57
  image=processed_image, # <-- Base image for Img2Img
 
 
58
  control_image=[processed_image, depth_map, lineart_map], # <-- ControlNet inputs
 
59
  image_embeds=face_emb, # <-- Face embedding for InstantID
60
 
61
  strength=0.85, # Img2Img strength (0.8-0.9 is good for style)
@@ -66,7 +95,6 @@ class Generator:
66
  num_inference_steps=8,
67
  guidance_scale=1.5,
68
 
69
- # --- ADDED ---
70
  clip_skip=2
71
 
72
  ).images[0]
 
1
  import torch
2
  from config import Config
3
  from utils import resize_image_to_1mp, get_caption
4
+ from PIL import Image # <-- Make sure this import is at the top
5
 
6
  class Generator:
7
  def __init__(self, model_handler):
8
  self.mh = model_handler
9
 
10
+ # --- START FIX ---
11
+ def prepare_control_images(self, image, width, height):
12
+ """
13
+ Generates conditioning maps, ensuring they are resized
14
+ to the exact target dimensions (width, height).
15
+ """
16
+ print(f"Generating control maps for {width}x{height}...")
17
+
18
+ # Generate depth map
19
+ # The detector might return a different size (e.g., 512x512)
20
+ depth_map_raw = self.mh.zoe_detector(image)
21
+
22
+ # Generate lineart map
23
+ lineart_map_raw = self.mh.lineart_detector(image)
24
+
25
+ # Manually resize maps to match the exact output resolution
26
+ # This prevents the tensor mismatch error.
27
+ depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
28
+ lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
29
+
30
  return depth_map, lineart_map
31
+ # --- END FIX ---
32
 
33
  def predict(self, input_image, user_prompt=""):
34
  # 1. Pre-process Inputs
35
  print("Processing Input...")
36
  processed_image = resize_image_to_1mp(input_image)
37
 
38
+ # --- START FIX ---
39
+ # Get the exact dimensions for the control maps
40
+ target_width, target_height = processed_image.size
41
+ # --- END FIX ---
42
+
43
  # 2. Get Face Embedding (Robust Mode)
44
  face_emb = self.mh.get_face_embedding(processed_image)
45
 
 
58
 
59
  # 4. Generate Control Maps (Structure)
60
  print("Generating Control Maps (Depth, LineArt)...")
61
+ # --- START FIX ---
62
+ # Pass target dimensions to the preprocessor
63
+ depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
64
+ # --- END FIX ---
65
 
66
  # 5. Logic for Face vs No-Face
67
  # ControlNet order: [InstantID, Zoe, LineArt]
 
81
  result = self.mh.pipeline(
82
  prompt=final_prompt,
83
  image=processed_image, # <-- Base image for Img2Img
84
+
85
+ # All 3 images are now guaranteed to be the same size
86
  control_image=[processed_image, depth_map, lineart_map], # <-- ControlNet inputs
87
+
88
  image_embeds=face_emb, # <-- Face embedding for InstantID
89
 
90
  strength=0.85, # Img2Img strength (0.8-0.9 is good for style)
 
95
  num_inference_steps=8,
96
  guidance_scale=1.5,
97
 
 
98
  clip_skip=2
99
 
100
  ).images[0]