LiuZichen commited on
Commit
47a0ec7
·
verified ·
1 Parent(s): 4129e4e

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +313 -0
  2. requirements.txt +10 -0
  3. util.py +207 -0
app.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import shlex
3
+ # Install the custom component if needed
4
+ subprocess.run(
5
+ shlex.split(
6
+ "pip install ./gradio_magicquillv2-0.0.1-py3-none-any.whl"
7
+ )
8
+ )
9
+ import sys
10
+ import os
11
+ import gradio as gr
12
+ import tempfile
13
+ import numpy as np
14
+ import io
15
+ import base64
16
+ import json
17
+ import uvicorn
18
+ import torch
19
+ from fastapi import FastAPI, Request
20
+ from fastapi.middleware.cors import CORSMiddleware
21
+ from gradio_client import Client, handle_file
22
+ from gradio_magicquillv2 import MagicQuillV2
23
+ from PIL import Image
24
+
25
+
26
+ from util import (
27
+ read_base64_image as read_base64_image_utils,
28
+ tensor_to_base64,
29
+ get_mask_bbox
30
+ )
31
+
32
+ # --- Configuration ---
33
+ # Set this to the URL of your backend Space (running app_backend.py)
34
+ # Example: "https://huggingface.co/spaces/username/backend-space"
35
+ hf_token = hf_token = os.environ.get("HF_TOKEN")
36
+ BACKEND_URL = "LiuZichen/MagicQuillV2"
37
+ SAM_URL = "LiuZichen/MagicQuillHelper"
38
+
39
+ print(f"Connecting to backend at: {BACKEND_URL}")
40
+
41
+ backend_client = Client(BACKEND_URL, token=hf_token)
42
+
43
+ print(f"Connecting to SAM client at: {SAM_URL}")
44
+ sam_client = Client(SAM_URL, token=hf_token)
45
+
46
+ # --- Helper Functions ---
47
+
48
+ def generate_image_handler(x, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg):
49
+ merged_image = x['from_frontend']['img']
50
+ total_mask = x['from_frontend']['total_mask']
51
+ original_image = x['from_frontend']['original_image']
52
+ add_color_image = x['from_frontend']['add_color_image']
53
+ add_edge_mask = x['from_frontend']['add_edge_mask']
54
+ remove_edge_mask = x['from_frontend']['remove_edge_mask']
55
+ fill_mask = x['from_frontend']['fill_mask']
56
+ add_prop_image = x['from_frontend']['add_prop_image']
57
+ positive_prompt = x['from_backend']['prompt']
58
+
59
+ if backend_client is None:
60
+ print("Backend client not initialized")
61
+ x["from_backend"]["generated_image"] = None
62
+ return x
63
+
64
+ try:
65
+ # Call the backend API
66
+ # The order of arguments must match app_backend.py input list
67
+ res_base64 = backend_client.predict(
68
+ merged_image, # merged_image
69
+ total_mask, # total_mask
70
+ original_image, # original_image
71
+ add_color_image, # add_color_image
72
+ add_edge_mask, # add_edge_mask
73
+ remove_edge_mask, # remove_edge_mask
74
+ fill_mask, # fill_mask
75
+ add_prop_image, # add_prop_image
76
+ positive_prompt, # positive_prompt
77
+ negative_prompt, # negative_prompt
78
+ fine_edge, # fine_edge
79
+ fix_perspective, # fix_perspective
80
+ grow_size, # grow_size
81
+ edge_strength, # edge_strength
82
+ color_strength, # color_strength
83
+ local_strength, # local_strength
84
+ seed, # seed
85
+ steps, # steps
86
+ cfg, # cfg
87
+ api_name="/generate"
88
+ )
89
+ x["from_backend"]["generated_image"] = res_base64
90
+ except Exception as e:
91
+ print(f"Error in generation: {e}")
92
+ x["from_backend"]["generated_image"] = None
93
+
94
+ return x
95
+
96
+ # --- Gradio UI ---
97
+
98
+ with gr.Blocks(title="MagicQuill V2") as demo:
99
+ with gr.Row():
100
+ ms = MagicQuillV2()
101
+
102
+ with gr.Row():
103
+ with gr.Column():
104
+ btn = gr.Button("Run", variant="primary")
105
+ with gr.Column():
106
+ with gr.Accordion("parameters", open=False):
107
+ negative_prompt = gr.Textbox(label="Negative Prompt", value="", interactive=True)
108
+ fine_edge = gr.Radio(label="Fine Edge", choices=['enable', 'disable'], value='disable', interactive=True)
109
+ fix_perspective = gr.Radio(label="Fix Perspective", choices=['enable', 'disable'], value='disable', interactive=True)
110
+ grow_size = gr.Slider(label="Grow Size", minimum=10, maximum=100, value=50, step=1, interactive=True)
111
+ edge_strength = gr.Slider(label="Edge Strength", minimum=0.0, maximum=5.0, value=0.6, step=0.01, interactive=True)
112
+ color_strength = gr.Slider(label="Color Strength", minimum=0.0, maximum=5.0, value=1.5, step=0.01, interactive=True)
113
+ local_strength = gr.Slider(label="Local Strength", minimum=0.0, maximum=5.0, value=1.0, step=0.01, interactive=True)
114
+ seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
115
+ steps = gr.Slider(label="Steps", minimum=0, maximum=50, value=20, interactive=True)
116
+ cfg = gr.Slider(label="CFG", minimum=0.0, maximum=20.0, value=3.5, step=0.1, interactive=True)
117
+
118
+ btn.click(
119
+ generate_image_handler,
120
+ inputs=[ms, negative_prompt, fine_edge, fix_perspective, grow_size, edge_strength, color_strength, local_strength, seed, steps, cfg],
121
+ outputs=ms
122
+ )
123
+
124
+ # --- FastAPI App ---
125
+
126
+ app = FastAPI()
127
+ app.add_middleware(
128
+ CORSMiddleware,
129
+ allow_origins=['*'],
130
+ allow_credentials=True,
131
+ allow_methods=["*"],
132
+ allow_headers=["*"],
133
+ )
134
+
135
+ # Helper to fix root path if running behind proxy (Spaces)
136
+ def get_root_url(request: Request, route_path: str, root_path: str | None):
137
+ return root_path
138
+
139
+ import gradio.route_utils
140
+ gr.route_utils.get_root_url = get_root_url
141
+
142
+ # Mount the Gradio app
143
+ gr.mount_gradio_app(app, demo, path="/demo", root_path="/demo")
144
+
145
+ @app.post("/magic_quill/generate_image")
146
+ async def generate_image(request: Request):
147
+ data = await request.json()
148
+
149
+ if backend_client is None:
150
+ return {'error': 'Backend client not connected'}
151
+
152
+ try:
153
+ res = backend_client.predict(
154
+ data["merged_image"],
155
+ data["total_mask"],
156
+ data["original_image"],
157
+ data["add_color_image"],
158
+ data["add_edge_mask"],
159
+ data["remove_edge_mask"],
160
+ data["fill_mask"],
161
+ data["add_prop_image"],
162
+ data["positive_prompt"],
163
+ data["negative_prompt"],
164
+ data["fine_edge"],
165
+ data["fix_perspective"],
166
+ data["grow_size"],
167
+ data["edge_strength"],
168
+ data["color_strength"],
169
+ data["local_strength"],
170
+ data["seed"],
171
+ data["steps"],
172
+ data["cfg"],
173
+ api_name="/generate"
174
+ )
175
+ return {'res': res}
176
+ except Exception as e:
177
+ print(f"Error in backend generation: {e}")
178
+ return {'error': str(e)}
179
+
180
+ @app.post("/magic_quill/process_background_img")
181
+ async def process_background_img(request: Request):
182
+ img = await request.json()
183
+ from util import process_background
184
+ # process_background returns tensor [1, H, W, 3] in uint8 or float
185
+ resized_img_tensor = process_background(img)
186
+
187
+ # tensor_to_base64 from util expects tensor
188
+ resized_img_base64 = "data:image/webp;base64," + tensor_to_base64(
189
+ resized_img_tensor,
190
+ quality=80,
191
+ method=6
192
+ )
193
+ return resized_img_base64
194
+
195
+ @app.post("/magic_quill/segmentation")
196
+ async def segmentation(request: Request):
197
+ json_data = await request.json()
198
+ image_base64 = json_data.get("image", None)
199
+ coordinates_positive = json_data.get("coordinates_positive", None)
200
+ coordinates_negative = json_data.get("coordinates_negative", None)
201
+ bboxes = json_data.get("bboxes", None)
202
+
203
+ if sam_client is None:
204
+ return {"error": "sam client not initialized"}
205
+
206
+ # Process coordinates and bboxes (copied from original app.py)
207
+ pos_coordinates = None
208
+ if coordinates_positive and len(coordinates_positive) > 0:
209
+ pos_coordinates = []
210
+ for coord in coordinates_positive:
211
+ coord['x'] = int(round(coord['x']))
212
+ coord['y'] = int(round(coord['y']))
213
+ pos_coordinates.append({'x': coord['x'], 'y': coord['y']})
214
+ pos_coordinates = json.dumps(pos_coordinates)
215
+
216
+ neg_coordinates = None
217
+ if coordinates_negative and len(coordinates_negative) > 0:
218
+ neg_coordinates = []
219
+ for coord in coordinates_negative:
220
+ coord['x'] = int(round(coord['x']))
221
+ coord['y'] = int(round(coord['y']))
222
+ neg_coordinates.append({'x': coord['x'], 'y': coord['y']})
223
+ neg_coordinates = json.dumps(neg_coordinates)
224
+
225
+ bboxes_xyxy = None
226
+ if bboxes and len(bboxes) > 0:
227
+ valid_bboxes = []
228
+ for bbox in bboxes:
229
+ if (bbox.get("startX") is None or
230
+ bbox.get("startY") is None or
231
+ bbox.get("endX") is None or
232
+ bbox.get("endY") is None):
233
+ continue
234
+ else:
235
+ x_min = max(min(int(bbox["startX"]), int(bbox["endX"])), 0)
236
+ y_min = max(min(int(bbox["startY"]), int(bbox["endY"])), 0)
237
+ x_max = int(bbox["startX"]) if int(bbox["startX"]) > int(bbox["endX"]) else int(bbox["endX"])
238
+ y_max = int(bbox["startY"]) if int(bbox["startY"]) > int(bbox["endY"]) else int(bbox["endY"])
239
+ valid_bboxes.append((x_min, y_min, x_max, y_max))
240
+
241
+ bboxes_xyxy = []
242
+ for bbox in valid_bboxes:
243
+ x_min, y_min, x_max, y_max = bbox
244
+ bboxes_xyxy.append((x_min, y_min, x_max, y_max))
245
+
246
+ if bboxes_xyxy:
247
+ bboxes_xyxy = json.dumps(bboxes_xyxy)
248
+
249
+ print(f"Segmentation request: pos={pos_coordinates}, neg={neg_coordinates}, bboxes={bboxes_xyxy}")
250
+
251
+ try:
252
+ # Save base64 image to temp file
253
+ image_bytes = read_base64_image_utils(image_base64)
254
+ pil_image = Image.open(image_bytes)
255
+ with tempfile.NamedTemporaryFile(suffix=".webp", delete=False) as temp_in:
256
+ pil_image.save(temp_in.name, format="WEBP", quality=80)
257
+ temp_in_path = temp_in.name
258
+
259
+ # Execute segmentation via Client
260
+ result_path = sam_client.predict(
261
+ handle_file(temp_in_path),
262
+ pos_coordinates,
263
+ neg_coordinates,
264
+ bboxes_xyxy,
265
+ api_name="/segment"
266
+ )
267
+
268
+ os.unlink(temp_in_path)
269
+
270
+ if isinstance(result_path, (list, tuple)):
271
+ result_path = result_path[0]
272
+
273
+ if not result_path or not os.path.exists(result_path):
274
+ raise RuntimeError("Client returned invalid result path")
275
+
276
+ mask_pil = Image.open(result_path)
277
+ if mask_pil.mode != 'L':
278
+ mask_pil = mask_pil.convert('L')
279
+
280
+ pil_image = pil_image.convert("RGB")
281
+ if pil_image.size != mask_pil.size:
282
+ mask_pil = mask_pil.resize(pil_image.size, Image.NEAREST)
283
+
284
+ r, g, b = pil_image.split()
285
+ res_pil = Image.merge("RGBA", (r, g, b, mask_pil))
286
+
287
+ mask_tensor = torch.from_numpy(np.array(mask_pil) / 255.0).float().unsqueeze(0)
288
+ mask_bbox = get_mask_bbox(mask_tensor)
289
+ if mask_bbox:
290
+ x_min, y_min, x_max, y_max = mask_bbox
291
+ seg_bbox = {'startX': x_min, 'startY': y_min, 'endX': x_max, 'endY': y_max}
292
+ else:
293
+ seg_bbox = {'startX': 0, 'startY': 0, 'endX': 0, 'endY': 0}
294
+
295
+ print(seg_bbox)
296
+
297
+ buffered = io.BytesIO()
298
+ res_pil.save(buffered, format="PNG")
299
+ image_base64_res = base64.b64encode(buffered.getvalue()).decode("utf-8")
300
+
301
+ return {
302
+ "error": False,
303
+ "segmentation_image": "data:image/png;base64," + image_base64_res,
304
+ "segmentation_bbox": seg_bbox
305
+ }
306
+
307
+ except Exception as e:
308
+ print(f"Error in segmentation: {e}")
309
+ return {"error": str(e)}
310
+
311
+ if __name__ == "__main__":
312
+ uvicorn.run(app, host="0.0.0.0", port=7860)
313
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ gradio==5.4.0
4
+ gradio_client
5
+ numpy
6
+ opencv-python
7
+ pillow
8
+ requests
9
+ torch
10
+ torchvision
util.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import Counter
3
+ import numpy as np
4
+ from torchvision import transforms
5
+ import cv2 # OpenCV
6
+ import torch
7
+ import re
8
+ import io
9
+ import base64
10
+ from PIL import Image, ImageOps
11
+
12
+ PREFERRED_KONTEXT_RESOLUTIONS = [
13
+ (672, 1568),
14
+ (688, 1504),
15
+ (720, 1456),
16
+ (752, 1392),
17
+ (800, 1328),
18
+ (832, 1248),
19
+ (880, 1184),
20
+ (944, 1104),
21
+ (1024, 1024),
22
+ (1104, 944),
23
+ (1184, 880),
24
+ (1248, 832),
25
+ (1328, 800),
26
+ (1392, 752),
27
+ (1456, 720),
28
+ (1504, 688),
29
+ (1568, 672),
30
+ ]
31
+
32
+ def get_bounding_box_from_mask(mask, padded=False):
33
+ mask = mask.squeeze()
34
+ rows, cols = torch.where(mask > 0.5)
35
+ if len(rows) == 0 or len(cols) == 0:
36
+ return (0, 0, 0, 0)
37
+ height, width = mask.shape
38
+ if padded:
39
+ padded_size = max(width, height)
40
+ if width < height:
41
+ offset_x = (padded_size - width) / 2
42
+ offset_y = 0
43
+ else:
44
+ offset_y = (padded_size - height) / 2
45
+ offset_x = 0
46
+ top_left_x = round(float((torch.min(cols).item() + offset_x) / padded_size), 3)
47
+ bottom_right_x = round(float((torch.max(cols).item() + offset_x) / padded_size), 3)
48
+ top_left_y = round(float((torch.min(rows).item() + offset_y) / padded_size), 3)
49
+ bottom_right_y = round(float((torch.max(rows).item() + offset_y) / padded_size), 3)
50
+ else:
51
+ offset_x = 0
52
+ offset_y = 0
53
+
54
+ top_left_x = round(float(torch.min(cols).item() / width), 3)
55
+ bottom_right_x = round(float(torch.max(cols).item() / width), 3)
56
+ top_left_y = round(float(torch.min(rows).item() / height), 3)
57
+ bottom_right_y = round(float(torch.max(rows).item() / height), 3)
58
+
59
+
60
+ return (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
61
+
62
+ def extract_bbox(text):
63
+ pattern = r"\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]"
64
+ match = re.search(pattern, text)
65
+ return (int(match.group(1)), int(match.group(2)), int(match.group(3)), int(match.group(4)))
66
+
67
+ def resize_bbox(bbox, width_ratio, height_ratio):
68
+ x1, y1, x2, y2 = bbox
69
+ new_x1 = int(x1 * width_ratio)
70
+ new_y1 = int(y1 * height_ratio)
71
+ new_x2 = int(x2 * width_ratio)
72
+ new_y2 = int(y2 * height_ratio)
73
+
74
+ return (new_x1, new_y1, new_x2, new_y2)
75
+
76
+
77
+ def tensor_to_base64(tensor, quality=80, method=6):
78
+ tensor = tensor.squeeze(0).clone().detach().cpu()
79
+
80
+ if tensor.dtype == torch.float32 or tensor.dtype == torch.float64 or tensor.dtype == torch.float16:
81
+ tensor *= 255
82
+ tensor = tensor.to(torch.uint8)
83
+
84
+ if tensor.ndim == 2: # 灰度图像
85
+ pil_image = Image.fromarray(tensor.numpy(), 'L')
86
+ pil_image = pil_image.convert('RGB')
87
+ elif tensor.ndim == 3:
88
+ if tensor.shape[2] == 1: # 单通道
89
+ pil_image = Image.fromarray(tensor.numpy().squeeze(2), 'L')
90
+ pil_image = pil_image.convert('RGB')
91
+ elif tensor.shape[2] == 3: # RGB
92
+ pil_image = Image.fromarray(tensor.numpy(), 'RGB')
93
+ elif tensor.shape[2] == 4: # RGBA
94
+ pil_image = Image.fromarray(tensor.numpy(), 'RGBA')
95
+ else:
96
+ raise ValueError(f"Unsupported number of channels: {tensor.shape[2]}")
97
+ else:
98
+ raise ValueError(f"Unsupported tensor dimensions: {tensor.ndim}")
99
+
100
+ buffered = io.BytesIO()
101
+ pil_image.save(buffered, format="WEBP", quality=quality, method=method, lossless=False)
102
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
103
+ return img_str
104
+
105
+ def load_and_preprocess_image(image_path, convert_to='RGB', has_alpha=False):
106
+ image = Image.open(image_path)
107
+ image = ImageOps.exif_transpose(image)
108
+
109
+ if image.mode == 'RGBA':
110
+ background = Image.new('RGBA', image.size, (255, 255, 255, 255))
111
+ image = Image.alpha_composite(background, image)
112
+ image = image.convert(convert_to)
113
+ image_array = np.array(image).astype(np.float32) / 255.0
114
+
115
+ if has_alpha and convert_to == 'RGBA':
116
+ image_tensor = torch.from_numpy(image_array)[None,]
117
+ else:
118
+ if len(image_array.shape) == 3 and image_array.shape[2] > 3:
119
+ image_array = image_array[:, :, :3]
120
+ image_tensor = torch.from_numpy(image_array)[None,]
121
+
122
+ return image_tensor
123
+
124
+ def process_background(base64_image, convert_to='RGB', size=None):
125
+ image_data = read_base64_image(base64_image)
126
+ image = Image.open(image_data)
127
+ image = ImageOps.exif_transpose(image)
128
+ image = image.convert(convert_to)
129
+
130
+ # Select preferred size by closest aspect ratio, then snap to multiple_of
131
+ w0, h0 = image.size
132
+ aspect_ratio = (w0 / h0) if h0 != 0 else 1.0
133
+ # Choose the (w, h) whose aspect ratio is closest to the input
134
+ _, tw, th = min((abs(aspect_ratio - w / h), w, h) for (w, h) in PREFERRED_KONTEXT_RESOLUTIONS)
135
+ multiple_of = 16 # default: vae_scale_factor (8) * 2
136
+ tw = (tw // multiple_of) * multiple_of
137
+ th = (th // multiple_of) * multiple_of
138
+
139
+ if (w0, h0) != (tw, th):
140
+ image = image.resize((tw, th), resample=Image.BICUBIC)
141
+
142
+ image_array = np.array(image).astype(np.uint8)
143
+ image_tensor = torch.from_numpy(image_array)[None,]
144
+ return image_tensor
145
+
146
+ def read_base64_image(base64_image):
147
+ if base64_image.startswith("data:image/png;base64,"):
148
+ base64_image = base64_image.split(",")[1]
149
+ elif base64_image.startswith("data:image/jpeg;base64,"):
150
+ base64_image = base64_image.split(",")[1]
151
+ elif base64_image.startswith("data:image/webp;base64,"):
152
+ base64_image = base64_image.split(",")[1]
153
+ else:
154
+ raise ValueError("Unsupported image format.")
155
+ image_data = base64.b64decode(base64_image)
156
+ return io.BytesIO(image_data)
157
+
158
+ def create_alpha_mask(image_path):
159
+ """Create an alpha mask from the alpha channel of an image."""
160
+ image = Image.open(image_path)
161
+ image = ImageOps.exif_transpose(image)
162
+ mask = torch.zeros((1, image.height, image.width), dtype=torch.float32)
163
+ if 'A' in image.getbands():
164
+ alpha_channel = np.array(image.getchannel('A')).astype(np.float32) / 255.0
165
+ mask[0] = 1.0 - torch.from_numpy(alpha_channel)
166
+ return mask
167
+
168
+ def get_mask_bbox(mask_tensor, padding=10):
169
+ assert len(mask_tensor.shape) == 3 and mask_tensor.shape[0] == 1
170
+ _, H, W = mask_tensor.shape
171
+ mask_2d = mask_tensor.squeeze(0)
172
+
173
+ y_coords, x_coords = torch.where(mask_2d > 0)
174
+
175
+ if len(y_coords) == 0:
176
+ return None
177
+
178
+ x_min = int(torch.min(x_coords))
179
+ y_min = int(torch.min(y_coords))
180
+ x_max = int(torch.max(x_coords))
181
+ y_max = int(torch.max(y_coords))
182
+
183
+ x_min = max(0, x_min - padding)
184
+ y_min = max(0, y_min - padding)
185
+ x_max = min(W - 1, x_max + padding)
186
+ y_max = min(H - 1, y_max + padding)
187
+
188
+ return x_min, y_min, x_max, y_max
189
+
190
+ def tensor_to_pil(tensor):
191
+ tensor = tensor.squeeze(0).clone().detach().cpu()
192
+ if tensor.dtype in [torch.float32, torch.float64, torch.float16]:
193
+ if tensor.max() <= 1.0:
194
+ tensor *= 255
195
+ tensor = tensor.to(torch.uint8)
196
+
197
+ if tensor.ndim == 2: # 灰度图像 [H, W]
198
+ return Image.fromarray(tensor.numpy(), 'L')
199
+ elif tensor.ndim == 3:
200
+ if tensor.shape[2] == 1: # 单通道 [H, W, 1]
201
+ return Image.fromarray(tensor.numpy().squeeze(2), 'L')
202
+ elif tensor.shape[2] >= 3: # RGB [H, W, 3]
203
+ return Image.fromarray(tensor.numpy(), 'RGB')
204
+ else:
205
+ raise ValueError(f"不支持的通道数: {tensor.shape[2]}")
206
+ else:
207
+ raise ValueError(f"不支持的tensor维度: {tensor.ndim}")