chris-propeller commited on
Commit
471a44f
·
1 Parent(s): 6f66eec

add vectorize

Browse files
Files changed (1) hide show
  1. app.py +94 -3
app.py CHANGED
@@ -124,16 +124,58 @@ def api_predict(image, text_prompt, confidence_threshold):
124
  except Exception as e:
125
  return {"error": str(e)}
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  @spaces.GPU
128
  def sam2_compatible_api(data):
129
  """
130
  SAM2-compatible API endpoint with SAM3 extensions
131
  Supports text prompts (SAM3), points, and boxes (SAM2 compatible)
 
132
  """
133
  import numpy as np
134
  from PIL import Image
135
  import base64
136
  import io
 
137
 
138
  try:
139
  inputs_data = data.get("inputs", {})
@@ -145,6 +187,8 @@ def sam2_compatible_api(data):
145
  input_labels = inputs_data.get("labels", [])
146
  input_boxes = inputs_data.get("boxes", [])
147
  confidence_threshold = inputs_data.get("confidence_threshold", 0.5)
 
 
148
 
149
  # Validate inputs
150
  if not image_b64:
@@ -165,9 +209,11 @@ def sam2_compatible_api(data):
165
  image_b64 = image_b64.split(',')[1]
166
  image_bytes = base64.b64decode(image_b64)
167
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
168
 
169
  all_masks = []
170
  all_scores = []
 
171
 
172
  # Process text prompts (SAM3 feature)
173
  if has_text:
@@ -188,6 +234,12 @@ def sam2_compatible_api(data):
188
  all_masks.append(mask_b64)
189
  all_scores.append(score)
190
 
 
 
 
 
 
 
191
  # Process visual prompts (SAM2 compatibility) - Basic implementation
192
  if has_boxes or has_points:
193
  # For visual prompts, use a generic prompt to get masks
@@ -211,8 +263,14 @@ def sam2_compatible_api(data):
211
  all_masks.append(mask_b64)
212
  all_scores.append(score)
213
 
 
 
 
 
 
 
214
  # Build SAM2-compatible response
215
- return {
216
  "masks": all_masks,
217
  "scores": all_scores,
218
  "num_objects": len(all_masks),
@@ -220,6 +278,16 @@ def sam2_compatible_api(data):
220
  "success": True
221
  }
222
 
 
 
 
 
 
 
 
 
 
 
223
  except Exception as e:
224
  return {"error": str(e), "success": False, "sam_version": "3.0"}
225
 
@@ -340,6 +408,22 @@ response = requests.post(
340
  }
341
  )
342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  result = response.json()
344
  </pre>
345
 
@@ -361,7 +445,9 @@ result = response.json()
361
  "boxes": [[x1, y1, x2, y2], [x1, y1, x2, y2]], // Bounding boxes
362
 
363
  "multimask_output": false, // Optional, defaults to False
364
- "confidence_threshold": 0.5 // Optional, minimum confidence for returned masks
 
 
365
  }
366
  }
367
  </pre>
@@ -373,7 +459,12 @@ result = response.json()
373
  "scores": [0.95, 0.87],
374
  "num_objects": 2,
375
  "sam_version": "3.0",
376
- "success": true
 
 
 
 
 
377
  }
378
  </pre>
379
  """)
 
124
  except Exception as e:
125
  return {"error": str(e)}
126
 
127
+ def _mask_to_polygons_original_size(binary_mask, epsilon=1.0):
128
+ """
129
+ Convert binary mask to vector polygons (mask is already at original image size)
130
+
131
+ Args:
132
+ binary_mask: Binary mask array (0 or 1) at original image size
133
+ epsilon: Polygon simplification epsilon
134
+
135
+ Returns:
136
+ List of polygons, where each polygon is a list of [x, y] points in pixel coordinates
137
+ """
138
+ import cv2
139
+ import numpy as np
140
+
141
+ try:
142
+ # Find contours using OpenCV
143
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
144
+ polygons = []
145
+
146
+ for contour in contours:
147
+ if len(contour) < 3: # Skip small contours
148
+ continue
149
+
150
+ # Simplify polygon using Douglas-Peucker algorithm
151
+ simplified = cv2.approxPolyDP(contour, epsilon, True)
152
+
153
+ # Convert to list of [x, y] points
154
+ polygon_points = [[float(point[0][0]), float(point[0][1])] for point in simplified]
155
+
156
+ # Only add polygons with at least 3 points
157
+ if len(polygon_points) >= 3:
158
+ polygons.append(polygon_points)
159
+
160
+ return polygons
161
+
162
+ except Exception as e:
163
+ # Return empty list on error, but don't fail the entire request
164
+ print(f"Warning: Polygon extraction failed: {e}")
165
+ return []
166
+
167
  @spaces.GPU
168
  def sam2_compatible_api(data):
169
  """
170
  SAM2-compatible API endpoint with SAM3 extensions
171
  Supports text prompts (SAM3), points, and boxes (SAM2 compatible)
172
+ Includes vectorize option for polygon extraction
173
  """
174
  import numpy as np
175
  from PIL import Image
176
  import base64
177
  import io
178
+ import cv2
179
 
180
  try:
181
  inputs_data = data.get("inputs", {})
 
187
  input_labels = inputs_data.get("labels", [])
188
  input_boxes = inputs_data.get("boxes", [])
189
  confidence_threshold = inputs_data.get("confidence_threshold", 0.5)
190
+ vectorize = inputs_data.get("vectorize", False)
191
+ simplify_epsilon = inputs_data.get("simplify_epsilon", 1.0)
192
 
193
  # Validate inputs
194
  if not image_b64:
 
209
  image_b64 = image_b64.split(',')[1]
210
  image_bytes = base64.b64decode(image_b64)
211
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
212
+ original_image_size = image.size # Store for response metadata
213
 
214
  all_masks = []
215
  all_scores = []
216
+ all_polygons = []
217
 
218
  # Process text prompts (SAM3 feature)
219
  if has_text:
 
234
  all_masks.append(mask_b64)
235
  all_scores.append(score)
236
 
237
+ # Extract polygons if vectorize is enabled
238
+ if vectorize:
239
+ binary_mask = (mask_np > 0).astype(np.uint8)
240
+ polygons = _mask_to_polygons_original_size(binary_mask, simplify_epsilon)
241
+ all_polygons.append(polygons)
242
+
243
  # Process visual prompts (SAM2 compatibility) - Basic implementation
244
  if has_boxes or has_points:
245
  # For visual prompts, use a generic prompt to get masks
 
263
  all_masks.append(mask_b64)
264
  all_scores.append(score)
265
 
266
+ # Extract polygons if vectorize is enabled
267
+ if vectorize:
268
+ binary_mask = (mask_np > 0).astype(np.uint8)
269
+ polygons = _mask_to_polygons_original_size(binary_mask, simplify_epsilon)
270
+ all_polygons.append(polygons)
271
+
272
  # Build SAM2-compatible response
273
+ response = {
274
  "masks": all_masks,
275
  "scores": all_scores,
276
  "num_objects": len(all_masks),
 
278
  "success": True
279
  }
280
 
281
+ # Add polygon data if vectorize is enabled
282
+ if vectorize:
283
+ response.update({
284
+ "polygons": all_polygons,
285
+ "polygon_format": "pixel_coordinates",
286
+ "original_image_size": original_image_size
287
+ })
288
+
289
+ return response
290
+
291
  except Exception as e:
292
  return {"error": str(e), "success": False, "sam_version": "3.0"}
293
 
 
408
  }
409
  )
410
 
411
+ # SAM3 with Vectorize (returns both masks and polygons)
412
+ response = requests.post(
413
+ "https://your-username-sam3-api.hf.space/api/sam2_compatible",
414
+ json={
415
+ "data": {
416
+ "inputs": {
417
+ "image": image_b64,
418
+ "text_prompts": ["cat"],
419
+ "confidence_threshold": 0.5,
420
+ "vectorize": true,
421
+ "simplify_epsilon": 1.0
422
+ }
423
+ }
424
+ }
425
+ )
426
+
427
  result = response.json()
428
  </pre>
429
 
 
445
  "boxes": [[x1, y1, x2, y2], [x1, y1, x2, y2]], // Bounding boxes
446
 
447
  "multimask_output": false, // Optional, defaults to False
448
+ "confidence_threshold": 0.5, // Optional, minimum confidence for returned masks
449
+ "vectorize": false, // Optional, return vector polygons instead of/in addition to bitmaps
450
+ "simplify_epsilon": 1.0 // Optional, polygon simplification factor
451
  }
452
  }
453
  </pre>
 
459
  "scores": [0.95, 0.87],
460
  "num_objects": 2,
461
  "sam_version": "3.0",
462
+ "success": true,
463
+
464
+ // If vectorize=true, additional fields:
465
+ "polygons": [[[x1,y1],[x2,y2],...], [[x1,y1],...]], // Array of polygon arrays for each object
466
+ "polygon_format": "pixel_coordinates",
467
+ "original_image_size": [width, height]
468
  }
469
  </pre>
470
  """)