File size: 2,580 Bytes
5f631c6
 
 
ab35335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import torch
import torch.nn as nn
from typing import List, Tuple


class Exp:
    """
    Configuration class for the table structure model.

    This class contains all configuration parameters for the YOLOX-based
    table structure detection model, including architecture settings, inference
    parameters, and class-specific thresholds.
    """

    def __init__(self) -> None:
        """Initialize the configuration with default parameters."""
        self.name: str = "page-element-v3"
        self.ckpt: str = "weights.pth"
        self.device: str = "cuda:0" if torch.cuda.is_available() else "cpu"

        # YOLOX architecture parameters
        self.act: str = "silu"
        self.depth: float = 1.00
        self.width: float = 1.00
        self.labels: List[str] = [
            "border",  # not used
            "cell",
            "row",
            "column",
            "header"  # not used
        ]
        self.num_classes: int = len(self.labels)

        # Inference parameters
        self.size: Tuple[int, int] = (1024, 1024)
        self.min_bbox_size: int = 0
        self.normalize_boxes: bool = True

        # NMS & thresholding. These can be updated
        self.conf_thresh: float = 0.01
        self.iou_thresh: float = 0.25
        self.class_agnostic: bool = False

        self.threshold: float = 0.05

    def get_model(self) -> nn.Module:
        """
        Get the YOLOX model.

        Builds and returns a YOLOX model with the configured architecture.
        Also updates batch normalization parameters for optimal inference.

        Returns:
            nn.Module: The YOLOX model with configured parameters.
        """
        from yolox import YOLOX, YOLOPAFPN, YOLOXHead

        # Build model
        if getattr(self, "model", None) is None:
            in_channels = [256, 512, 1024]
            backbone = YOLOPAFPN(
                self.depth, self.width, in_channels=in_channels, act=self.act
            )
            head = YOLOXHead(
                self.num_classes, self.width, in_channels=in_channels, act=self.act
            )
            self.model = YOLOX(backbone, head)

        # Update batch-norm parameters
        def init_yolo(M: nn.Module) -> None:
            for m in M.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eps = 1e-3
                    m.momentum = 0.03

        self.model.apply(init_yolo)

        return self.model