File size: 1,532 Bytes
18352e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#!/usr/bin/env python3
"""
Patch for PyTorch 2.7 weights_only issue with doclayout_yolo models
This allows loading the YOLO model weights safely
"""

import torch
import torch.serialization

# Store original torch.load
_original_torch_load = torch.load

def patched_torch_load(*args, **kwargs):
    """Patched torch.load that defaults to weights_only=False for compatibility"""
    # If weights_only is not specified, set it to False for compatibility
    if 'weights_only' not in kwargs:
        kwargs['weights_only'] = False
    return _original_torch_load(*args, **kwargs)

def patch_torch_load():
    """Patch torch.load to allow doclayout_yolo classes"""
    try:
        # First try to add safe globals
        torch.serialization.add_safe_globals([
            'doclayout_yolo.nn.tasks.YOLOv10DetectionModel',
            'doclayout_yolo.nn.modules.YOLOv10DetectionModel', 
            'ultralytics.nn.tasks.DetectionModel',
            'ultralytics.nn.modules.Conv',
            'ultralytics.nn.modules.C2f',
            'ultralytics.nn.modules.SPPF',
            'ultralytics.nn.modules.Detect',
            'ultralytics.nn.modules.DFL',
        ])
        print("✅ PyTorch safe globals added for doclayout_yolo")
    except Exception as e:
        print(f"⚠️ Safe globals failed: {e}")
        
    # Also monkey-patch torch.load to default to weights_only=False
    torch.load = patched_torch_load
    print("✅ PyTorch load function patched for compatibility")

if __name__ == "__main__":
    patch_torch_load()