File size: 10,063 Bytes
861422e
3fdea04
 
 
861422e
3fdea04
861422e
 
 
 
3fdea04
 
861422e
 
 
 
e34b7d2
861422e
 
 
bfafb94
ae866cf
861422e
 
603d886
 
 
 
 
 
 
 
 
 
861422e
3fdea04
861422e
3fdea04
 
861422e
 
 
 
 
 
3fdea04
 
861422e
3fdea04
861422e
 
 
3fdea04
861422e
 
 
 
 
 
3fdea04
861422e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3fdea04
861422e
 
3fdea04
861422e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93317dd
 
cd5a36a
 
 
 
 
 
 
 
861422e
 
 
cd5a36a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16407bf
861422e
 
16407bf
 
 
861422e
16407bf
 
861422e
16407bf
 
 
 
 
 
 
861422e
16407bf
 
 
 
 
 
 
 
 
 
 
 
 
 
861422e
 
 
16407bf
861422e
 
16407bf
 
 
cd5a36a
16407bf
 
 
 
cd5a36a
 
 
 
16407bf
 
 
861422e
 
 
 
 
 
 
 
 
 
 
 
 
 
3fdea04
861422e
3fdea04
861422e
 
3fdea04
861422e
 
 
3fdea04
861422e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import logging
import os
from io import BytesIO

# Load environment variables from .env if present (helps local dev)
try:
    from dotenv import load_dotenv

    load_dotenv()
except Exception:
    pass

import base64
import cv2
import numpy as np
from PIL import Image
import google.generativeai as genai

log = logging.getLogger(__name__)

# Remote inference configuration (Gemini API key only; no Vertex required) gemini-3-pro-image-preview / gemini-2.5-flash-image / imagen-4.0-generate-001
DEFAULT_MODEL_ID = os.environ.get("GEMINI_IMAGE_MODEL", "gemini-3-pro-image-preview")
DEFAULT_PROMPT = os.environ.get(
    "GEMINI_IMAGE_PROMPT",
    ("Remove ONLY the white areas shown in the mask. Keep everything else EXACTLY as it is.\n\n"
    "CRITICAL RULES:\n"
    "1. White pixels in mask = REMOVE (inpaint with background)\n"
    "2. Black pixels in mask = KEEP UNCHANGED\n"
    "3. Do NOT remove similar objects outside the mask\n"
    "4. Do NOT recreate or add new objects\n"
    "5. Only modify pixels where mask is white\n"
    "6. Fill removed areas with surrounding background texture\n"
    "7. Match lighting, colors, and texture perfectly\n\n"
    "Follow the mask EXACTLY. Do not interpret semantically."
    ),
)
_GENAI_MODEL: genai.GenerativeModel | None = None


def _resize_mask(mask: np.ndarray, target_hw: tuple[int, int]) -> np.ndarray:
    """Resize mask to match the target height/width."""
    target_h, target_w = target_hw
    if mask.shape[:2] == (target_h, target_w):
        return mask
    return cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_NEAREST)


def _binary_mask_from_rgba(mask: np.ndarray, invert_mask: bool) -> np.ndarray:
    """
    Normalize incoming RGBA masks to a 0/255 binary mask.
    - Transparent alpha (0) is treated as "remove"
    - White/bright RGB is treated as "remove" when alpha is mostly opaque
    """
    if mask.shape[2] == 3:
        alpha_channel = np.ones(mask.shape[:2], dtype=np.uint8) * 255
        rgb_channels = mask
    else:
        alpha_channel = mask[:, :, 3]
        rgb_channels = mask[:, :, :3]

    # If alpha carries information, prefer it
    if alpha_channel.mean() < 240:
        mask_bw = np.where(alpha_channel < 128, 255, 0).astype(np.uint8)
    else:
        gray = cv2.cvtColor(rgb_channels, cv2.COLOR_RGB2GRAY)
        mask_bw = np.where(gray > 128, 255, 0).astype(np.uint8)

    if not invert_mask:
        mask_bw = 255 - mask_bw

    return mask_bw


def _pil_to_png_bytes(img: Image.Image) -> bytes:
    """Encode a PIL image to PNG bytes for Gemini edit endpoints."""
    buf = BytesIO()
    img.save(buf, format="PNG")
    buf.seek(0)
    return buf.getvalue()


def _get_gemini_model() -> genai.GenerativeModel:
    global _GENAI_MODEL
    if _GENAI_MODEL is None:
        api_key = (
            os.environ.get("GEMINI_API_KEY")
            or os.environ.get("GOOGLE_API_KEY")
            or os.environ.get("GOOGLE_GENAI_API_KEY")
        )
        if not api_key:
            raise RuntimeError("Set Gemini API key via GEMINI_API_KEY / GOOGLE_API_KEY / GOOGLE_GENAI_API_KEY")
        genai.configure(api_key=api_key)
        model_id = os.environ.get("GEMINI_IMAGE_MODEL", DEFAULT_MODEL_ID)
        _GENAI_MODEL = genai.GenerativeModel(model_id)
    return _GENAI_MODEL


def _call_gemini_edit(
    image_rgb: np.ndarray,
    mask_bw: np.ndarray,
    prompt: str | None,
    target_size: tuple[int, int],
) -> Image.Image:
    """
    Send source image + binary mask to Gemini via API-key-only generate_content.
    We include both the base image and the mask as separate parts and instruct the model to remove masked regions.
    """
    model = _get_gemini_model()

    base_image = Image.fromarray(image_rgb).convert("RGB")
    mask_image = Image.fromarray(mask_bw).convert("L")

    # Build a guidance image where the removal region is painted white for clarity
    guidance_rgb = image_rgb.copy()
    guidance_rgb[mask_bw > 0] = 255
    guidance_image = Image.fromarray(guidance_rgb).convert("RGB")

    base_bytes = _pil_to_png_bytes(base_image)
    mask_bytes = _pil_to_png_bytes(mask_image)
    guidance_bytes = _pil_to_png_bytes(guidance_image)

    # Enrich prompt to explicitly describe the two images being sent
    effective_prompt = (
        (prompt or DEFAULT_PROMPT).strip()
        + "\nIMAGE ORDER:\n"
        + "Image A: Original photo with the removal region painted white.\n"
        + "Image B: Binary mask (white=remove, black=keep). Use this mask to decide what to remove.\n"
    )
    log.info(
        "Calling Gemini generate_content model=%s (mask-guided remove) mask_pixels=%d",
        model.model_name,
        int((mask_bw > 0).sum()),
    )

    # Build content parts: prompt + guidance image + mask image (explicit order)
    content = [
        effective_prompt,
        {"mime_type": "image/png", "data": guidance_bytes},
        {"mime_type": "image/png", "data": mask_bytes},
    ]

    # Note: response_mime_type doesn't support image/png in the old google.generativeai package
    # Images are returned in response parts as inline_data
    try:
        response = model.generate_content(
            content,
            stream=False
        )
    except Exception as gen_err:
        log.error("Gemini generate_content raised exception: %s", gen_err, exc_info=True)
        raise RuntimeError(f"Gemini API error: {gen_err}")

    output_img: Image.Image | None = None

    # Check for blocked content or errors
    candidates = getattr(response, "candidates", [])
    if not candidates:
        log.error("Gemini returned no candidates")
        raise RuntimeError("Gemini API returned no candidates. The request may have been blocked.")
    
    # Check finish_reason for blocked content
    for idx, candidate in enumerate(candidates):
        finish_reason = getattr(candidate, "finish_reason", None)
        if finish_reason:
            # finish_reason values: 0=STOP, 1=MAX_TOKENS, 2=SAFETY, 3=RECITATION, 4=OTHER, 17=BLOCKED
            if finish_reason == 17 or finish_reason == 2:
                safety_ratings = getattr(candidate, "safety_ratings", [])
                log.error("Gemini blocked the request. Finish reason: %s, Safety ratings: %s", finish_reason, safety_ratings)
                raise RuntimeError(f"Gemini API blocked the content (finish_reason={finish_reason}). The image may violate safety policies.")
            elif finish_reason != 0:  # 0 = STOP (normal completion)
                log.warning("Gemini finished with non-zero reason: %s", finish_reason)
    
    # Extract first image from response parts
    try:
        log.debug("Number of candidates: %d", len(candidates))
        
        for idx, candidate in enumerate(candidates):
            parts = getattr(candidate, "content", None)
            if not parts:
                log.debug("Candidate %d has no content", idx)
                continue
            response_parts = getattr(parts, "parts", None)
            if not response_parts:
                log.debug("Candidate %d content has no parts", idx)
                continue
            log.debug("Candidate %d has %d parts", idx, len(response_parts))
            
            for part_idx, part in enumerate(response_parts):
                inline = getattr(part, "inline_data", None)
                if inline:
                    log.debug("Part %d has inline_data, mime_type: %s", part_idx, getattr(inline, "mime_type", None))
                    if inline.data:
                        data = inline.data
                        if isinstance(data, str):
                            data = base64.b64decode(data)
                        output_img = Image.open(BytesIO(data)).convert("RGB")
                        log.info("Successfully extracted image from Gemini response")
                        break
                else:
                    # Check if part has text (might be an error message)
                    text = getattr(part, "text", None)
                    if text:
                        log.warning("Gemini returned text instead of image in part %d: %s", part_idx, text[:200])
            if output_img:
                break
    except Exception as err:
        log.error("Failed to parse Gemini response image: %s", err, exc_info=True)

    if output_img is None:
        # Log full response for debugging
        try:
            response_text = str(response)
            log.error("Gemini generate_content returned no image. Full response (first 1000 chars): %s", response_text[:1000])
            # Try to extract any error messages
            if hasattr(response, "prompt_feedback"):
                feedback = response.prompt_feedback
                log.error("Prompt feedback: %s", feedback)
            # Check candidates for finish reasons
            for idx, candidate in enumerate(candidates):
                finish_reason = getattr(candidate, "finish_reason", None)
                log.error("Candidate %d finish_reason: %s", idx, finish_reason)
        except Exception:
            pass
        raise RuntimeError("Gemini generate_content returned no image. Check logs for details.")

    # Ensure output matches original dimensions if Gemini rescaled
    if output_img.size != target_size:
        output_img = output_img.resize(target_size, Image.Resampling.LANCZOS)

    return output_img


def process_inpaint(
    image: np.ndarray,
    mask: np.ndarray,
    invert_mask: bool = True,
    prompt: str | None = None,
) -> np.ndarray:
    """
    Forward inpainting to Gemini edit API using source image + mask.
    """
    image_rgba = Image.fromarray(image).convert("RGBA")
    image_rgb = np.array(image_rgba.convert("RGB"))

    mask_rgba = np.array(Image.fromarray(mask).convert("RGBA"))
    mask_bw = _binary_mask_from_rgba(mask_rgba, invert_mask)
    mask_bw = _resize_mask(mask_bw, image_rgb.shape[:2])

    target_size = (image_rgb.shape[1], image_rgb.shape[0])  # (width, height)
    edited_image = _call_gemini_edit(image_rgb, mask_bw, prompt, target_size)
    return np.array(edited_image)