File size: 1,532 Bytes
18352e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
#!/usr/bin/env python3
"""
Patch for PyTorch 2.7 weights_only issue with doclayout_yolo models
This allows loading the YOLO model weights safely
"""
import torch
import torch.serialization
# Store original torch.load
_original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
"""Patched torch.load that defaults to weights_only=False for compatibility"""
# If weights_only is not specified, set it to False for compatibility
if 'weights_only' not in kwargs:
kwargs['weights_only'] = False
return _original_torch_load(*args, **kwargs)
def patch_torch_load():
"""Patch torch.load to allow doclayout_yolo classes"""
try:
# First try to add safe globals
torch.serialization.add_safe_globals([
'doclayout_yolo.nn.tasks.YOLOv10DetectionModel',
'doclayout_yolo.nn.modules.YOLOv10DetectionModel',
'ultralytics.nn.tasks.DetectionModel',
'ultralytics.nn.modules.Conv',
'ultralytics.nn.modules.C2f',
'ultralytics.nn.modules.SPPF',
'ultralytics.nn.modules.Detect',
'ultralytics.nn.modules.DFL',
])
print("✅ PyTorch safe globals added for doclayout_yolo")
except Exception as e:
print(f"⚠️ Safe globals failed: {e}")
# Also monkey-patch torch.load to default to weights_only=False
torch.load = patched_torch_load
print("✅ PyTorch load function patched for compatibility")
if __name__ == "__main__":
patch_torch_load() |