primerz commited on
Commit
a910636
·
verified ·
1 Parent(s): ff641c2

Update generator.py

Browse files
Files changed (1) hide show
  1. generator.py +17 -22
generator.py CHANGED
@@ -1,13 +1,12 @@
1
  import torch
2
  from config import Config
3
  from utils import resize_image_to_1mp, get_caption
4
- from PIL import Image # <-- Make sure this import is at the top
5
 
6
  class Generator:
7
  def __init__(self, model_handler):
8
  self.mh = model_handler
9
 
10
- # --- START FIX ---
11
  def prepare_control_images(self, image, width, height):
12
  """
13
  Generates conditioning maps, ensuring they are resized
@@ -16,29 +15,22 @@ class Generator:
16
  print(f"Generating control maps for {width}x{height}...")
17
 
18
  # Generate depth map
19
- # The detector might return a different size (e.g., 512x512)
20
  depth_map_raw = self.mh.zoe_detector(image)
21
 
22
  # Generate lineart map
23
  lineart_map_raw = self.mh.lineart_detector(image)
24
 
25
  # Manually resize maps to match the exact output resolution
26
- # This prevents the tensor mismatch error.
27
  depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
28
  lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
29
 
30
  return depth_map, lineart_map
31
- # --- END FIX ---
32
 
33
  def predict(self, input_image, user_prompt=""):
34
  # 1. Pre-process Inputs
35
  print("Processing Input...")
36
  processed_image = resize_image_to_1mp(input_image)
37
-
38
- # --- START FIX ---
39
- # Get the exact dimensions for the control maps
40
  target_width, target_height = processed_image.size
41
- # --- END FIX ---
42
 
43
  # 2. Get Face Embedding (Robust Mode)
44
  face_emb = self.mh.get_face_embedding(processed_image)
@@ -58,10 +50,7 @@ class Generator:
58
 
59
  # 4. Generate Control Maps (Structure)
60
  print("Generating Control Maps (Depth, LineArt)...")
61
- # --- START FIX ---
62
- # Pass target dimensions to the preprocessor
63
  depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
64
- # --- END FIX ---
65
 
66
  # 5. Logic for Face vs No-Face
67
  # ControlNet order: [InstantID, Zoe, LineArt]
@@ -73,29 +62,35 @@ class Generator:
73
  else:
74
  print("No face detected: Disabling InstantID.")
75
  controlnet_conditioning_scale = [0.0, 0.4, 0.4] # Disable InstantID weight
76
- control_guidance_end = [0.5, 0.8, 0.8] # Set end to avoid 0.0 >= 0.0 error
77
  self.mh.pipeline.set_ip_adapter_scale(0.0)
 
 
 
 
 
 
78
 
79
  # 6. Run Inference
80
  print("Running pipeline...")
81
  result = self.mh.pipeline(
82
  prompt=final_prompt,
83
- image=processed_image, # <-- Base image for Img2Img
84
-
85
- # All 3 images are now guaranteed to be the same size
86
- control_image=[processed_image, depth_map, lineart_map], # <-- ControlNet inputs
87
 
88
- image_embeds=face_emb, # <-- Face embedding for InstantID
89
-
90
- strength=0.85, # Img2Img strength (0.8-0.9 is good for style)
91
  controlnet_conditioning_scale=controlnet_conditioning_scale,
92
  control_guidance_end=control_guidance_end,
93
 
94
  # LCM settings
95
  num_inference_steps=8,
96
- guidance_scale=1.5,
 
 
97
 
98
- clip_skip=2
 
99
 
100
  ).images[0]
101
 
 
1
  import torch
2
  from config import Config
3
  from utils import resize_image_to_1mp, get_caption
4
+ from PIL import Image
5
 
6
  class Generator:
7
  def __init__(self, model_handler):
8
  self.mh = model_handler
9
 
 
10
  def prepare_control_images(self, image, width, height):
11
  """
12
  Generates conditioning maps, ensuring they are resized
 
15
  print(f"Generating control maps for {width}x{height}...")
16
 
17
  # Generate depth map
 
18
  depth_map_raw = self.mh.zoe_detector(image)
19
 
20
  # Generate lineart map
21
  lineart_map_raw = self.mh.lineart_detector(image)
22
 
23
  # Manually resize maps to match the exact output resolution
 
24
  depth_map = depth_map_raw.resize((width, height), Image.LANCZOS)
25
  lineart_map = lineart_map_raw.resize((width, height), Image.LANCZOS)
26
 
27
  return depth_map, lineart_map
 
28
 
29
  def predict(self, input_image, user_prompt=""):
30
  # 1. Pre-process Inputs
31
  print("Processing Input...")
32
  processed_image = resize_image_to_1mp(input_image)
 
 
 
33
  target_width, target_height = processed_image.size
 
34
 
35
  # 2. Get Face Embedding (Robust Mode)
36
  face_emb = self.mh.get_face_embedding(processed_image)
 
50
 
51
  # 4. Generate Control Maps (Structure)
52
  print("Generating Control Maps (Depth, LineArt)...")
 
 
53
  depth_map, lineart_map = self.prepare_control_images(processed_image, target_width, target_height)
 
54
 
55
  # 5. Logic for Face vs No-Face
56
  # ControlNet order: [InstantID, Zoe, LineArt]
 
62
  else:
63
  print("No face detected: Disabling InstantID.")
64
  controlnet_conditioning_scale = [0.0, 0.4, 0.4] # Disable InstantID weight
65
+ control_guidance_end = [0.5, 0.8, 0.8]
66
  self.mh.pipeline.set_ip_adapter_scale(0.0)
67
+
68
+ # --- START FIX for NoneType Error ---
69
+ # Create a dummy tensor instead of passing None
70
+ # Shape is (batch_size, embedding_dim)
71
+ face_emb = torch.zeros((1, 512), dtype=Config.DTYPE, device=Config.DEVICE)
72
+ # --- END FIX ---
73
 
74
  # 6. Run Inference
75
  print("Running pipeline...")
76
  result = self.mh.pipeline(
77
  prompt=final_prompt,
78
+ image=processed_image, # Base image for Img2Img
79
+ control_image=[processed_image, depth_map, lineart_map], # ControlNet inputs
80
+ image_embeds=face_emb, # Face embedding (or dummy)
 
81
 
82
+ strength=0.666, # <-- Img2Img strength
 
 
83
  controlnet_conditioning_scale=controlnet_conditioning_scale,
84
  control_guidance_end=control_guidance_end,
85
 
86
  # LCM settings
87
  num_inference_steps=8,
88
+ guidance_scale=1.75, # <-- CFG Scale
89
+
90
+ clip_skip=2,
91
 
92
+ # --- LoRA Strength ---
93
+ cross_attention_kwargs={"scale": 1.333}
94
 
95
  ).images[0]
96