Ayesha-Majeed commited on
Commit
4d0d001
·
verified ·
1 Parent(s): 0e213e5

Update binary_segmentation.py

Browse files
Files changed (1) hide show
  1. binary_segmentation.py +35 -35
binary_segmentation.py CHANGED
@@ -495,57 +495,57 @@ class BinarySegmenter:
495
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
496
  ])
497
 
498
- # def _load_birefnet(self):
499
- # """Load BiRefNet model (best accuracy, larger)"""
500
- # try:
501
- # from transformers import AutoModelForImageSegmentation
502
-
503
- # self.model = AutoModelForImageSegmentation.from_pretrained(
504
- # 'ZhengPeng7/BiRefNet',
505
- # trust_remote_code=True,
506
- # cache_dir=str(self.cache_dir),
507
- # torch_dtype=torch.float32,
508
- # low_cpu_mem_usage=False
509
- # )
510
-
511
- # self.transform = transforms.Compose([
512
- # transforms.Resize((320, 320)),
513
- # transforms.ToTensor(),
514
- # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
515
- # ])
516
- # except ImportError:
517
- # raise ImportError("BiRefNet requires: pip install transformers")
518
-
519
  def _load_birefnet(self):
520
  """Load BiRefNet model (best accuracy, larger)"""
521
  try:
522
  from transformers import AutoModelForImageSegmentation
523
-
524
  self.model = AutoModelForImageSegmentation.from_pretrained(
525
  'ZhengPeng7/BiRefNet',
526
  trust_remote_code=True,
527
  cache_dir=str(self.cache_dir),
528
- torch_dtype=torch.float32, # ✅ Keep FP32 for CPU
529
  low_cpu_mem_usage=False
530
  )
531
-
532
- # ✅ QUANTIZE to INT8 for CPU speedup
533
- if DEVICE == "cpu":
534
-
535
- self.model = torch.quantization.quantize_dynamic(
536
- self.model,
537
- {torch.nn.Linear, torch.nn.Conv2d},
538
- dtype=torch.qint8
539
- )
540
- logger.info("✅ BiRefNet quantized to INT8")
541
-
542
  self.transform = transforms.Compose([
543
- transforms.Resize((1024, 1024)),
544
  transforms.ToTensor(),
545
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
546
  ])
547
  except ImportError:
548
  raise ImportError("BiRefNet requires: pip install transformers")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
 
550
  def _load_rmbg(self):
551
  """Load RMBG model (good balance)"""
 
495
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
496
  ])
497
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
  def _load_birefnet(self):
499
  """Load BiRefNet model (best accuracy, larger)"""
500
  try:
501
  from transformers import AutoModelForImageSegmentation
502
+
503
  self.model = AutoModelForImageSegmentation.from_pretrained(
504
  'ZhengPeng7/BiRefNet',
505
  trust_remote_code=True,
506
  cache_dir=str(self.cache_dir),
507
+ torch_dtype=torch.float32,
508
  low_cpu_mem_usage=False
509
  )
510
+
 
 
 
 
 
 
 
 
 
 
511
  self.transform = transforms.Compose([
512
+ transforms.Resize((320, 320)),
513
  transforms.ToTensor(),
514
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
515
  ])
516
  except ImportError:
517
  raise ImportError("BiRefNet requires: pip install transformers")
518
+
519
+ # def _load_birefnet(self):
520
+ # """Load BiRefNet model (best accuracy, larger)"""
521
+ # try:
522
+ # from transformers import AutoModelForImageSegmentation
523
+
524
+ # self.model = AutoModelForImageSegmentation.from_pretrained(
525
+ # 'ZhengPeng7/BiRefNet',
526
+ # trust_remote_code=True,
527
+ # cache_dir=str(self.cache_dir),
528
+ # torch_dtype=torch.float32, # ✅ Keep FP32 for CPU
529
+ # low_cpu_mem_usage=False
530
+ # )
531
+
532
+ # # ✅ QUANTIZE to INT8 for CPU speedup
533
+ # if DEVICE == "cpu":
534
+
535
+ # self.model = torch.quantization.quantize_dynamic(
536
+ # self.model,
537
+ # {torch.nn.Linear, torch.nn.Conv2d},
538
+ # dtype=torch.qint8
539
+ # )
540
+ # logger.info("✅ BiRefNet quantized to INT8")
541
+
542
+ # self.transform = transforms.Compose([
543
+ # transforms.Resize((1024, 1024)),
544
+ # transforms.ToTensor(),
545
+ # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
546
+ # ])
547
+ # except ImportError:
548
+ # raise ImportError("BiRefNet requires: pip install transformers")
549
 
550
  def _load_rmbg(self):
551
  """Load RMBG model (good balance)"""