|
|
|
|
|
""" |
|
|
Patch for PyTorch 2.7 weights_only issue with doclayout_yolo models |
|
|
This allows loading the YOLO model weights safely |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.serialization |
|
|
|
|
|
|
|
|
_original_torch_load = torch.load |
|
|
|
|
|
def patched_torch_load(*args, **kwargs): |
|
|
"""Patched torch.load that defaults to weights_only=False for compatibility""" |
|
|
|
|
|
if 'weights_only' not in kwargs: |
|
|
kwargs['weights_only'] = False |
|
|
return _original_torch_load(*args, **kwargs) |
|
|
|
|
|
def patch_torch_load(): |
|
|
"""Patch torch.load to allow doclayout_yolo classes""" |
|
|
try: |
|
|
|
|
|
torch.serialization.add_safe_globals([ |
|
|
'doclayout_yolo.nn.tasks.YOLOv10DetectionModel', |
|
|
'doclayout_yolo.nn.modules.YOLOv10DetectionModel', |
|
|
'ultralytics.nn.tasks.DetectionModel', |
|
|
'ultralytics.nn.modules.Conv', |
|
|
'ultralytics.nn.modules.C2f', |
|
|
'ultralytics.nn.modules.SPPF', |
|
|
'ultralytics.nn.modules.Detect', |
|
|
'ultralytics.nn.modules.DFL', |
|
|
]) |
|
|
print("✅ PyTorch safe globals added for doclayout_yolo") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Safe globals failed: {e}") |
|
|
|
|
|
|
|
|
torch.load = patched_torch_load |
|
|
print("✅ PyTorch load function patched for compatibility") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
patch_torch_load() |