File size: 2,650 Bytes
9954323
 
9e99f59
9954323
 
 
 
 
9e99f59
 
 
 
9954323
9e99f59
 
 
 
 
 
9954323
 
9e99f59
 
 
 
 
 
 
 
 
 
 
 
9954323
 
9e99f59
 
 
9954323
 
9e99f59
 
 
9954323
9e99f59
9954323
 
 
 
9e99f59
9954323
 
 
9e99f59
9954323
 
9e99f59
 
 
 
 
 
9954323
 
 
 
 
 
9e99f59
 
 
 
 
 
9954323
 
 
9e99f59
9954323
 
 
 
9e99f59
9954323
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import torch.nn as nn
from typing import Dict, List, Tuple


class Exp:
    """
    Configuration class for the page element model.

    This class contains all configuration parameters for the YOLOX-based
    page element detection model, including architecture settings, inference
    parameters, and class-specific thresholds.
    """

    def __init__(self) -> None:
        """Initialize the configuration with default parameters."""
        self.name: str = "page-element-v3"
        self.ckpt: str = "weights.pth"
        self.device: str = "cuda:0" if torch.cuda.is_available() else "cpu"

        # YOLOX architecture parameters
        self.act: str = "silu"
        self.depth: float = 1.00
        self.width: float = 1.00
        self.labels: List[str] = [
            "table",
            "chart",
            "title",
            "infographic",
            "text",
            "header_footer",
        ]
        self.num_classes: int = len(self.labels)

        # Inference parameters
        self.size: Tuple[int, int] = (1024, 1024)
        self.min_bbox_size: int = 0
        self.normalize_boxes: bool = True

        # NMS & thresholding. These can be updated
        self.conf_thresh: float = 0.01
        self.iou_thresh: float = 0.5
        self.class_agnostic: bool = True

        self.thresholds_per_class: Dict[str, float] = {
            "table": 0.1,
            "chart": 0.01,
            "infographic": 0.01,
            "title": 0.1,
            "text": 0.1,
            "header_footer": 0.1,
        }

    def get_model(self) -> nn.Module:
        """
        Get the YOLOX model.

        Builds and returns a YOLOX model with the configured architecture.
        Also updates batch normalization parameters for optimal inference.

        Returns:
            nn.Module: The YOLOX model with configured parameters.
        """
        from yolox import YOLOX, YOLOPAFPN, YOLOXHead

        # Build model
        if getattr(self, "model", None) is None:
            in_channels = [256, 512, 1024]
            backbone = YOLOPAFPN(
                self.depth, self.width, in_channels=in_channels, act=self.act
            )
            head = YOLOXHead(
                self.num_classes, self.width, in_channels=in_channels, act=self.act
            )
            self.model = YOLOX(backbone, head)

        # Update batch-norm parameters
        def init_yolo(M: nn.Module) -> None:
            for m in M.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eps = 1e-3
                    m.momentum = 0.03

        self.model.apply(init_yolo)

        return self.model