Ryan Chesler commited on
Commit
12ea31e
·
1 Parent(s): 697b917

Clean up repository structure and update for pip install

Browse files
Demo.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3fbac9002e8052dd97f8de19ede6a60fa119d237a08aa13e09fc294c73708489
3
- size 1085041
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:002e6edf2b37d18f5eb2499fa653b8543b95964b122be8726cf214a2cf5500ba
3
+ size 848913
README.md CHANGED
@@ -134,7 +134,12 @@ git clone https://huggingface.co/nvidia/nemotron-table-structure-v1
134
  ```
135
  git clone git@hf.co:nvidia/nemotron-table-structure-v1
136
  ```
137
-
 
 
 
 
 
138
  2. Run the model using the following code:
139
 
140
  ```
 
134
  ```
135
  git clone git@hf.co:nvidia/nemotron-table-structure-v1
136
  ```
137
+ Optional:
138
+ This can be installed as a package using pip
139
+ ```
140
+ cd nemotron-table-structure-v1
141
+ pip install -e .
142
+ ```
143
  2. Run the model using the following code:
144
 
145
  ```
model.py DELETED
@@ -1,222 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import os
5
- import sys
6
- import torch
7
- import importlib
8
- import numpy as np
9
- import numpy.typing as npt
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- from typing import Dict, List, Tuple, Union
13
- from yolox.boxes import postprocess
14
-
15
-
16
- def define_model(config_name: str = "page_element_v3", verbose: bool = True) -> nn.Module:
17
- """
18
- Defines and initializes the model based on the configuration.
19
-
20
- Args:
21
- config_name (str): Configuration name. Defaults to "page_element_v3".
22
- verbose (bool): Whether to print verbose output. Defaults to True.
23
-
24
- Returns:
25
- torch.nn.Module: The initialized YOLOX model.
26
- """
27
- # Load model from exp_file
28
- sys.path.append(os.path.dirname(config_name))
29
- exp_module = importlib.import_module(os.path.basename(config_name).split(".")[0])
30
-
31
- config = exp_module.Exp()
32
- model = config.get_model()
33
-
34
- # Load weights
35
- if verbose:
36
- print(" -> Loading weights from", config.ckpt)
37
-
38
- ckpt = torch.load(config.ckpt, map_location="cpu", weights_only=False)
39
- model.load_state_dict(ckpt["model"], strict=True)
40
-
41
- model = YoloXWrapper(model, config)
42
- return model.eval().to(config.device)
43
-
44
-
45
- def resize_pad(img: torch.Tensor, size: tuple) -> torch.Tensor:
46
- """
47
- Resizes and pads an image to a given size.
48
- The goal is to preserve the aspect ratio of the image.
49
-
50
- Args:
51
- img (torch.Tensor[C x H x W]): The image to resize and pad.
52
- size (tuple[2]): The size to resize and pad the image to.
53
-
54
- Returns:
55
- torch.Tensor: The resized and padded image.
56
- """
57
- img = img.float()
58
- _, h, w = img.shape
59
- scale = min(size[0] / h, size[1] / w)
60
- nh = int(h * scale)
61
- nw = int(w * scale)
62
- img = F.interpolate(
63
- img.unsqueeze(0), size=(nh, nw), mode="bilinear", align_corners=False
64
- ).squeeze(0)
65
- img = torch.clamp(img, 0, 255)
66
- pad_b = size[0] - nh
67
- pad_r = size[1] - nw
68
- img = F.pad(img, (0, pad_r, 0, pad_b), value=114.0)
69
- return img
70
-
71
-
72
- class YoloXWrapper(nn.Module):
73
- """
74
- Wrapper for YoloX models.
75
- """
76
- def __init__(self, model: nn.Module, config) -> None:
77
- """
78
- Constructor
79
-
80
- Args:
81
- model (torch model): Yolo model.
82
- config (Config): Config object containing model parameters.
83
- """
84
- super().__init__()
85
- self.model = model
86
- self.config = config
87
-
88
- # Copy config parameters
89
- self.device = config.device
90
- self.img_size = config.size
91
- self.min_bbox_size = config.min_bbox_size
92
- self.normalize_boxes = config.normalize_boxes
93
- self.conf_thresh = config.conf_thresh
94
- self.iou_thresh = config.iou_thresh
95
- self.class_agnostic = config.class_agnostic
96
- self.threshold = config.threshold
97
- self.labels = config.labels
98
- self.num_classes = config.num_classes
99
-
100
- def reformat_input(
101
- self,
102
- x: torch.Tensor,
103
- orig_sizes: Union[torch.Tensor, List, Tuple, npt.NDArray]
104
- ) -> Tuple[torch.Tensor, torch.Tensor]:
105
- """
106
- Reformats the input data and original sizes to the correct format.
107
-
108
- Args:
109
- x (torch.Tensor[BS x C x H x W]): Input image batch.
110
- orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes.
111
- Returns:
112
- torch tensor [BS x C x H x W]: Input image batch.
113
- torch tensor [BS x 2]: Original image sizes (before resizing and padding).
114
- """
115
- # Convert image size to tensor
116
- if isinstance(orig_sizes, (list, tuple)):
117
- orig_sizes = np.array(orig_sizes)
118
- if orig_sizes.shape[-1] == 3: # remove channel
119
- orig_sizes = orig_sizes[..., :2]
120
- if isinstance(orig_sizes, np.ndarray):
121
- orig_sizes = torch.from_numpy(orig_sizes).to(self.device)
122
-
123
- # Add batch dimension if not present
124
- if len(x.size()) == 3:
125
- x = x.unsqueeze(0)
126
- if len(orig_sizes.size()) == 1:
127
- orig_sizes = orig_sizes.unsqueeze(0)
128
-
129
- return x, orig_sizes
130
-
131
- def preprocess(self, image: Union[torch.Tensor, npt.NDArray]) -> torch.Tensor:
132
- """
133
- YoloX preprocessing function:
134
- - Resizes to the longest edge to img_size while preserving the aspect ratio
135
- - Pads the shortest edge to img_size
136
-
137
- Args:
138
- image (torch tensor or np array [H x W x 3]): Input images in uint8 format.
139
-
140
- Returns:
141
- torch tensor [3 x H x W]: Processed image.
142
- """
143
- if not isinstance(image, torch.Tensor):
144
- image = torch.from_numpy(image)
145
- image = image.permute(2, 0, 1) # [H, W, 3] -> [3, H, W]
146
- image = resize_pad(image, self.img_size)
147
- return image.float()
148
-
149
- def forward(
150
- self,
151
- x: torch.Tensor,
152
- orig_sizes: Union[torch.Tensor, List, Tuple, npt.NDArray]
153
- ) -> List[Dict[str, torch.Tensor]]:
154
- """
155
- Forward pass of the model.
156
- Applies NMS and reformats the predictions.
157
-
158
- Args:
159
- x (torch.Tensor[BS x C x H x W]): Input image batch.
160
- orig_sizes (torch.Tensor or list or np.ndarray): Original image sizes.
161
-
162
- Returns:
163
- list[dict]: List of prediction dictionaries. Each dictionary contains:
164
- - labels (torch.Tensor[N]): Class labels
165
- - boxes (torch.Tensor[N x 4]): Bounding boxes
166
- - scores (torch.Tensor[N]): Confidence scores.
167
- """
168
- x, orig_sizes = self.reformat_input(x, orig_sizes)
169
-
170
- # Scale to 0-255 if in range 0-1
171
- if x.max() <= 1:
172
- x *= 255
173
-
174
- pred_boxes = self.model(x.to(self.device))
175
-
176
- # NMS
177
- pred_boxes = postprocess(
178
- pred_boxes,
179
- self.config.num_classes,
180
- self.conf_thresh,
181
- self.iou_thresh,
182
- class_agnostic=self.class_agnostic,
183
- )
184
-
185
- # Reformat output
186
- preds = []
187
- for i, (p, size) in enumerate(zip(pred_boxes, orig_sizes)):
188
- if p is None: # No detections
189
- preds.append({
190
- "labels": torch.empty(0),
191
- "boxes": torch.empty((0, 4)),
192
- "scores": torch.empty(0),
193
- })
194
- continue
195
-
196
- p = p.view(-1, p.size(-1))
197
- ratio = min(self.img_size[0] / size[0], self.img_size[1] / size[1])
198
- boxes = p[:, :4] / ratio
199
-
200
- # Clip
201
- boxes[:, [0, 2]] = torch.clamp(boxes[:, [0, 2]], 0, size[1])
202
- boxes[:, [1, 3]] = torch.clamp(boxes[:, [1, 3]], 0, size[0])
203
-
204
- # Remove too small
205
- kept = (
206
- (boxes[:, 2] - boxes[:, 0] > self.min_bbox_size) &
207
- (boxes[:, 3] - boxes[:, 1] > self.min_bbox_size)
208
- )
209
- boxes = boxes[kept]
210
- p = p[kept]
211
-
212
- # Normalize to 0-1
213
- if self.normalize_boxes:
214
- boxes[:, [0, 2]] /= size[1]
215
- boxes[:, [1, 3]] /= size[0]
216
-
217
- scores = p[:, 4] * p[:, 5]
218
- labels = p[:, 6]
219
-
220
- preds.append({"labels": labels, "boxes": boxes, "scores": scores})
221
-
222
- return preds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
nemotron_table_structure_v1/model.py CHANGED
@@ -36,10 +36,11 @@ def define_model(config_name: str = "page_element_v3", verbose: bool = True) ->
36
  if verbose:
37
  print(" -> Loading weights from", config.ckpt)
38
 
39
- # Use importlib.resources to locate 'weights.pth' inside the module's directory (nmtron_page_elements_v3)
40
- with importlib.resources.path("table_structure_v1", "weights.pth") as weights_path:
41
- ckpt = torch.load(str(weights_path), map_location="cpu", weights_only=False)
42
- model.load_state_dict(ckpt["model"], strict=True)
 
43
 
44
  model = YoloXWrapper(model, config)
45
  return model.eval().to(config.device)
 
36
  if verbose:
37
  print(" -> Loading weights from", config.ckpt)
38
 
39
+ # Find package directory and load weights (nemotron_table_structure_v1)
40
+ package_dir = os.path.dirname(os.path.abspath(__file__))
41
+ weights_path = os.path.join(package_dir, "weights.pth")
42
+ state_dict = torch.load(weights_path, map_location="cpu", weights_only=False)
43
+ model.load_state_dict(state_dict["model"], strict=True)
44
 
45
  model = YoloXWrapper(model, config)
46
  return model.eval().to(config.device)
post_processing/table_struct_pp.py DELETED
@@ -1,222 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import re
5
- from typing import List, Union, Optional, Literal
6
- import numpy as np
7
- import numpy.typing as npt
8
- import pandas as pd
9
-
10
-
11
- def assign_boxes(
12
- box: Union[List[float], npt.NDArray[np.float64]],
13
- candidate_boxes: npt.NDArray[np.float64],
14
- delta: float = 2.0,
15
- min_overlap: float = 0.25,
16
- mode: Literal["cell", "row", "column"] = "cell",
17
- ) -> npt.NDArray[np.int_]:
18
- """
19
- Assigns the best candidate boxes to a reference `box` based on overlap.
20
-
21
- If mode is "cell", the overlap is calculated using surface area overlap.
22
- If mode is "row", the overlap is calculated using row height overlap.
23
- If mode is "column", the overlap is calculated using column width overlap.
24
-
25
- If delta > 1, it will look for multiple matches,
26
- using candidates with score >= max_overlap / delta.
27
-
28
- Args:
29
- box (list or numpy.ndarray): Reference bounding box [x_min, y_min, x_max, y_max].
30
- candidate_boxes (numpy.ndarray [N, 4]): Array of candidate bounding boxes.
31
- delta (float, optional): Factor for matches relative to the best overlap. Defaults to 2.0.
32
- min_overlap (float, optional): Minimum required overlap for a match. Defaults to 0.25.
33
- mode (str, optional): Mode to assign boxes ("cell", "row", or "column"). Defaults to "cell".
34
-
35
- Returns:
36
- numpy.ndarray [M]: Indices of the matched boxes sorted by decreasing overlap.
37
- Returns an empty array if no matches are found.
38
- """
39
- if not len(candidate_boxes):
40
- return np.array([], dtype=np.int_)
41
-
42
- x0_1, y0_1, x1_1, y1_1 = box
43
- x0_2, y0_2, x1_2, y1_2 = (
44
- candidate_boxes[:, 0],
45
- candidate_boxes[:, 1],
46
- candidate_boxes[:, 2],
47
- candidate_boxes[:, 3],
48
- )
49
-
50
- # Intersection
51
- inter_y0 = np.maximum(y0_1, y0_2)
52
- inter_y1 = np.minimum(y1_1, y1_2)
53
- inter_x0 = np.maximum(x0_1, x0_2)
54
- inter_x1 = np.minimum(x1_1, x1_2)
55
-
56
- if mode == "cell":
57
- inter_area = np.maximum(0, inter_y1 - inter_y0) * np.maximum(0, inter_x1 - inter_x0)
58
- box_area = (y1_1 - y0_1) * (x1_1 - x0_1)
59
- overlap = inter_area / (box_area + 1e-6)
60
- elif mode == "row":
61
- inter_area = np.maximum(0, inter_y1 - inter_y0)
62
- box_area = y1_1 - y0_1
63
- overlap = inter_area / (box_area + 1e-6)
64
- elif mode == "column":
65
- inter_area = np.maximum(0, inter_x1 - inter_x0)
66
- box_area = x1_1 - x0_1
67
- overlap = inter_area / (box_area + 1e-6)
68
- else:
69
- raise ValueError(f"Invalid mode: {mode}")
70
-
71
- max_overlap = np.max(overlap)
72
- if max_overlap <= min_overlap: # No match
73
- return np.array([], dtype=np.int_)
74
-
75
- n = len(np.where(overlap >= (max_overlap / delta))[0]) if delta > 1 else 1
76
- matches = np.argsort(-overlap)[:n]
77
- return matches
78
-
79
-
80
- def merge_text_in_cell(df_cell: pd.DataFrame) -> pd.DataFrame:
81
- """
82
- Merges text from multiple rows into a single cell and recalculates its bounding box.
83
- Values are sorted by rounded (y, x) coordinates.
84
-
85
- Args:
86
- df_cell (pandas.DataFrame): DataFrame containing cells to merge.
87
-
88
- Returns:
89
- pandas.DataFrame: Updated DataFrame with merged text and a single bounding box.
90
- """
91
- boxes = np.stack(df_cell["box"].values)
92
-
93
- df_cell["x"] = (boxes[:, 0] - boxes[:, 0].min()) // 10
94
- df_cell["y"] = (boxes[:, 1] - boxes[:, 1].min()) // 10
95
- df_cell = df_cell.sort_values(["y", "x"])
96
-
97
- text = " ".join(df_cell["text"].values.tolist())
98
- df_cell["text"] = text
99
- df_cell = df_cell.head(1)
100
- df_cell["box"] = df_cell["cell"]
101
- df_cell.drop(["x", "y"], axis=1, inplace=True)
102
-
103
- return df_cell
104
-
105
-
106
- def remove_empty_row(mat: List[List[str]]) -> List[List[str]]:
107
- """
108
- Remove empty rows from a matrix.
109
-
110
- Args:
111
- mat (list[list]): The matrix to remove empty rows from.
112
-
113
- Returns:
114
- list[list]: The matrix with empty rows removed.
115
- """
116
- mat_filter = []
117
- for row in mat:
118
- if max([len(c) for c in row]):
119
- mat_filter.append(row)
120
- return mat_filter
121
-
122
-
123
- def build_markdown(
124
- df: pd.DataFrame,
125
- remove_empty: bool = True,
126
- n_rows: Optional[int] = None,
127
- repeat_single: bool = False,
128
- ) -> Union[List[List[str]], npt.NDArray[np.str_]]:
129
- """
130
- Convert a dataframe into a markdown table.
131
-
132
- Args:
133
- df (pandas.DataFrame): The dataframe to convert with columns 'col_ids',
134
- 'row_ids', and 'text'.
135
- remove_empty (bool, optional): Whether to remove empty rows & cols. Defaults to True.
136
- n_rows (int, optional): Number of rows. Inferred from df if None. Defaults to None.
137
- repeat_single (bool, optional): Whether to repeat single element in rows.
138
- Defaults to False.
139
-
140
- Returns:
141
- list[list[str]] or numpy.ndarray: A list of lists or array representing the markdown table.
142
- """
143
- df = df.reset_index(drop=True)
144
- n_cols = max([np.max(c) for c in df['col_ids'].values])
145
- if n_rows is None:
146
- n_rows = max([np.max(c) for c in df['row_ids'].values])
147
- else:
148
- n_rows = max(
149
- n_rows - 1,
150
- max([np.max(c) for c in df['row_ids'].values])
151
- )
152
-
153
- mat = np.empty((n_rows + 1, n_cols + 1), dtype=str).tolist()
154
-
155
- for i in range(len(df)):
156
- if isinstance(df["row_ids"][i], int) or isinstance(df["col_ids"][i], int):
157
- continue
158
- for r in df["row_ids"][i]:
159
- for c in df["col_ids"][i]:
160
- mat[r][c] = (mat[r][c] + " " + df["text"][i]).strip()
161
-
162
- # Remove empty rows & columns
163
- if remove_empty:
164
- mat = remove_empty_row(mat)
165
- mat = np.array(remove_empty_row(np.array(mat).T.tolist())).T.tolist()
166
-
167
- if repeat_single:
168
- new_mat = []
169
- for row in mat:
170
- if sum([len(c) > 0 for c in row]) == 1:
171
- txt = [c for c in row if len(c)][0]
172
- new_mat.append([txt for _ in range(len(row))])
173
- else:
174
- new_mat.append(row)
175
- mat = np.array(new_mat)
176
-
177
- return mat
178
-
179
-
180
- def display_markdown(
181
- data: List[List[str]], show: bool = True, use_header: bool = True
182
- ) -> str:
183
- """
184
- Convert a list of lists of strings into a markdown table.
185
- If show is True, use_header will be set to True.
186
-
187
- Args:
188
- data (list[list[str]]): The table data. The first sublist should contain headers.
189
- show (bool, optional): Whether to display the table. Defaults to True.
190
- use_header (bool, optional): Whether to use the first sublist as headers. Defaults to True.
191
-
192
- Returns:
193
- str: A markdown-formatted table as a string.
194
- """
195
- if show:
196
- use_header = True
197
- data = [[re.sub(r'\n', ' ', c) for c in row] for row in data]
198
-
199
- if not len(data):
200
- return "EMPTY TABLE"
201
-
202
- max_cols = max(len(row) for row in data)
203
- data = [row + [""] * (max_cols - len(row)) for row in data]
204
-
205
- if use_header:
206
- header = "| " + " | ".join(data[0]) + " |"
207
- separator = "| " + " | ".join(["---"] * max_cols) + " |"
208
- body = "\n".join("| " + " | ".join(row) + " |" for row in data[1:])
209
- markdown_table = (
210
- f"{header}\n{separator}\n{body}" if body else f"{header}\n{separator}"
211
- )
212
-
213
- if show:
214
- from IPython.display import display, Markdown
215
- markdown_table = re.sub(r'\$', r'\\$', markdown_table)
216
- markdown_table = re.sub(r'\%', r'\\%', markdown_table)
217
- display(Markdown(markdown_table))
218
-
219
- else:
220
- markdown_table = "\n".join("| " + " | ".join(row) + " |" for row in data)
221
-
222
- return markdown_table
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
post_processing/wbf.py DELETED
@@ -1,324 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- # Adapted from:
5
- # https://github.com/ZFTurbo/Weighted-Boxes-Fusion/blob/master/ensemble_boxes/ensemble_boxes_wbf.py
6
-
7
- import warnings
8
- from typing import Dict, List, Tuple, Union, Literal
9
- import numpy as np
10
- import numpy.typing as npt
11
-
12
-
13
- def prefilter_boxes(
14
- boxes: List[npt.NDArray[np.float64]],
15
- scores: List[npt.NDArray[np.float64]],
16
- labels: List[npt.NDArray[np.int_]],
17
- weights: List[float],
18
- thr: float,
19
- class_agnostic: bool = False,
20
- ) -> Dict[Union[str, int], npt.NDArray[np.float64]]:
21
- """
22
- Reformats and filters boxes.
23
- Output is a dict of boxes to merge separately.
24
-
25
- Args:
26
- boxes (list[np array[n x 4]]): List of boxes. One list per model.
27
- scores (list[np array[n]]): List of confidences.
28
- labels (list[np array[n]]): List of labels.
29
- weights (list): Model weights.
30
- thr (float): Confidence threshold
31
- class_agnostic (bool, optional): Merge boxes from different classes. Defaults to False.
32
-
33
- Returns:
34
- dict[np array [? x 8]]: Filtered boxes.
35
- """
36
- # Create dict with boxes stored by its label
37
- new_boxes = dict()
38
-
39
- for t in range(len(boxes)):
40
- assert len(boxes[t]) == len(scores[t]), "len(boxes) != len(scores)"
41
- assert len(boxes[t]) == len(labels[t]), "len(boxes) != len(labels)"
42
-
43
- for j in range(len(boxes[t])):
44
- score = scores[t][j]
45
- if score < thr:
46
- continue
47
- label = int(labels[t][j])
48
- box_part = boxes[t][j]
49
- x1 = float(box_part[0])
50
- y1 = float(box_part[1])
51
- x2 = float(box_part[2])
52
- y2 = float(box_part[3])
53
-
54
- # Box data checks
55
- if x2 < x1:
56
- warnings.warn("X2 < X1 value in box. Swap them.")
57
- x1, x2 = x2, x1
58
- if y2 < y1:
59
- warnings.warn("Y2 < Y1 value in box. Swap them.")
60
- y1, y2 = y2, y1
61
-
62
- array = np.array([x1, x2, y1, y2])
63
- if array.min() < 0 or array.max() > 1:
64
- warnings.warn("Coordinates outside [0, 1]")
65
- array = np.clip(array, 0, 1)
66
- x1, x2, y1, y2 = array
67
-
68
- if (x2 - x1) * (y2 - y1) == 0.0:
69
- warnings.warn("Zero area box skipped: {}.".format(box_part))
70
- continue
71
-
72
- # [label, score, weight, model index, x1, y1, x2, y2]
73
- b = [int(label), float(score) * weights[t], weights[t], t, x1, y1, x2, y2]
74
-
75
- label_k = "*" if class_agnostic else label
76
- if label_k not in new_boxes:
77
- new_boxes[label_k] = []
78
- new_boxes[label_k].append(b)
79
-
80
- # Sort each list in dict by score and transform it to numpy array
81
- for k in new_boxes:
82
- current_boxes = np.array(new_boxes[k])
83
- new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
84
-
85
- return new_boxes
86
-
87
-
88
- def merge_labels(
89
- labels: npt.NDArray[np.int_], confs: npt.NDArray[np.float64]
90
- ) -> int:
91
- """
92
- Custom function for merging labels.
93
- If all labels are the same, return the unique value.
94
- Else, return the label of the most confident non-title (class 2) box.
95
-
96
- Args:
97
- labels (np array [n]): Labels.
98
- confs (np array [n]): Confidence.
99
-
100
- Returns:
101
- int: Label.
102
- """
103
- if len(np.unique(labels)) == 1:
104
- return labels[0]
105
- else: # Most confident and not a title
106
- confs = confs[confs != 2]
107
- labels = labels[labels != 2]
108
- return labels[np.argmax(confs)]
109
-
110
-
111
- def get_weighted_box(
112
- boxes: npt.NDArray[np.float64], conf_type: Literal["avg", "max"] = "avg"
113
- ) -> npt.NDArray[np.float64]:
114
- """
115
- Merges boxes by using the weighted fusion.
116
-
117
- Args:
118
- boxes (np array [n x 8]): Boxes to merge.
119
- conf_type (str, optional): Confidence merging type. Defaults to "avg".
120
-
121
- Returns:
122
- np array [8]: Merged box.
123
- """
124
- box = np.zeros(8, dtype=np.float32)
125
- conf = 0
126
- conf_list = []
127
- w = 0
128
- for b in boxes:
129
- box[4:] += b[1] * b[4:]
130
- conf += b[1]
131
- conf_list.append(b[1])
132
- w += b[2]
133
-
134
- box[0] = merge_labels(
135
- np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes])
136
- )
137
-
138
- box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
139
- box[2] = w
140
- box[3] = -1 # model index field is retained for consistency but is not used.
141
- box[4:] /= conf
142
- return box
143
-
144
-
145
- def get_biggest_box(
146
- boxes: npt.NDArray[np.float64], conf_type: Literal["avg", "max"] = "avg"
147
- ) -> npt.NDArray[np.float64]:
148
- """
149
- Merges boxes by using the biggest box.
150
-
151
- Args:
152
- boxes (np array [n x 8]): Boxes to merge.
153
- conf_type (str, optional): Confidence merging type. Defaults to "avg".
154
-
155
- Returns:
156
- np array [8]: Merged box.
157
- """
158
- box = np.zeros(8, dtype=np.float32)
159
- box[4:] = boxes[0][4:]
160
- conf_list = []
161
- w = 0
162
- for b in boxes:
163
- box[4] = min(box[4], b[4])
164
- box[5] = min(box[5], b[5])
165
- box[6] = max(box[6], b[6])
166
- box[7] = max(box[7], b[7])
167
- conf_list.append(b[1])
168
- w += b[2]
169
-
170
- box[0] = merge_labels(
171
- np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes])
172
- )
173
- # print(box[0], np.array([b[0] for b in boxes]))
174
-
175
- box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
176
- box[2] = w
177
- box[3] = -1 # model index field is retained for consistency but is not used.
178
- return box
179
-
180
-
181
- def find_matching_box_fast(
182
- boxes_list: npt.NDArray[np.float64],
183
- new_box: npt.NDArray[np.float64],
184
- match_iou: float,
185
- ) -> Tuple[int, float]:
186
- """
187
- Reimplementation of find_matching_box with numpy instead of loops.
188
- Gives significant speed up for larger arrays (~100x).
189
- This was previously the bottleneck since the function is called for every entry in the array.
190
-
191
- Args:
192
- boxes_list (np.ndarray): Array of boxes with shape (N, 8).
193
- new_box (np.ndarray): New box to match with shape (8,).
194
- match_iou (float): IoU threshold for matching.
195
-
196
- Returns:
197
- Tuple[int, float]: Index of best matching box (-1 if no match) and IoU value.
198
- """
199
-
200
- def bb_iou_array(
201
- boxes: npt.NDArray[np.float64], new_box: npt.NDArray[np.float64]
202
- ) -> npt.NDArray[np.float64]:
203
- # bb interesection over union
204
- xA = np.maximum(boxes[:, 0], new_box[0])
205
- yA = np.maximum(boxes[:, 1], new_box[1])
206
- xB = np.minimum(boxes[:, 2], new_box[2])
207
- yB = np.minimum(boxes[:, 3], new_box[3])
208
-
209
- interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
210
-
211
- # compute the area of both the prediction and ground-truth rectangles
212
- boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
213
- boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
214
-
215
- iou = interArea / (boxAArea + boxBArea - interArea)
216
-
217
- return iou
218
-
219
- if boxes_list.shape[0] == 0:
220
- return -1, match_iou
221
-
222
- ious = bb_iou_array(boxes_list[:, 4:], new_box[4:])
223
- # ious[boxes[:, 0] != new_box[0]] = -1
224
-
225
- best_idx = np.argmax(ious)
226
- best_iou = ious[best_idx]
227
-
228
- if best_iou <= match_iou:
229
- best_iou = match_iou
230
- best_idx = -1
231
-
232
- return best_idx, best_iou
233
-
234
-
235
- def weighted_boxes_fusion(
236
- boxes_list: List[npt.NDArray[np.float64]],
237
- labels_list: List[npt.NDArray[np.int_]],
238
- scores_list: List[npt.NDArray[np.float64]],
239
- iou_thr: float = 0.5,
240
- skip_box_thr: float = 0.0,
241
- conf_type: Literal["avg", "max"] = "avg",
242
- merge_type: Literal["weighted", "biggest"] = "weighted",
243
- class_agnostic: bool = False,
244
- ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.int_]]:
245
- """
246
- Custom WBF implementation that supports a class_agnostic mode and a biggest box fusion.
247
- Boxes are expected to be in normalized (x0, y0, x1, y1) format.
248
-
249
- Args:
250
- boxes_list (list[np.ndarray[n x 4]]): List of boxes. One list per model.
251
- labels_list (list[np.ndarray[n]]): List of labels.
252
- scores_list (list[np.ndarray[n]]): List of confidences.
253
- iou_thr (float, optional): IoU threshold for matching. Defaults to 0.55.
254
- skip_box_thr (float, optional): Exclude boxes with score < skip_box_thr. Defaults to 0.0.
255
- conf_type (str, optional): Confidence merging type ("avg" or "max"). Defaults to "avg".
256
- merge_type (str, optional): Merge type ("weighted" or "biggest"). Defaults to "weighted".
257
- class_agnostic (bool, optional): Merge boxes from different classes. Defaults to False.
258
-
259
- Returns:
260
- numpy.ndarray [N x 4]: Array of bounding boxes.
261
- numpy.ndarray [N]: Array of labels.
262
- numpy.ndarray [N]: Array of scores.
263
- """
264
- weights = np.ones(len(boxes_list))
265
-
266
- assert conf_type in ["avg", "max"], 'Conf type must be "avg" or "max"'
267
- assert merge_type in ["weighted", "biggest"], 'Conf type must be "weighted" or "biggest"'
268
-
269
- filtered_boxes = prefilter_boxes(
270
- boxes_list,
271
- scores_list,
272
- labels_list,
273
- weights,
274
- skip_box_thr,
275
- class_agnostic=class_agnostic,
276
- )
277
- if len(filtered_boxes) == 0:
278
- return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
279
-
280
- overall_boxes = []
281
- for label in filtered_boxes:
282
- boxes = filtered_boxes[label]
283
- clusters = []
284
-
285
- # Clusterize boxes
286
- for j in range(len(boxes)):
287
- ids = [i for i in range(len(boxes)) if i != j]
288
- index, best_iou = find_matching_box_fast(boxes[ids], boxes[j], iou_thr)
289
-
290
- if index != -1:
291
- index = ids[index]
292
- cluster_idx = [
293
- clust_idx
294
- for clust_idx, clust in enumerate(clusters)
295
- if (j in clust or index in clust)
296
- ]
297
- if len(cluster_idx):
298
- cluster_idx = cluster_idx[0]
299
- clusters[cluster_idx] = list(
300
- set(clusters[cluster_idx] + [index, j])
301
- )
302
- else:
303
- clusters.append([index, j])
304
- else:
305
- clusters.append([j])
306
-
307
- for j, c in enumerate(clusters):
308
- if merge_type == "weighted":
309
- weighted_box = get_weighted_box(boxes[c], conf_type)
310
- elif merge_type == "biggest":
311
- weighted_box = get_biggest_box(boxes[c], conf_type)
312
-
313
- if conf_type == "max":
314
- weighted_box[1] = weighted_box[1] / weights.max()
315
- else: # avg
316
- weighted_box[1] = weighted_box[1] * len(c) / weights.sum()
317
- overall_boxes.append(weighted_box)
318
-
319
- overall_boxes = np.array(overall_boxes)
320
- overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
321
- boxes = overall_boxes[:, 4:]
322
- scores = overall_boxes[:, 1]
323
- labels = overall_boxes[:, 0]
324
- return boxes, labels, scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
table_structure_v1.py DELETED
@@ -1,81 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import torch
5
- import torch.nn as nn
6
- from typing import List, Tuple
7
-
8
-
9
- class Exp:
10
- """
11
- Configuration class for the table structure model.
12
-
13
- This class contains all configuration parameters for the YOLOX-based
14
- table structure detection model, including architecture settings, inference
15
- parameters, and class-specific thresholds.
16
- """
17
-
18
- def __init__(self) -> None:
19
- """Initialize the configuration with default parameters."""
20
- self.name: str = "page-element-v3"
21
- self.ckpt: str = "weights.pth"
22
- self.device: str = "cuda:0" if torch.cuda.is_available() else "cpu"
23
-
24
- # YOLOX architecture parameters
25
- self.act: str = "silu"
26
- self.depth: float = 1.00
27
- self.width: float = 1.00
28
- self.labels: List[str] = [
29
- "border", # not used
30
- "cell",
31
- "row",
32
- "column",
33
- "header" # not used
34
- ]
35
- self.num_classes: int = len(self.labels)
36
-
37
- # Inference parameters
38
- self.size: Tuple[int, int] = (1024, 1024)
39
- self.min_bbox_size: int = 0
40
- self.normalize_boxes: bool = True
41
-
42
- # NMS & thresholding. These can be updated
43
- self.conf_thresh: float = 0.01
44
- self.iou_thresh: float = 0.25
45
- self.class_agnostic: bool = False
46
-
47
- self.threshold: float = 0.05
48
-
49
- def get_model(self) -> nn.Module:
50
- """
51
- Get the YOLOX model.
52
-
53
- Builds and returns a YOLOX model with the configured architecture.
54
- Also updates batch normalization parameters for optimal inference.
55
-
56
- Returns:
57
- nn.Module: The YOLOX model with configured parameters.
58
- """
59
- from yolox import YOLOX, YOLOPAFPN, YOLOXHead
60
-
61
- # Build model
62
- if getattr(self, "model", None) is None:
63
- in_channels = [256, 512, 1024]
64
- backbone = YOLOPAFPN(
65
- self.depth, self.width, in_channels=in_channels, act=self.act
66
- )
67
- head = YOLOXHead(
68
- self.num_classes, self.width, in_channels=in_channels, act=self.act
69
- )
70
- self.model = YOLOX(backbone, head)
71
-
72
- # Update batch-norm parameters
73
- def init_yolo(M: nn.Module) -> None:
74
- for m in M.modules():
75
- if isinstance(m, nn.BatchNorm2d):
76
- m.eps = 1e-3
77
- m.momentum = 0.03
78
-
79
- self.model.apply(init_yolo)
80
-
81
- return self.model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolox/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- #!/usr/bin/env python3
5
- # -*- coding:utf-8 -*-
6
- # Copyright (c) Megvii Inc. All rights reserved.
7
-
8
- from .yolo_head import YOLOXHead
9
- from .yolo_pafpn import YOLOPAFPN
10
- from .yolox import YOLOX
 
 
 
 
 
 
 
 
 
 
 
yolox/boxes.py DELETED
@@ -1,58 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- #!/usr/bin/env python3
5
- # Copyright (c) Megvii Inc. All rights reserved.
6
-
7
- import torch
8
- import torchvision
9
-
10
-
11
- def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
12
- """
13
- Copied from YOLOX/yolox/utils/boxes.py
14
- """
15
- box_corner = prediction.new(prediction.shape)
16
- box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
17
- box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
18
- box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
19
- box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
20
- prediction[:, :, :4] = box_corner[:, :, :4]
21
-
22
- output = [None for _ in range(len(prediction))]
23
- for i, image_pred in enumerate(prediction):
24
-
25
- # If none are remaining => process next image
26
- if not image_pred.size(0):
27
- continue
28
- # Get score and class with highest confidence
29
- class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True)
30
-
31
- conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()
32
- # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
33
- detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
34
- detections = detections[conf_mask]
35
- if not detections.size(0):
36
- continue
37
-
38
- if class_agnostic:
39
- nms_out_index = torchvision.ops.nms(
40
- detections[:, :4],
41
- detections[:, 4] * detections[:, 5],
42
- nms_thre,
43
- )
44
- else:
45
- nms_out_index = torchvision.ops.batched_nms(
46
- detections[:, :4],
47
- detections[:, 4] * detections[:, 5],
48
- detections[:, 6],
49
- nms_thre,
50
- )
51
-
52
- detections = detections[nms_out_index]
53
- if output[i] is None:
54
- output[i] = detections
55
- else:
56
- output[i] = torch.cat((output[i], detections))
57
-
58
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolox/darknet.py DELETED
@@ -1,182 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- #!/usr/bin/env python
5
- # -*- encoding: utf-8 -*-
6
- # Copyright (c) Megvii Inc. All rights reserved.
7
-
8
- from torch import nn
9
-
10
- from .network_blocks import BaseConv, CSPLayer, DWConv, Focus, ResLayer, SPPBottleneck
11
-
12
-
13
- class Darknet(nn.Module):
14
- # number of blocks from dark2 to dark5.
15
- depth2blocks = {21: [1, 2, 2, 1], 53: [2, 8, 8, 4]}
16
-
17
- def __init__(
18
- self,
19
- depth,
20
- in_channels=3,
21
- stem_out_channels=32,
22
- out_features=("dark3", "dark4", "dark5"),
23
- ):
24
- """
25
- Args:
26
- depth (int): depth of darknet used in model, usually use [21, 53] for this param.
27
- in_channels (int): number of input channels, for example, use 3 for RGB image.
28
- stem_out_channels (int): number of output channels of darknet stem.
29
- It decides channels of darknet layer2 to layer5.
30
- out_features (Tuple[str]): desired output layer name.
31
- """
32
- super().__init__()
33
- assert out_features, "please provide output features of Darknet"
34
- self.out_features = out_features
35
- self.stem = nn.Sequential(
36
- BaseConv(in_channels, stem_out_channels, ksize=3, stride=1, act="lrelu"),
37
- *self.make_group_layer(stem_out_channels, num_blocks=1, stride=2),
38
- )
39
- in_channels = stem_out_channels * 2 # 64
40
-
41
- num_blocks = Darknet.depth2blocks[depth]
42
- # create darknet with `stem_out_channels` and `num_blocks` layers.
43
- # to make model structure more clear, we don't use `for` statement in python.
44
- self.dark2 = nn.Sequential(
45
- *self.make_group_layer(in_channels, num_blocks[0], stride=2)
46
- )
47
- in_channels *= 2 # 128
48
- self.dark3 = nn.Sequential(
49
- *self.make_group_layer(in_channels, num_blocks[1], stride=2)
50
- )
51
- in_channels *= 2 # 256
52
- self.dark4 = nn.Sequential(
53
- *self.make_group_layer(in_channels, num_blocks[2], stride=2)
54
- )
55
- in_channels *= 2 # 512
56
-
57
- self.dark5 = nn.Sequential(
58
- *self.make_group_layer(in_channels, num_blocks[3], stride=2),
59
- *self.make_spp_block([in_channels, in_channels * 2], in_channels * 2),
60
- )
61
-
62
- def make_group_layer(self, in_channels: int, num_blocks: int, stride: int = 1):
63
- "starts with conv layer then has `num_blocks` `ResLayer`"
64
- return [
65
- BaseConv(in_channels, in_channels * 2, ksize=3, stride=stride, act="lrelu"),
66
- *[(ResLayer(in_channels * 2)) for _ in range(num_blocks)],
67
- ]
68
-
69
- def make_spp_block(self, filters_list, in_filters):
70
- m = nn.Sequential(
71
- *[
72
- BaseConv(in_filters, filters_list[0], 1, stride=1, act="lrelu"),
73
- BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
74
- SPPBottleneck(
75
- in_channels=filters_list[1],
76
- out_channels=filters_list[0],
77
- activation="lrelu",
78
- ),
79
- BaseConv(filters_list[0], filters_list[1], 3, stride=1, act="lrelu"),
80
- BaseConv(filters_list[1], filters_list[0], 1, stride=1, act="lrelu"),
81
- ]
82
- )
83
- return m
84
-
85
- def forward(self, x):
86
- outputs = {}
87
- x = self.stem(x)
88
- outputs["stem"] = x
89
- x = self.dark2(x)
90
- outputs["dark2"] = x
91
- x = self.dark3(x)
92
- outputs["dark3"] = x
93
- x = self.dark4(x)
94
- outputs["dark4"] = x
95
- x = self.dark5(x)
96
- outputs["dark5"] = x
97
- return {k: v for k, v in outputs.items() if k in self.out_features}
98
-
99
-
100
- class CSPDarknet(nn.Module):
101
- def __init__(
102
- self,
103
- dep_mul,
104
- wid_mul,
105
- out_features=("dark3", "dark4", "dark5"),
106
- depthwise=False,
107
- act="silu",
108
- ):
109
- super().__init__()
110
- assert out_features, "please provide output features of Darknet"
111
- self.out_features = out_features
112
- Conv = DWConv if depthwise else BaseConv
113
-
114
- base_channels = int(wid_mul * 64) # 64
115
- base_depth = max(round(dep_mul * 3), 1) # 3
116
-
117
- # stem
118
- self.stem = Focus(3, base_channels, ksize=3, act=act)
119
-
120
- # dark2
121
- self.dark2 = nn.Sequential(
122
- Conv(base_channels, base_channels * 2, 3, 2, act=act),
123
- CSPLayer(
124
- base_channels * 2,
125
- base_channels * 2,
126
- n=base_depth,
127
- depthwise=depthwise,
128
- act=act,
129
- ),
130
- )
131
-
132
- # dark3
133
- self.dark3 = nn.Sequential(
134
- Conv(base_channels * 2, base_channels * 4, 3, 2, act=act),
135
- CSPLayer(
136
- base_channels * 4,
137
- base_channels * 4,
138
- n=base_depth * 3,
139
- depthwise=depthwise,
140
- act=act,
141
- ),
142
- )
143
-
144
- # dark4
145
- self.dark4 = nn.Sequential(
146
- Conv(base_channels * 4, base_channels * 8, 3, 2, act=act),
147
- CSPLayer(
148
- base_channels * 8,
149
- base_channels * 8,
150
- n=base_depth * 3,
151
- depthwise=depthwise,
152
- act=act,
153
- ),
154
- )
155
-
156
- # dark5
157
- self.dark5 = nn.Sequential(
158
- Conv(base_channels * 8, base_channels * 16, 3, 2, act=act),
159
- SPPBottleneck(base_channels * 16, base_channels * 16, activation=act),
160
- CSPLayer(
161
- base_channels * 16,
162
- base_channels * 16,
163
- n=base_depth,
164
- shortcut=False,
165
- depthwise=depthwise,
166
- act=act,
167
- ),
168
- )
169
-
170
- def forward(self, x):
171
- outputs = {}
172
- x = self.stem(x)
173
- outputs["stem"] = x
174
- x = self.dark2(x)
175
- outputs["dark2"] = x
176
- x = self.dark3(x)
177
- outputs["dark3"] = x
178
- x = self.dark4(x)
179
- outputs["dark4"] = x
180
- x = self.dark5(x)
181
- outputs["dark5"] = x
182
- return {k: v for k, v in outputs.items() if k in self.out_features}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolox/network_blocks.py DELETED
@@ -1,213 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- #!/usr/bin/env python
5
- # -*- encoding: utf-8 -*-
6
- # Copyright (c) Megvii Inc. All rights reserved.
7
-
8
- import torch
9
- import torch.nn as nn
10
-
11
-
12
- class SiLU(nn.Module):
13
- """export-friendly version of nn.SiLU()"""
14
-
15
- @staticmethod
16
- def forward(x):
17
- return x * torch.sigmoid(x)
18
-
19
-
20
- def get_activation(name="silu", inplace=True):
21
- if name == "silu":
22
- module = nn.SiLU(inplace=inplace)
23
- elif name == "relu":
24
- module = nn.ReLU(inplace=inplace)
25
- elif name == "lrelu":
26
- module = nn.LeakyReLU(0.1, inplace=inplace)
27
- else:
28
- raise AttributeError("Unsupported act type: {}".format(name))
29
- return module
30
-
31
-
32
- class BaseConv(nn.Module):
33
- """A Conv2d -> Batchnorm -> silu/leaky relu block"""
34
-
35
- def __init__(
36
- self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"
37
- ):
38
- super().__init__()
39
- # same padding
40
- pad = (ksize - 1) // 2
41
- self.conv = nn.Conv2d(
42
- in_channels,
43
- out_channels,
44
- kernel_size=ksize,
45
- stride=stride,
46
- padding=pad,
47
- groups=groups,
48
- bias=bias,
49
- )
50
- self.bn = nn.BatchNorm2d(out_channels)
51
- self.act = get_activation(act, inplace=True)
52
-
53
- def forward(self, x):
54
- return self.act(self.bn(self.conv(x)))
55
-
56
- def fuseforward(self, x):
57
- return self.act(self.conv(x))
58
-
59
-
60
- class DWConv(nn.Module):
61
- """Depthwise Conv + Conv"""
62
-
63
- def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
64
- super().__init__()
65
- self.dconv = BaseConv(
66
- in_channels,
67
- in_channels,
68
- ksize=ksize,
69
- stride=stride,
70
- groups=in_channels,
71
- act=act,
72
- )
73
- self.pconv = BaseConv(
74
- in_channels, out_channels, ksize=1, stride=1, groups=1, act=act
75
- )
76
-
77
- def forward(self, x):
78
- x = self.dconv(x)
79
- return self.pconv(x)
80
-
81
-
82
- class Bottleneck(nn.Module):
83
- # Standard bottleneck
84
- def __init__(
85
- self,
86
- in_channels,
87
- out_channels,
88
- shortcut=True,
89
- expansion=0.5,
90
- depthwise=False,
91
- act="silu",
92
- ):
93
- super().__init__()
94
- hidden_channels = int(out_channels * expansion)
95
- Conv = DWConv if depthwise else BaseConv
96
- self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
97
- self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
98
- self.use_add = shortcut and in_channels == out_channels
99
-
100
- def forward(self, x):
101
- y = self.conv2(self.conv1(x))
102
- if self.use_add:
103
- y = y + x
104
- return y
105
-
106
-
107
- class ResLayer(nn.Module):
108
- "Residual layer with `in_channels` inputs."
109
-
110
- def __init__(self, in_channels: int):
111
- super().__init__()
112
- mid_channels = in_channels // 2
113
- self.layer1 = BaseConv(
114
- in_channels, mid_channels, ksize=1, stride=1, act="lrelu"
115
- )
116
- self.layer2 = BaseConv(
117
- mid_channels, in_channels, ksize=3, stride=1, act="lrelu"
118
- )
119
-
120
- def forward(self, x):
121
- out = self.layer2(self.layer1(x))
122
- return x + out
123
-
124
-
125
- class SPPBottleneck(nn.Module):
126
- """Spatial pyramid pooling layer used in YOLOv3-SPP"""
127
-
128
- def __init__(
129
- self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"
130
- ):
131
- super().__init__()
132
- hidden_channels = in_channels // 2
133
- self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation)
134
- self.m = nn.ModuleList(
135
- [
136
- nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
137
- for ks in kernel_sizes
138
- ]
139
- )
140
- conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
141
- self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation)
142
-
143
- def forward(self, x):
144
- x = self.conv1(x)
145
- x = torch.cat([x] + [m(x) for m in self.m], dim=1)
146
- x = self.conv2(x)
147
- return x
148
-
149
-
150
- class CSPLayer(nn.Module):
151
- """C3 in yolov5, CSP Bottleneck with 3 convolutions"""
152
-
153
- def __init__(
154
- self,
155
- in_channels,
156
- out_channels,
157
- n=1,
158
- shortcut=True,
159
- expansion=0.5,
160
- depthwise=False,
161
- act="silu",
162
- ):
163
- """
164
- Args:
165
- in_channels (int): input channels.
166
- out_channels (int): output channels.
167
- n (int): number of Bottlenecks. Default value: 1.
168
- """
169
- # ch_in, ch_out, number, shortcut, groups, expansion
170
- super().__init__()
171
- hidden_channels = int(out_channels * expansion) # hidden channels
172
- self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
173
- self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
174
- self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
175
- module_list = [
176
- Bottleneck(
177
- hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act
178
- )
179
- for _ in range(n)
180
- ]
181
- self.m = nn.Sequential(*module_list)
182
-
183
- def forward(self, x):
184
- x_1 = self.conv1(x)
185
- x_2 = self.conv2(x)
186
- x_1 = self.m(x_1)
187
- x = torch.cat((x_1, x_2), dim=1)
188
- return self.conv3(x)
189
-
190
-
191
- class Focus(nn.Module):
192
- """Focus width and height information into channel space."""
193
-
194
- def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
195
- super().__init__()
196
- self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
197
-
198
- def forward(self, x):
199
- # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
200
- patch_top_left = x[..., ::2, ::2]
201
- patch_top_right = x[..., ::2, 1::2]
202
- patch_bot_left = x[..., 1::2, ::2]
203
- patch_bot_right = x[..., 1::2, 1::2]
204
- x = torch.cat(
205
- (
206
- patch_top_left,
207
- patch_bot_left,
208
- patch_top_right,
209
- patch_bot_right,
210
- ),
211
- dim=1,
212
- )
213
- return self.conv(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolox/yolo_fpn.py DELETED
@@ -1,87 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- #!/usr/bin/env python
5
- # -*- encoding: utf-8 -*-
6
- # Copyright (c) Megvii Inc. All rights reserved.
7
-
8
- import torch
9
- import torch.nn as nn
10
-
11
- from .darknet import Darknet
12
- from .network_blocks import BaseConv
13
-
14
-
15
- class YOLOFPN(nn.Module):
16
- """
17
- YOLOFPN module. Darknet 53 is the default backbone of this model.
18
- """
19
-
20
- def __init__(
21
- self,
22
- depth=53,
23
- in_features=["dark3", "dark4", "dark5"],
24
- ):
25
- super().__init__()
26
-
27
- self.backbone = Darknet(depth)
28
- self.in_features = in_features
29
-
30
- # out 1
31
- self.out1_cbl = self._make_cbl(512, 256, 1)
32
- self.out1 = self._make_embedding([256, 512], 512 + 256)
33
-
34
- # out 2
35
- self.out2_cbl = self._make_cbl(256, 128, 1)
36
- self.out2 = self._make_embedding([128, 256], 256 + 128)
37
-
38
- # upsample
39
- self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
40
-
41
- def _make_cbl(self, _in, _out, ks):
42
- return BaseConv(_in, _out, ks, stride=1, act="lrelu")
43
-
44
- def _make_embedding(self, filters_list, in_filters):
45
- m = nn.Sequential(
46
- *[
47
- self._make_cbl(in_filters, filters_list[0], 1),
48
- self._make_cbl(filters_list[0], filters_list[1], 3),
49
- self._make_cbl(filters_list[1], filters_list[0], 1),
50
- self._make_cbl(filters_list[0], filters_list[1], 3),
51
- self._make_cbl(filters_list[1], filters_list[0], 1),
52
- ]
53
- )
54
- return m
55
-
56
- def load_pretrained_model(self, filename="./weights/darknet53.mix.pth"):
57
- with open(filename, "rb") as f:
58
- state_dict = torch.load(f, map_location="cpu")
59
- print("loading pretrained weights...")
60
- self.backbone.load_state_dict(state_dict)
61
-
62
- def forward(self, inputs):
63
- """
64
- Args:
65
- inputs (Tensor): input image.
66
-
67
- Returns:
68
- Tuple[Tensor]: FPN output features..
69
- """
70
- # backbone
71
- out_features = self.backbone(inputs)
72
- x2, x1, x0 = [out_features[f] for f in self.in_features]
73
-
74
- # yolo branch 1
75
- x1_in = self.out1_cbl(x0)
76
- x1_in = self.upsample(x1_in)
77
- x1_in = torch.cat([x1_in, x1], 1)
78
- out_dark4 = self.out1(x1_in)
79
-
80
- # yolo branch 2
81
- x2_in = self.out2_cbl(out_dark4)
82
- x2_in = self.upsample(x2_in)
83
- x2_in = torch.cat([x2_in, x2], 1)
84
- out_dark3 = self.out2(x2_in)
85
-
86
- outputs = (out_dark3, out_dark4, x0)
87
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolox/yolo_head.py DELETED
@@ -1,238 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- #!/usr/bin/env python3
5
- # -*- coding:utf-8 -*-
6
- # Copyright (c) Megvii Inc. All rights reserved.
7
-
8
- import torch
9
- import torch.nn as nn
10
- from .network_blocks import BaseConv, DWConv
11
-
12
-
13
- _TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]
14
-
15
-
16
- def meshgrid(*tensors):
17
- """
18
- Copied from YOLOX/yolox/utils/compat.py
19
- """
20
- if _TORCH_VER >= [1, 10]:
21
- return torch.meshgrid(*tensors, indexing="ij")
22
- else:
23
- return torch.meshgrid(*tensors)
24
-
25
-
26
- def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
27
- """
28
- Copied from YOLOX/yolox/utils/boxes.py
29
- """
30
- if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
31
- raise IndexError
32
-
33
- if xyxy:
34
- tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
35
- br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
36
- area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
37
- area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
38
- else:
39
- tl = torch.max(
40
- (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
41
- (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2),
42
- )
43
- br = torch.min(
44
- (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
45
- (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2),
46
- )
47
-
48
- area_a = torch.prod(bboxes_a[:, 2:], 1)
49
- area_b = torch.prod(bboxes_b[:, 2:], 1)
50
- en = (tl < br).type(tl.type()).prod(dim=2)
51
- area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all())
52
- return area_i / (area_a[:, None] + area_b - area_i)
53
-
54
-
55
- class YOLOXHead(nn.Module):
56
- def __init__(
57
- self,
58
- num_classes,
59
- width=1.0,
60
- strides=[8, 16, 32],
61
- in_channels=[256, 512, 1024],
62
- act="silu",
63
- depthwise=False,
64
- ):
65
- """
66
- Args:
67
- act (str): activation type of conv. Defalut value: "silu".
68
- depthwise (bool): whether apply depthwise conv in conv branch. Defalut value: False.
69
- """
70
- super().__init__()
71
-
72
- self.num_classes = num_classes
73
- self.decode_in_inference = True # for deploy, set to False
74
-
75
- self.cls_convs = nn.ModuleList()
76
- self.reg_convs = nn.ModuleList()
77
- self.cls_preds = nn.ModuleList()
78
- self.reg_preds = nn.ModuleList()
79
- self.obj_preds = nn.ModuleList()
80
- self.stems = nn.ModuleList()
81
- Conv = DWConv if depthwise else BaseConv
82
-
83
- for i in range(len(in_channels)):
84
- self.stems.append(
85
- BaseConv(
86
- in_channels=int(in_channels[i] * width),
87
- out_channels=int(256 * width),
88
- ksize=1,
89
- stride=1,
90
- act=act,
91
- )
92
- )
93
- self.cls_convs.append(
94
- nn.Sequential(
95
- *[
96
- Conv(
97
- in_channels=int(256 * width),
98
- out_channels=int(256 * width),
99
- ksize=3,
100
- stride=1,
101
- act=act,
102
- ),
103
- Conv(
104
- in_channels=int(256 * width),
105
- out_channels=int(256 * width),
106
- ksize=3,
107
- stride=1,
108
- act=act,
109
- ),
110
- ]
111
- )
112
- )
113
- self.reg_convs.append(
114
- nn.Sequential(
115
- *[
116
- Conv(
117
- in_channels=int(256 * width),
118
- out_channels=int(256 * width),
119
- ksize=3,
120
- stride=1,
121
- act=act,
122
- ),
123
- Conv(
124
- in_channels=int(256 * width),
125
- out_channels=int(256 * width),
126
- ksize=3,
127
- stride=1,
128
- act=act,
129
- ),
130
- ]
131
- )
132
- )
133
- self.cls_preds.append(
134
- nn.Conv2d(
135
- in_channels=int(256 * width),
136
- out_channels=self.num_classes,
137
- kernel_size=1,
138
- stride=1,
139
- padding=0,
140
- )
141
- )
142
- self.reg_preds.append(
143
- nn.Conv2d(
144
- in_channels=int(256 * width),
145
- out_channels=4,
146
- kernel_size=1,
147
- stride=1,
148
- padding=0,
149
- )
150
- )
151
- self.obj_preds.append(
152
- nn.Conv2d(
153
- in_channels=int(256 * width),
154
- out_channels=1,
155
- kernel_size=1,
156
- stride=1,
157
- padding=0,
158
- )
159
- )
160
-
161
- self.use_l1 = False
162
- self.l1_loss = nn.L1Loss(reduction="none")
163
- self.bcewithlog_loss = nn.BCEWithLogitsLoss(reduction="none")
164
- self.iou_loss = None
165
- self.strides = strides
166
- self.grids = [torch.zeros(1)] * len(in_channels)
167
-
168
- def forward(self, xin, labels=None, imgs=None):
169
- outputs = []
170
- for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
171
- zip(self.cls_convs, self.reg_convs, self.strides, xin)
172
- ):
173
- x = self.stems[k](x)
174
- cls_x = x
175
- reg_x = x
176
-
177
- cls_feat = cls_conv(cls_x)
178
- cls_output = self.cls_preds[k](cls_feat)
179
-
180
- reg_feat = reg_conv(reg_x)
181
- reg_output = self.reg_preds[k](reg_feat)
182
- obj_output = self.obj_preds[k](reg_feat)
183
-
184
- output = torch.cat(
185
- [reg_output, obj_output.sigmoid(), cls_output.sigmoid()], 1
186
- )
187
-
188
- outputs.append(output)
189
-
190
- self.hw = [x.shape[-2:] for x in outputs]
191
- # [batch, n_anchors_all, 85]
192
- outputs = torch.cat(
193
- [x.flatten(start_dim=2) for x in outputs], dim=2
194
- ).permute(0, 2, 1)
195
- if self.decode_in_inference:
196
- return self.decode_outputs(outputs, dtype=xin[0].type())
197
- else:
198
- return outputs
199
-
200
- def get_output_and_grid(self, output, k, stride, dtype):
201
- grid = self.grids[k]
202
-
203
- batch_size = output.shape[0]
204
- n_ch = 5 + self.num_classes
205
- hsize, wsize = output.shape[-2:]
206
- if grid.shape[2:4] != output.shape[2:4]:
207
- yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
208
- grid = torch.stack((xv, yv), 2).view(1, 1, hsize, wsize, 2).type(dtype)
209
- self.grids[k] = grid
210
-
211
- output = output.view(batch_size, 1, n_ch, hsize, wsize)
212
- output = output.permute(0, 1, 3, 4, 2).reshape(
213
- batch_size, hsize * wsize, -1
214
- )
215
- grid = grid.view(1, -1, 2)
216
- output[..., :2] = (output[..., :2] + grid) * stride
217
- output[..., 2:4] = torch.exp(output[..., 2:4]) * stride
218
- return output, grid
219
-
220
- def decode_outputs(self, outputs, dtype):
221
- grids = []
222
- strides = []
223
- for (hsize, wsize), stride in zip(self.hw, self.strides):
224
- yv, xv = meshgrid([torch.arange(hsize), torch.arange(wsize)])
225
- grid = torch.stack((xv, yv), 2).view(1, -1, 2)
226
- grids.append(grid)
227
- shape = grid.shape[:2]
228
- strides.append(torch.full((*shape, 1), stride))
229
-
230
- grids = torch.cat(grids, dim=1).type(dtype)
231
- strides = torch.cat(strides, dim=1).type(dtype)
232
-
233
- outputs = torch.cat([
234
- (outputs[..., 0:2] + grids) * strides,
235
- torch.exp(outputs[..., 2:4]) * strides,
236
- outputs[..., 4:]
237
- ], dim=-1)
238
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolox/yolo_pafpn.py DELETED
@@ -1,119 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- #!/usr/bin/env python
5
- # -*- encoding: utf-8 -*-
6
- # Copyright (c) Megvii Inc. All rights reserved.
7
-
8
- import torch
9
- import torch.nn as nn
10
-
11
- from .darknet import CSPDarknet
12
- from .network_blocks import BaseConv, CSPLayer, DWConv
13
-
14
-
15
- class YOLOPAFPN(nn.Module):
16
- """
17
- YOLOv3 model. Darknet 53 is the default backbone of this model.
18
- """
19
-
20
- def __init__(
21
- self,
22
- depth=1.0,
23
- width=1.0,
24
- in_features=("dark3", "dark4", "dark5"),
25
- in_channels=[256, 512, 1024],
26
- depthwise=False,
27
- act="silu",
28
- ):
29
- super().__init__()
30
- self.backbone = CSPDarknet(depth, width, depthwise=depthwise, act=act)
31
- self.in_features = in_features
32
- self.in_channels = in_channels
33
- Conv = DWConv if depthwise else BaseConv
34
-
35
- self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
36
- self.lateral_conv0 = BaseConv(
37
- int(in_channels[2] * width), int(in_channels[1] * width), 1, 1, act=act
38
- )
39
- self.C3_p4 = CSPLayer(
40
- int(2 * in_channels[1] * width),
41
- int(in_channels[1] * width),
42
- round(3 * depth),
43
- False,
44
- depthwise=depthwise,
45
- act=act,
46
- ) # cat
47
-
48
- self.reduce_conv1 = BaseConv(
49
- int(in_channels[1] * width), int(in_channels[0] * width), 1, 1, act=act
50
- )
51
- self.C3_p3 = CSPLayer(
52
- int(2 * in_channels[0] * width),
53
- int(in_channels[0] * width),
54
- round(3 * depth),
55
- False,
56
- depthwise=depthwise,
57
- act=act,
58
- )
59
-
60
- # bottom-up conv
61
- self.bu_conv2 = Conv(
62
- int(in_channels[0] * width), int(in_channels[0] * width), 3, 2, act=act
63
- )
64
- self.C3_n3 = CSPLayer(
65
- int(2 * in_channels[0] * width),
66
- int(in_channels[1] * width),
67
- round(3 * depth),
68
- False,
69
- depthwise=depthwise,
70
- act=act,
71
- )
72
-
73
- # bottom-up conv
74
- self.bu_conv1 = Conv(
75
- int(in_channels[1] * width), int(in_channels[1] * width), 3, 2, act=act
76
- )
77
- self.C3_n4 = CSPLayer(
78
- int(2 * in_channels[1] * width),
79
- int(in_channels[2] * width),
80
- round(3 * depth),
81
- False,
82
- depthwise=depthwise,
83
- act=act,
84
- )
85
-
86
- def forward(self, input):
87
- """
88
- Args:
89
- inputs: input images.
90
-
91
- Returns:
92
- Tuple[Tensor]: FPN feature.
93
- """
94
-
95
- # backbone
96
- out_features = self.backbone(input)
97
- features = [out_features[f] for f in self.in_features]
98
- [x2, x1, x0] = features
99
-
100
- fpn_out0 = self.lateral_conv0(x0) # 1024->512/32
101
- f_out0 = self.upsample(fpn_out0) # 512/16
102
- f_out0 = torch.cat([f_out0, x1], 1) # 512->1024/16
103
- f_out0 = self.C3_p4(f_out0) # 1024->512/16
104
-
105
- fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16
106
- f_out1 = self.upsample(fpn_out1) # 256/8
107
- f_out1 = torch.cat([f_out1, x2], 1) # 256->512/8
108
- pan_out2 = self.C3_p3(f_out1) # 512->256/8
109
-
110
- p_out1 = self.bu_conv2(pan_out2) # 256->256/16
111
- p_out1 = torch.cat([p_out1, fpn_out1], 1) # 256->512/16
112
- pan_out1 = self.C3_n3(p_out1) # 512->512/16
113
-
114
- p_out0 = self.bu_conv1(pan_out1) # 512->512/32
115
- p_out0 = torch.cat([p_out0, fpn_out0], 1) # 512->1024/32
116
- pan_out0 = self.C3_n4(p_out0) # 1024->1024/32
117
-
118
- outputs = (pan_out2, pan_out1, pan_out0)
119
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolox/yolox.py DELETED
@@ -1,35 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- #!/usr/bin/env python
5
- # -*- encoding: utf-8 -*-
6
- # Copyright (c) Megvii Inc. All rights reserved.
7
-
8
- import torch.nn as nn
9
-
10
- from .yolo_head import YOLOXHead
11
- from .yolo_pafpn import YOLOPAFPN
12
-
13
-
14
- class YOLOX(nn.Module):
15
- """
16
- YOLOX model module. The module list is defined by create_yolov3_modules function.
17
- The network returns loss values from three YOLO layers during training
18
- and detection results during test.
19
- """
20
-
21
- def __init__(self, backbone=None, head=None):
22
- super().__init__()
23
- if backbone is None:
24
- backbone = YOLOPAFPN()
25
- if head is None:
26
- head = YOLOXHead(80)
27
-
28
- self.backbone = backbone
29
- self.head = head
30
-
31
- def forward(self, x, targets=None):
32
- assert not self.training, "Training mode not supported, please refer to the YOLOX repo"
33
- fpn_outs = self.backbone(x)
34
- outputs = self.head(fpn_outs)
35
- return outputs