chris-propeller commited on
Commit
4f603ce
·
1 Parent(s): a95034f
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -11,20 +11,33 @@ from transformers import Sam3Model, Sam3Processor
11
  import warnings
12
  warnings.filterwarnings("ignore")
13
 
14
- # Initialize model and processor (matching working space exactly)
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- model = Sam3Model.from_pretrained(
17
- "facebook/sam3",
18
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
19
- ).to(device)
20
- processor = Sam3Processor.from_pretrained("facebook/sam3")
 
 
 
 
 
 
 
 
 
 
21
 
22
  @spaces.GPU
23
  def sam3_inference(image, text_prompt, confidence_threshold=0.5):
24
  """
25
- Standalone GPU function matching working space pattern
26
  """
27
  try:
 
 
 
28
  # Handle base64 input (for API)
29
  if isinstance(image, str):
30
  if image.startswith('data:image'):
@@ -32,14 +45,14 @@ def sam3_inference(image, text_prompt, confidence_threshold=0.5):
32
  image_bytes = base64.b64decode(image)
33
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
34
 
35
- # Process with SAM3 (matching working space exactly)
36
  inputs = processor(
37
  images=image,
38
  text=text_prompt.strip(),
39
  return_tensors="pt"
40
  ).to(device)
41
 
42
- # Convert dtype to match model (following working space pattern)
43
  for key in inputs:
44
  if inputs[key].dtype == torch.float32:
45
  inputs[key] = inputs[key].to(model.dtype)
@@ -64,7 +77,7 @@ class SAM3Handler:
64
  """SAM3 handler for both UI and API access"""
65
 
66
  def __init__(self):
67
- print(f"SAM3 handler initialized with device: {device}")
68
 
69
  def predict(self, image, text_prompt, confidence_threshold=0.5):
70
  """
 
11
  import warnings
12
  warnings.filterwarnings("ignore")
13
 
14
+ # Global variables for lazy initialization
15
+ _model = None
16
+ _processor = None
17
+ _device = None
18
+
19
+ def get_model_and_processor():
20
+ """Lazy initialization of model and processor"""
21
+ global _model, _processor, _device
22
+ if _model is None:
23
+ _device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ _model = Sam3Model.from_pretrained(
25
+ "facebook/sam3",
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
27
+ ).to(_device)
28
+ _processor = Sam3Processor.from_pretrained("facebook/sam3")
29
+ print(f"Model loaded on device: {_device}")
30
+ return _model, _processor, _device
31
 
32
  @spaces.GPU
33
  def sam3_inference(image, text_prompt, confidence_threshold=0.5):
34
  """
35
+ Standalone GPU function with lazy model initialization for Spaces Stateless GPU
36
  """
37
  try:
38
+ # Initialize model inside GPU function (required for Stateless GPU)
39
+ model, processor, device = get_model_and_processor()
40
+
41
  # Handle base64 input (for API)
42
  if isinstance(image, str):
43
  if image.startswith('data:image'):
 
45
  image_bytes = base64.b64decode(image)
46
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
47
 
48
+ # Process with SAM3
49
  inputs = processor(
50
  images=image,
51
  text=text_prompt.strip(),
52
  return_tensors="pt"
53
  ).to(device)
54
 
55
+ # Convert dtype to match model
56
  for key in inputs:
57
  if inputs[key].dtype == torch.float32:
58
  inputs[key] = inputs[key].to(model.dtype)
 
77
  """SAM3 handler for both UI and API access"""
78
 
79
  def __init__(self):
80
+ print("SAM3 handler initialized (models will be loaded lazily)")
81
 
82
  def predict(self, image, text_prompt, confidence_threshold=0.5):
83
  """