Spaces:
Runtime error
Runtime error
Update gradio_app.py
Browse files- gradio_app.py +35 -24
gradio_app.py
CHANGED
|
@@ -208,10 +208,41 @@ with open("imgs/background.png", "rb") as f:
|
|
| 208 |
|
| 209 |
@spaces.GPU
|
| 210 |
def inference(id_image, hair_image):
|
| 211 |
-
#
|
| 212 |
-
|
| 213 |
-
raise RuntimeError("This demo requires a GPU Space. Please enable a GPU in this Space.")
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
(
|
| 216 |
log_validation,
|
| 217 |
UNet3DConditionModel,
|
|
@@ -225,26 +256,6 @@ def inference(id_image, hair_image):
|
|
| 225 |
bald_head,
|
| 226 |
) = _import_inference_bits()
|
| 227 |
|
| 228 |
-
# Disable StyleGAN2 custom CUDA ops to avoid JIT compiling (needs ninja/NVCC).
|
| 229 |
-
# ZeroGPU 下建议走纯 PyTorch 引用实现,避免扩展编译失败。
|
| 230 |
-
try:
|
| 231 |
-
from HairMapper.styleGAN2_ada_model.stylegan2_ada.torch_utils.ops import bias_act as _bias_act
|
| 232 |
-
_bias_act.USING_CUDA_TO_SPEED_UP = False
|
| 233 |
-
try:
|
| 234 |
-
from HairMapper.styleGAN2_ada_model.stylegan2_ada.torch_utils.ops import upfirdn2d as _upfirdn2d
|
| 235 |
-
if hasattr(_upfirdn2d, 'USING_CUDA_TO_SPEED_UP'):
|
| 236 |
-
_upfirdn2d.USING_CUDA_TO_SPEED_UP = False
|
| 237 |
-
except Exception:
|
| 238 |
-
pass
|
| 239 |
-
try:
|
| 240 |
-
from HairMapper.styleGAN2_ada_model.stylegan2_ada.torch_utils.ops import filtered_lrelu as _fl
|
| 241 |
-
if hasattr(_fl, 'USING_CUDA_TO_SPEED_UP'):
|
| 242 |
-
_fl.USING_CUDA_TO_SPEED_UP = False
|
| 243 |
-
except Exception:
|
| 244 |
-
pass
|
| 245 |
-
except Exception:
|
| 246 |
-
pass
|
| 247 |
-
|
| 248 |
os.makedirs("gradio_inputs", exist_ok=True)
|
| 249 |
os.makedirs("gradio_outputs", exist_ok=True)
|
| 250 |
|
|
@@ -291,7 +302,7 @@ def inference(id_image, hair_image):
|
|
| 291 |
|
| 292 |
args = Args()
|
| 293 |
|
| 294 |
-
device = torch.device("cuda"
|
| 295 |
|
| 296 |
logging.basicConfig(
|
| 297 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
|
|
| 208 |
|
| 209 |
@spaces.GPU
|
| 210 |
def inference(id_image, hair_image):
|
| 211 |
+
# ZeroGPU: 强制使用 'cuda' 设备(ZeroGPU 下 torch.cuda.is_available 可能为 False)。
|
| 212 |
+
device = torch.device("cuda")
|
|
|
|
| 213 |
|
| 214 |
+
# 先禁用 StyleGAN2 自定义 CUDA 算子(导入 HairMapper 前),避免触发 JIT 编译。
|
| 215 |
+
# 1) 禁用模块级开关
|
| 216 |
+
try:
|
| 217 |
+
from torch_utils.ops import bias_act as _bias_act2
|
| 218 |
+
_bias_act2.USING_CUDA_TO_SPEED_UP = False
|
| 219 |
+
except Exception:
|
| 220 |
+
pass
|
| 221 |
+
for _mod_name in ("upfirdn2d", "filtered_lrelu"):
|
| 222 |
+
try:
|
| 223 |
+
_m = __import__(f"torch_utils.ops.{_mod_name}", fromlist=["*"])
|
| 224 |
+
if hasattr(_m, 'USING_CUDA_TO_SPEED_UP'):
|
| 225 |
+
setattr(_m, 'USING_CUDA_TO_SPEED_UP', False)
|
| 226 |
+
except Exception:
|
| 227 |
+
pass
|
| 228 |
+
try:
|
| 229 |
+
from HairMapper.styleGAN2_ada_model.stylegan2_ada.torch_utils.ops import bias_act as _bias_act
|
| 230 |
+
_bias_act.USING_CUDA_TO_SPEED_UP = False
|
| 231 |
+
except Exception:
|
| 232 |
+
pass
|
| 233 |
+
|
| 234 |
+
# 2) 强制 bias_act 走 ref 实现(即便上层传 impl='cuda' 也改为 'ref')。
|
| 235 |
+
try:
|
| 236 |
+
import types
|
| 237 |
+
from torch_utils.ops import bias_act as _ba_mod
|
| 238 |
+
_orig_bias_act = _ba_mod.bias_act
|
| 239 |
+
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
| 240 |
+
return _orig_bias_act(x, b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp, impl='ref')
|
| 241 |
+
_ba_mod.bias_act = types.FunctionType(_bias_act_ref.__code__, globals(), name='bias_act')
|
| 242 |
+
except Exception:
|
| 243 |
+
pass
|
| 244 |
+
|
| 245 |
+
# 再导入依赖(此时已关闭自定义算子与强制 ref 实现)
|
| 246 |
(
|
| 247 |
log_validation,
|
| 248 |
UNet3DConditionModel,
|
|
|
|
| 256 |
bald_head,
|
| 257 |
) = _import_inference_bits()
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
os.makedirs("gradio_inputs", exist_ok=True)
|
| 260 |
os.makedirs("gradio_outputs", exist_ok=True)
|
| 261 |
|
|
|
|
| 302 |
|
| 303 |
args = Args()
|
| 304 |
|
| 305 |
+
device = torch.device("cuda")
|
| 306 |
|
| 307 |
logging.basicConfig(
|
| 308 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|