chris-propeller commited on
Commit
a45e44e
·
1 Parent(s): acd640e

attempt to fix multiple text prompts

Browse files
Files changed (1) hide show
  1. app.py +82 -58
app.py CHANGED
@@ -279,61 +279,74 @@ def sam2_compatible_api(data):
279
  if has_points or has_boxes:
280
  prompt_types.append("visual")
281
 
282
- # Prepare inputs for combined SAM3 inference call
283
- combined_text_prompt = None
284
- combined_boxes = None
285
- combined_box_labels = None
286
- combined_points = None
287
- combined_point_labels = None
288
-
289
- # Handle text prompts - combine multiple text prompts into one
290
  if has_text:
291
- # For multiple text prompts, join them (SAM3 can handle combined descriptions)
292
- combined_text_prompt = ", ".join(text_prompts)
293
-
294
- # Handle box prompts
295
- if has_boxes:
296
- combined_boxes = input_boxes
297
- # Create box labels (default to positive boxes if not provided)
298
- combined_box_labels = inputs_data.get("box_labels", [1] * len(input_boxes))
299
-
300
- # Handle point prompts
301
- if has_points:
302
- combined_points = input_points
303
- combined_point_labels = input_labels
304
-
305
- # Make single combined inference call with all prompt types
306
- results = sam3_inference(
307
- image=image,
308
- text_prompt=combined_text_prompt,
309
- boxes=combined_boxes,
310
- box_labels=combined_box_labels,
311
- points=combined_points,
312
- point_labels=combined_point_labels,
313
- confidence_threshold=confidence_threshold
314
- )
315
-
316
- # Process results
317
- if results and len(results["masks"]) > 0:
318
- for i in range(len(results["masks"])):
319
- mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
320
- score = results["scores"][i].item()
321
-
322
- if score >= confidence_threshold:
323
- # Convert mask to base64
324
- mask_image = Image.fromarray(mask_np, mode='L')
325
- buffer = io.BytesIO()
326
- mask_image.save(buffer, format='PNG')
327
- mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
328
-
329
- all_masks.append(mask_b64)
330
- all_scores.append(score)
331
-
332
- # Extract polygons if vectorize is enabled
333
- if vectorize:
334
- binary_mask = (mask_np > 0).astype(np.uint8)
335
- polygons = _mask_to_polygons_original_size(binary_mask, simplify_epsilon)
336
- all_polygons.append(polygons)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
  # Build SAM2-compatible response
339
  response = {
@@ -472,13 +485,24 @@ response = requests.post(
472
  }
473
  )
474
 
475
- # SAM3 Combined Prompts (Text + Visual) - NEW CAPABILITY!
 
 
 
 
 
 
 
 
 
 
 
 
476
  response = requests.post(
477
  "https://your-username-sam3-api.hf.space/api/sam2_compatible",
478
  json={
479
  "inputs": {
480
  "image": image_b64,
481
- "text_prompts": ["cat"], # Text description
482
  "boxes": [[50, 50, 150, 150]], # Bounding box
483
  "box_labels": [0], # 0=negative (exclude this area)
484
  "points": [[200, 200]], # Point prompt
@@ -512,8 +536,8 @@ result = response.json()
512
  "inputs": {
513
  "image": "base64_encoded_image_string",
514
 
515
- // SAM3 NEW: Text-based prompts (can be combined with visual prompts)
516
- "text_prompts": ["person", "car"], // List of text descriptions
517
 
518
  // SAM2 COMPATIBLE: Point-based prompts (can be combined with text/boxes)
519
  "points": [[x1, y1], [x2, y2]], // Individual points (not nested arrays)
 
279
  if has_points or has_boxes:
280
  prompt_types.append("visual")
281
 
282
+ # Process text prompts individually (SAM3 works best with individual text prompts)
 
 
 
 
 
 
 
283
  if has_text:
284
+ for text_prompt in text_prompts:
285
+ if text_prompt.strip(): # Skip empty prompts
286
+ results = sam3_inference(
287
+ image=image,
288
+ text_prompt=text_prompt.strip(),
289
+ confidence_threshold=confidence_threshold
290
+ )
291
+
292
+ if results and len(results["masks"]) > 0:
293
+ for i in range(len(results["masks"])):
294
+ mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
295
+ score = results["scores"][i].item()
296
+
297
+ if score >= confidence_threshold:
298
+ # Convert mask to base64
299
+ mask_image = Image.fromarray(mask_np, mode='L')
300
+ buffer = io.BytesIO()
301
+ mask_image.save(buffer, format='PNG')
302
+ mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
303
+
304
+ all_masks.append(mask_b64)
305
+ all_scores.append(score)
306
+
307
+ # Extract polygons if vectorize is enabled
308
+ if vectorize:
309
+ binary_mask = (mask_np > 0).astype(np.uint8)
310
+ polygons = _mask_to_polygons_original_size(binary_mask, simplify_epsilon)
311
+ all_polygons.append(polygons)
312
+
313
+ # Process visual prompts (boxes and/or points) - can be combined in a single call
314
+ if has_boxes or has_points:
315
+ combined_boxes = input_boxes if has_boxes else None
316
+ combined_box_labels = inputs_data.get("box_labels", [1] * len(input_boxes)) if has_boxes else None
317
+ combined_points = input_points if has_points else None
318
+ combined_point_labels = input_labels if has_points else None
319
+
320
+ results = sam3_inference(
321
+ image=image,
322
+ text_prompt=None,
323
+ boxes=combined_boxes,
324
+ box_labels=combined_box_labels,
325
+ points=combined_points,
326
+ point_labels=combined_point_labels,
327
+ confidence_threshold=confidence_threshold
328
+ )
329
+
330
+ if results and len(results["masks"]) > 0:
331
+ for i in range(len(results["masks"])):
332
+ mask_np = results["masks"][i].cpu().numpy().astype(np.uint8) * 255
333
+ score = results["scores"][i].item()
334
+
335
+ if score >= confidence_threshold:
336
+ # Convert mask to base64
337
+ mask_image = Image.fromarray(mask_np, mode='L')
338
+ buffer = io.BytesIO()
339
+ mask_image.save(buffer, format='PNG')
340
+ mask_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
341
+
342
+ all_masks.append(mask_b64)
343
+ all_scores.append(score)
344
+
345
+ # Extract polygons if vectorize is enabled
346
+ if vectorize:
347
+ binary_mask = (mask_np > 0).astype(np.uint8)
348
+ polygons = _mask_to_polygons_original_size(binary_mask, simplify_epsilon)
349
+ all_polygons.append(polygons)
350
 
351
  # Build SAM2-compatible response
352
  response = {
 
485
  }
486
  )
487
 
488
+ # SAM3 with Multiple Text Prompts (processed individually)
489
+ response = requests.post(
490
+ "https://your-username-sam3-api.hf.space/api/sam2_compatible",
491
+ json={
492
+ "inputs": {
493
+ "image": image_b64,
494
+ "text_prompts": ["cat", "dog"], # Each prompt processed separately
495
+ "confidence_threshold": 0.5
496
+ }
497
+ }
498
+ )
499
+
500
+ # SAM3 Combined Visual Prompts (boxes + points in single call)
501
  response = requests.post(
502
  "https://your-username-sam3-api.hf.space/api/sam2_compatible",
503
  json={
504
  "inputs": {
505
  "image": image_b64,
 
506
  "boxes": [[50, 50, 150, 150]], # Bounding box
507
  "box_labels": [0], # 0=negative (exclude this area)
508
  "points": [[200, 200]], # Point prompt
 
536
  "inputs": {
537
  "image": "base64_encoded_image_string",
538
 
539
+ // SAM3 NEW: Text-based prompts (each processed individually for best results)
540
+ "text_prompts": ["person", "car"], // List of text descriptions - each processed separately
541
 
542
  // SAM2 COMPATIBLE: Point-based prompts (can be combined with text/boxes)
543
  "points": [[x1, y1], [x2, y2]], // Individual points (not nested arrays)