iljung1106 commited on
Commit
8dd354e
·
1 Parent(s): 4ccf22e

cv2.CacadeClassifier error fix attempt

Browse files
Files changed (1) hide show
  1. inference_utils.py +23 -4
inference_utils.py CHANGED
@@ -8,6 +8,7 @@ import numpy as np
8
  import torch
9
  from PIL import Image
10
  from torchvision import transforms
 
11
 
12
  # Shim for spaces to allow local execution without the package
13
  try:
@@ -69,7 +70,17 @@ class FaceEyeExtractor:
69
  self._yolo_model = None
70
  self._yolo_device = None
71
  self._stride = 32
72
- self._cascade = None
 
 
 
 
 
 
 
 
 
 
73
 
74
  def _patch_torch_load_for_old_ckpt(self):
75
  import torch as _torch
@@ -151,9 +162,10 @@ class FaceEyeExtractor:
151
  if not self.cascade_path.exists():
152
  raise RuntimeError(f"Cascade xml not found: {self.cascade_path}")
153
 
154
- self._cascade = cv2.CascadeClassifier(str(self.cascade_path))
155
- if self._cascade.empty():
156
  raise RuntimeError(f"cascade load failed: {self.cascade_path}")
 
157
 
158
  def _letterbox_compat(self, img0, new_shape, stride):
159
  from utils.datasets import letterbox # type: ignore
@@ -255,7 +267,14 @@ class FaceEyeExtractor:
255
  dyn_min = int(0.07 * min_side)
256
  min_sz = max(8, int(self.eye_min_size), dyn_min)
257
 
258
- raw = self._cascade.detectMultiScale(
 
 
 
 
 
 
 
259
  proc,
260
  scaleFactor=1.15,
261
  minNeighbors=self.neighbors,
 
8
  import torch
9
  from PIL import Image
10
  from torchvision import transforms
11
+ import threading
12
 
13
  # Shim for spaces to allow local execution without the package
14
  try:
 
70
  self._yolo_model = None
71
  self._yolo_device = None
72
  self._stride = 32
73
+ self._tl = threading.local()
74
+
75
+ def __getstate__(self):
76
+ state = self.__dict__.copy()
77
+ if "_tl" in state:
78
+ del state["_tl"]
79
+ return state
80
+
81
+ def __setstate__(self, state):
82
+ self.__dict__.update(state)
83
+ self._tl = threading.local()
84
 
85
  def _patch_torch_load_for_old_ckpt(self):
86
  import torch as _torch
 
162
  if not self.cascade_path.exists():
163
  raise RuntimeError(f"Cascade xml not found: {self.cascade_path}")
164
 
165
+ cascade = cv2.CascadeClassifier(str(self.cascade_path))
166
+ if cascade.empty():
167
  raise RuntimeError(f"cascade load failed: {self.cascade_path}")
168
+ self._tl.cascade = cascade
169
 
170
  def _letterbox_compat(self, img0, new_shape, stride):
171
  from utils.datasets import letterbox # type: ignore
 
267
  dyn_min = int(0.07 * min_side)
268
  min_sz = max(8, int(self.eye_min_size), dyn_min)
269
 
270
+ cascade = getattr(self._tl, 'cascade', None)
271
+ if cascade is None:
272
+ cascade = cv2.CascadeClassifier(str(self.cascade_path))
273
+ if cascade.empty():
274
+ raise RuntimeError(f"cascade load failed: {self.cascade_path}")
275
+ self._tl.cascade = cascade
276
+
277
+ raw = cascade.detectMultiScale(
278
  proc,
279
  scaleFactor=1.15,
280
  minNeighbors=self.neighbors,