alexander1i commited on
Commit
3c0d2f8
·
1 Parent(s): 5e07795

fix handler

Browse files
Files changed (1) hide show
  1. handler.py +171 -29
handler.py CHANGED
@@ -1,33 +1,175 @@
1
- # HANDLER_MARKER_2025-08-29
2
- import io, base64
 
3
  from PIL import Image
4
  import torch
5
- from diffusers import StableDiffusionXLInpaintPipeline
 
 
 
6
 
7
  class EndpointHandler:
8
- def __init__(self, path="."):
9
- print("HANDLER_MARKER_2025-08-29: loading WITHOUT variant arg")
10
- model_id = "andro-flock/LUSTIFY-SDXL-NSFW-checkpoint-v2-0-INPAINTING"
11
- # NO variant= here
12
- self.pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
13
- model_id, torch_dtype=torch.float16, use_safetensors=True
14
- ).to("cuda")
15
- self.pipe.enable_attention_slicing()
16
-
17
- def _to_pil(self, data, mode):
18
- if isinstance(data, str):
19
- import base64, io
20
- data = base64.b64decode(data)
21
- return Image.open(io.BytesIO(data)).convert(mode)
22
-
23
- def __call__(self, data):
24
- prompt = data.get("prompt", "")
25
- init_img = self._to_pil(data["image"], "RGB")
26
- mask_img = self._to_pil(data["mask"], "L") # white=repaint, black=keep
27
- steps = int(data.get("num_inference_steps", 30))
28
- guidance = float(data.get("guidance_scale", 7.0))
29
- strength = float(data.get("strength", 0.85))
30
- out = self.pipe(prompt=prompt, image=init_img, mask_image=mask_img,
31
- num_inference_steps=steps, guidance_scale=guidance, strength=strength).images[0]
32
- buf = io.BytesIO(); out.save(buf, format="PNG")
33
- return {"image_base64": base64.b64encode(buf.getvalue()).decode()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HANDLER v5 — SDXL txt2img + inpaint, no 'variant' usage, robust inputs
2
+ import os, io, json, base64
3
+ from typing import Any, Dict
4
  from PIL import Image
5
  import torch
6
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLInpaintPipeline
7
+ from huggingface_hub import snapshot_download
8
+
9
+ MODEL_ID = os.getenv("MODEL_ID", "andro-flock/LUSTIFY-SDXL-NSFW-checkpoint-v2-0-INPAINTING")
10
 
11
  class EndpointHandler:
12
+ def __init__(self, path: str = "."):
13
+ print("HANDLER v5: init start")
14
+ token = os.getenv("HF_TOKEN") # optional, for gated/private repos
15
+ # Download repo locally first to avoid variant resolution issues
16
+ local_dir = snapshot_download(MODEL_ID, token=token)
17
+ print(f"HANDLER v5: snapshot at {local_dir}")
18
+
19
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ self.pipe_txt2img = None
21
+ self.pipe_inpaint = None
22
+ last_err = None
23
+
24
+ # Try fp16 → bf16 → fp32
25
+ for dtype in (torch.float16, torch.bfloat16, torch.float32):
26
+ try:
27
+ # Try to load txt2img (some inpaint repos still work for txt2img)
28
+ try:
29
+ self.pipe_txt2img = StableDiffusionXLPipeline.from_pretrained(
30
+ local_dir, torch_dtype=dtype, use_safetensors=True
31
+ ).to(self.device)
32
+ print(f"HANDLER v5: txt2img OK ({dtype})")
33
+ except Exception as e:
34
+ self.pipe_txt2img = None
35
+ print(f"HANDLER v5: txt2img failed on {dtype}: {e}")
36
+
37
+ # Load inpaint (required)
38
+ self.pipe_inpaint = StableDiffusionXLInpaintPipeline.from_pretrained(
39
+ local_dir, torch_dtype=dtype, use_safetensors=True
40
+ ).to(self.device)
41
+ print(f"HANDLER v5: inpaint OK ({dtype})")
42
+
43
+ break # success on this dtype
44
+ except Exception as e:
45
+ last_err = e
46
+ self.pipe_txt2img = None
47
+ self.pipe_inpaint = None
48
+ print(f"HANDLER v5: inpaint failed on {dtype}: {e}")
49
+
50
+ if self.pipe_inpaint is None:
51
+ raise RuntimeError(f"Failed to load pipelines: {last_err}")
52
+
53
+ # Light memory tweaks
54
+ try:
55
+ self.pipe_inpaint.enable_attention_slicing()
56
+ if self.pipe_txt2img:
57
+ self.pipe_txt2img.enable_attention_slicing()
58
+ except Exception:
59
+ pass
60
+
61
+ print("HANDLER v5: ready")
62
+
63
+ # ---------- helpers ----------
64
+ def _unwrap(self, data: Dict[str, Any]) -> Dict[str, Any]:
65
+ # Accept {"inputs": {...}} or raw dict
66
+ if "inputs" in data:
67
+ inner = data["inputs"]
68
+ if isinstance(inner, str):
69
+ try:
70
+ return json.loads(inner)
71
+ except Exception:
72
+ return {}
73
+ if isinstance(inner, dict):
74
+ return inner
75
+ return data
76
+
77
+ def _to_pil(self, payload: Any, mode: str) -> Image.Image:
78
+ # Accept pure base64 bytes OR data URLs
79
+ if isinstance(payload, str):
80
+ if payload.startswith("data:"):
81
+ payload = payload.split(",", 1)[1]
82
+ payload = base64.b64decode(payload)
83
+ return Image.open(io.BytesIO(payload)).convert(mode)
84
+
85
+ def _int(self, data, key, default):
86
+ try:
87
+ return int(data.get(key, default))
88
+ except Exception:
89
+ return default
90
+
91
+ def _float(self, data, key, default):
92
+ try:
93
+ return float(data.get(key, default))
94
+ except Exception:
95
+ return default
96
+
97
+ # ---------- main entry ----------
98
+ def __call__(self, data: Dict[str, Any]):
99
+ data = self._unwrap(data)
100
+
101
+ prompt = data.get("prompt", "")
102
+ negative_prompt = data.get("negative_prompt", None)
103
+ steps = self._int(data, "num_inference_steps", 30)
104
+ guidance = self._float(data, "guidance_scale", 7.0)
105
+ seed = data.get("seed", None)
106
+
107
+ generator = None
108
+ if seed is not None:
109
+ try:
110
+ generator = torch.Generator(device=self.device).manual_seed(int(seed))
111
+ except Exception:
112
+ generator = None
113
+
114
+ # --------- decide mode ---------
115
+ # (1) txt2img path: no init image provided
116
+ if "image" not in data and "init_image" not in data:
117
+ width = self._int(data, "width", 1024)
118
+ height = self._int(data, "height", 1024)
119
+ # SDXL likes multiples of 8
120
+ width = max(64, (width // 8) * 8)
121
+ height = max(64, (height // 8) * 8)
122
+
123
+ if self.pipe_txt2img is not None:
124
+ image = self.pipe_txt2img(
125
+ prompt=prompt,
126
+ negative_prompt=negative_prompt,
127
+ width=width,
128
+ height=height,
129
+ num_inference_steps=steps,
130
+ guidance_scale=guidance,
131
+ generator=generator,
132
+ ).images[0]
133
+ else:
134
+ # Fallback: synthesize from blank canvas with inpaint
135
+ canvas = Image.new("RGB", (width, height), (255, 255, 255))
136
+ mask = Image.new("L", (width, height), 255) # edit-all
137
+ image = self.pipe_inpaint(
138
+ prompt=prompt,
139
+ image=canvas,
140
+ mask_image=mask,
141
+ negative_prompt=negative_prompt,
142
+ num_inference_steps=steps,
143
+ guidance_scale=guidance,
144
+ generator=generator,
145
+ ).images[0]
146
+
147
+ # (2) inpaint path: init image (and optional mask)
148
+ else:
149
+ init_key = "image" if "image" in data else "init_image"
150
+ init_img = self._to_pil(data[init_key], "RGB")
151
+
152
+ if "mask" in data:
153
+ mask_img = self._to_pil(data["mask"], "L")
154
+ else:
155
+ # default to "edit-all" if mask omitted
156
+ mask_img = Image.new("L", init_img.size, 255)
157
+
158
+ strength = self._float(data, "strength", 0.85)
159
+
160
+ image = self.pipe_inpaint(
161
+ prompt=prompt,
162
+ image=init_img,
163
+ mask_image=mask_img,
164
+ negative_prompt=negative_prompt,
165
+ num_inference_steps=steps,
166
+ guidance_scale=guidance,
167
+ strength=strength,
168
+ generator=generator,
169
+ ).images[0]
170
+
171
+ # Return PNG as base64
172
+ buf = io.BytesIO()
173
+ image.save(buf, format="PNG")
174
+ out_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
175
+ return {"image_base64": out_b64}