File size: 3,700 Bytes
9954323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

COLORS = [
    "#003EFF",
    "#FF8F00",
    "#079700",
    "#A123FF",
    "#87CEEB",
    "#FF5733",
    "#C70039",
    "#900C3F",
    "#581845",
    "#11998E",
]


def reformat_for_plotting(labels, bboxes, scores, shape, num_classes):
    """
    Reformat YOLOX predictions for plotting.

    Args:
        labels (np.ndarray): Array of labels.
        bboxes (np.ndarray): Array of bounding boxes.
        scores (np.ndarray): Array of confidence scores.
        shape (tuple): Shape of the image.
        num_classes (int): Number of classes.

    Returns:
        list[np.ndarray]: List of box bounding boxes per class.
        list[np.ndarray]: List of confidence scores per class.
    """
    boxes_plot = bboxes.copy()
    boxes_plot[:, [0, 2]] *= shape[1]
    boxes_plot[:, [1, 3]] *= shape[0]
    boxes_plot = boxes_plot.astype(int)
    boxes_plot[:, 2] -= boxes_plot[:, 0]
    boxes_plot[:, 3] -= boxes_plot[:, 1]
    boxes_plot = [boxes_plot[labels == c] for c in range(num_classes)]
    confs = [scores[labels == c] for c in range(num_classes)]
    return boxes_plot, confs


def plot_sample(img, boxes_list, confs_list, labels):
    """
    Plots an image with bounding boxes.
    Coordinates are expected in format [x_min, y_min, width, height].

    Args:
        img (numpy.ndarray): The input image to be plotted.
        boxes_list (list[np.ndarray]): List of box bounding boxes per class.
        confs_list (list[np.ndarray]): List of confidence scores per class.
        labels (list): List of class labels.
    """
    plt.imshow(img, cmap="gray")
    plt.axis(False)

    for boxes, confs, col, l in zip(boxes_list, confs_list, COLORS, labels):
        for box_idx, box in enumerate(boxes):
            # Better display around boundaries
            h, w, _ = img.shape
            box = np.copy(box)
            box[:2] = np.clip(box[:2], 2, max(h, w))
            box[2] = min(box[2], w - 2 - box[0])
            box[3] = min(box[3], h - 2 - box[1])

            rect = Rectangle(
                (box[0], box[1]),
                box[2],
                box[3],
                linewidth=2,
                facecolor="none",
                edgecolor=col,
            )
            plt.gca().add_patch(rect)

            # Add class and index label with proper alignment
            plt.text(
                box[0], box[1],
                f"{l}_{box_idx}   conf={confs[box_idx]:.3f}",
                color='white',
                fontsize=8,
                bbox=dict(facecolor=col, alpha=1, edgecolor=col, pad=0, linewidth=2),
                verticalalignment='bottom',
                horizontalalignment='left'
            )


def postprocess_preds_page_element(preds, thresholds_per_class, class_labels):
    """
    Post process predictions for the page element task.
    - Applies thresholding

    Args:
        preds (dict): Predictions. Keys are "scores", "boxes", "labels".
        thresholds_per_class (dict): Thresholds per class.
        labels (list): List of class labels.

    Returns:
        labels (numpy.ndarray): Array of labels.
        bboxes (numpy.ndarray): Array of bounding boxes.
        scores (numpy.ndarray): Array of scores.
    """
    labels = preds["labels"].cpu().numpy()
    boxes = preds["boxes"].cpu().numpy()
    scores = preds["scores"].cpu().numpy()

    # Threshold per class
    thresholds = np.array(
        [thresholds_per_class[class_labels[int(x)]] for x in labels]
    )
    labels = labels[scores > thresholds]
    boxes = boxes[scores > thresholds]
    scores = scores[scores > thresholds]

    return labels, boxes, scores