alexander1i commited on
Commit
7286926
·
1 Parent(s): 3c0d2f8

fix handler

Browse files
Files changed (1) hide show
  1. handler.py +57 -36
handler.py CHANGED
@@ -1,5 +1,5 @@
1
- # HANDLER v5 — SDXL txt2img + inpaint, no 'variant' usage, robust inputs
2
- import os, io, json, base64
3
  from typing import Any, Dict
4
  from PIL import Image
5
  import torch
@@ -10,47 +10,53 @@ MODEL_ID = os.getenv("MODEL_ID", "andro-flock/LUSTIFY-SDXL-NSFW-checkpoint-v2-0-
10
 
11
  class EndpointHandler:
12
  def __init__(self, path: str = "."):
13
- print("HANDLER v5: init start")
14
- token = os.getenv("HF_TOKEN") # optional, for gated/private repos
15
- # Download repo locally first to avoid variant resolution issues
16
  local_dir = snapshot_download(MODEL_ID, token=token)
17
- print(f"HANDLER v5: snapshot at {local_dir}")
18
 
19
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
  self.pipe_txt2img = None
21
  self.pipe_inpaint = None
22
  last_err = None
23
 
24
- # Try fp16 → bf16 → fp32
25
  for dtype in (torch.float16, torch.bfloat16, torch.float32):
26
  try:
27
- # Try to load txt2img (some inpaint repos still work for txt2img)
28
  try:
29
- self.pipe_txt2img = StableDiffusionXLPipeline.from_pretrained(
30
  local_dir, torch_dtype=dtype, use_safetensors=True
31
  ).to(self.device)
32
- print(f"HANDLER v5: txt2img OK ({dtype})")
 
 
 
 
 
 
 
 
 
 
33
  except Exception as e:
34
  self.pipe_txt2img = None
35
- print(f"HANDLER v5: txt2img failed on {dtype}: {e}")
36
 
37
  # Load inpaint (required)
38
  self.pipe_inpaint = StableDiffusionXLInpaintPipeline.from_pretrained(
39
  local_dir, torch_dtype=dtype, use_safetensors=True
40
  ).to(self.device)
41
- print(f"HANDLER v5: inpaint OK ({dtype})")
42
-
43
- break # success on this dtype
44
  except Exception as e:
45
  last_err = e
46
  self.pipe_txt2img = None
47
  self.pipe_inpaint = None
48
- print(f"HANDLER v5: inpaint failed on {dtype}: {e}")
49
 
50
  if self.pipe_inpaint is None:
51
  raise RuntimeError(f"Failed to load pipelines: {last_err}")
52
 
53
- # Light memory tweaks
54
  try:
55
  self.pipe_inpaint.enable_attention_slicing()
56
  if self.pipe_txt2img:
@@ -58,11 +64,10 @@ class EndpointHandler:
58
  except Exception:
59
  pass
60
 
61
- print("HANDLER v5: ready")
62
 
63
  # ---------- helpers ----------
64
  def _unwrap(self, data: Dict[str, Any]) -> Dict[str, Any]:
65
- # Accept {"inputs": {...}} or raw dict
66
  if "inputs" in data:
67
  inner = data["inputs"]
68
  if isinstance(inner, str):
@@ -74,12 +79,20 @@ class EndpointHandler:
74
  return inner
75
  return data
76
 
 
 
 
 
 
77
  def _to_pil(self, payload: Any, mode: str) -> Image.Image:
78
- # Accept pure base64 bytes OR data URLs
79
  if isinstance(payload, str):
80
- if payload.startswith("data:"):
81
- payload = payload.split(",", 1)[1]
82
- payload = base64.b64decode(payload)
 
 
 
83
  return Image.open(io.BytesIO(payload)).convert(mode)
84
 
85
  def _int(self, data, key, default):
@@ -111,12 +124,23 @@ class EndpointHandler:
111
  except Exception:
112
  generator = None
113
 
114
- # --------- decide mode ---------
115
- # (1) txt2img path: no init image provided
116
- if "image" not in data and "init_image" not in data:
 
 
 
 
 
 
 
 
 
 
 
 
117
  width = self._int(data, "width", 1024)
118
  height = self._int(data, "height", 1024)
119
- # SDXL likes multiples of 8
120
  width = max(64, (width // 8) * 8)
121
  height = max(64, (height // 8) * 8)
122
 
@@ -131,9 +155,9 @@ class EndpointHandler:
131
  generator=generator,
132
  ).images[0]
133
  else:
134
- # Fallback: synthesize from blank canvas with inpaint
135
  canvas = Image.new("RGB", (width, height), (255, 255, 255))
136
- mask = Image.new("L", (width, height), 255) # edit-all
137
  image = self.pipe_inpaint(
138
  prompt=prompt,
139
  image=canvas,
@@ -143,17 +167,14 @@ class EndpointHandler:
143
  guidance_scale=guidance,
144
  generator=generator,
145
  ).images[0]
146
-
147
- # (2) inpaint path: init image (and optional mask)
148
  else:
149
- init_key = "image" if "image" in data else "init_image"
150
- init_img = self._to_pil(data[init_key], "RGB")
151
 
152
- if "mask" in data:
153
- mask_img = self._to_pil(data["mask"], "L")
154
  else:
155
- # default to "edit-all" if mask omitted
156
- mask_img = Image.new("L", init_img.size, 255)
157
 
158
  strength = self._float(data, "strength", 0.85)
159
 
 
1
+ # HANDLER v6 — SDXL txt2img + inpaint, supports image_url/mask_url, guards UNet channels
2
+ import os, io, json, base64, requests
3
  from typing import Any, Dict
4
  from PIL import Image
5
  import torch
 
10
 
11
  class EndpointHandler:
12
  def __init__(self, path: str = "."):
13
+ print("HANDLER v6: init start")
14
+ token = os.getenv("HF_TOKEN")
 
15
  local_dir = snapshot_download(MODEL_ID, token=token)
16
+ print(f"HANDLER v6: snapshot at {local_dir}")
17
 
18
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
  self.pipe_txt2img = None
20
  self.pipe_inpaint = None
21
  last_err = None
22
 
 
23
  for dtype in (torch.float16, torch.bfloat16, torch.float32):
24
  try:
25
+ # Try to load txt2img
26
  try:
27
+ p = StableDiffusionXLPipeline.from_pretrained(
28
  local_dir, torch_dtype=dtype, use_safetensors=True
29
  ).to(self.device)
30
+ # Keep txt2img ONLY if UNet is 4-ch (proper base)
31
+ if getattr(p.unet.config, "in_channels", 4) == 4:
32
+ self.pipe_txt2img = p
33
+ print(f"HANDLER v6: txt2img OK ({dtype}, in_ch=4)")
34
+ else:
35
+ print("HANDLER v6: txt2img UNet in_ch != 4; disabling txt2img for this repo")
36
+ try:
37
+ p.to("cpu"); del p
38
+ except Exception:
39
+ pass
40
+ self.pipe_txt2img = None
41
  except Exception as e:
42
  self.pipe_txt2img = None
43
+ print(f"HANDLER v6: txt2img failed on {dtype}: {e}")
44
 
45
  # Load inpaint (required)
46
  self.pipe_inpaint = StableDiffusionXLInpaintPipeline.from_pretrained(
47
  local_dir, torch_dtype=dtype, use_safetensors=True
48
  ).to(self.device)
49
+ print(f"HANDLER v6: inpaint OK ({dtype}, in_ch={getattr(self.pipe_inpaint.unet.config, 'in_channels', 'NA')})")
50
+ break
 
51
  except Exception as e:
52
  last_err = e
53
  self.pipe_txt2img = None
54
  self.pipe_inpaint = None
55
+ print(f"HANDLER v6: inpaint failed on {dtype}: {e}")
56
 
57
  if self.pipe_inpaint is None:
58
  raise RuntimeError(f"Failed to load pipelines: {last_err}")
59
 
 
60
  try:
61
  self.pipe_inpaint.enable_attention_slicing()
62
  if self.pipe_txt2img:
 
64
  except Exception:
65
  pass
66
 
67
+ print("HANDLER v6: ready")
68
 
69
  # ---------- helpers ----------
70
  def _unwrap(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
71
  if "inputs" in data:
72
  inner = data["inputs"]
73
  if isinstance(inner, str):
 
79
  return inner
80
  return data
81
 
82
+ def _fetch_url_bytes(self, url: str) -> bytes:
83
+ r = requests.get(url, timeout=60)
84
+ r.raise_for_status()
85
+ return r.content
86
+
87
  def _to_pil(self, payload: Any, mode: str) -> Image.Image:
88
+ # Accept: bytes, base64, or data URL, or HTTP(S) URL
89
  if isinstance(payload, str):
90
+ if payload.startswith("http://") or payload.startswith("https://"):
91
+ payload = self._fetch_url_bytes(payload)
92
+ else:
93
+ if payload.startswith("data:"):
94
+ payload = payload.split(",", 1)[1]
95
+ payload = base64.b64decode(payload)
96
  return Image.open(io.BytesIO(payload)).convert(mode)
97
 
98
  def _int(self, data, key, default):
 
124
  except Exception:
125
  generator = None
126
 
127
+ # Normalize keys for images/masks
128
+ # Accept: image / init_image / image_url ; mask / mask_url
129
+ init_img_payload = None
130
+ if "image" in data:
131
+ init_img_payload = data["image"]
132
+ elif "init_image" in data:
133
+ init_img_payload = data["init_image"]
134
+ elif "image_url" in data:
135
+ init_img_payload = data["image_url"]
136
+
137
+ mask_payload = data.get("mask") or data.get("mask_url")
138
+
139
+ # --------- choose mode ---------
140
+ if init_img_payload is None:
141
+ # txt2img mode (only if we truly have a 4-ch UNet)
142
  width = self._int(data, "width", 1024)
143
  height = self._int(data, "height", 1024)
 
144
  width = max(64, (width // 8) * 8)
145
  height = max(64, (height // 8) * 8)
146
 
 
155
  generator=generator,
156
  ).images[0]
157
  else:
158
+ # Fallback: blank-canvas inpaint (works with 9-ch UNet)
159
  canvas = Image.new("RGB", (width, height), (255, 255, 255))
160
+ mask = Image.new("L", (width, height), 255)
161
  image = self.pipe_inpaint(
162
  prompt=prompt,
163
  image=canvas,
 
167
  guidance_scale=guidance,
168
  generator=generator,
169
  ).images[0]
 
 
170
  else:
171
+ # inpaint mode
172
+ init_img = self._to_pil(init_img_payload, "RGB")
173
 
174
+ if mask_payload is not None:
175
+ mask_img = self._to_pil(mask_payload, "L").resize(init_img.size, Image.NEAREST)
176
  else:
177
+ mask_img = Image.new("L", init_img.size, 255) # edit-all default
 
178
 
179
  strength = self._float(data, "strength", 0.85)
180