Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The Deeplab2 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Implementation of the Segmentation and Tracking Quality (STQ) metric.""" | |
| import collections | |
| from typing import MutableMapping, Sequence, Dict, Text, Any | |
| import numpy as np | |
| import tensorflow as tf | |
| def _update_dict_stats(stat_dict: MutableMapping[int, tf.Tensor], | |
| id_array: tf.Tensor): | |
| """Updates a given dict with corresponding counts.""" | |
| ids, _, counts = tf.unique_with_counts(id_array) | |
| for idx, count in zip(ids.numpy(), counts): | |
| if idx in stat_dict: | |
| stat_dict[idx] += count | |
| else: | |
| stat_dict[idx] = count | |
| class STQuality(object): | |
| """Metric class for the Segmentation and Tracking Quality (STQ). | |
| The metric computes the geometric mean of two terms. | |
| - Association Quality: This term measures the quality of the track ID | |
| assignment for `thing` classes. It is formulated as a weighted IoU | |
| measure. | |
| - Segmentation Quality: This term measures the semantic segmentation quality. | |
| The standard class IoU measure is used for this. | |
| Example usage: | |
| stq_obj = segmentation_tracking_quality.STQuality(num_classes, things_list, | |
| ignore_label, max_instances_per_category, offset) | |
| stq_obj.update_state(y_true_1, y_pred_1) | |
| stq_obj.update_state(y_true_2, y_pred_2) | |
| ... | |
| result = stq_obj.result().numpy() | |
| """ | |
| def __init__(self, | |
| num_classes: int, | |
| things_list: Sequence[int], | |
| ignore_label: int, | |
| max_instances_per_category: int, | |
| offset: int, | |
| name='stq' | |
| ): | |
| """Initialization of the STQ metric. | |
| Args: | |
| num_classes: Number of classes in the dataset as an integer. | |
| things_list: A sequence of class ids that belong to `things`. | |
| ignore_label: The class id to be ignored in evaluation as an integer or | |
| integer tensor. | |
| max_instances_per_category: The maximum number of instances for each class | |
| as an integer or integer tensor. | |
| offset: The maximum number of unique labels as an integer or integer | |
| tensor. | |
| name: An optional name. (default: 'st_quality') | |
| """ | |
| self._name = name | |
| self._num_classes = num_classes | |
| self._ignore_label = ignore_label | |
| self._things_list = things_list | |
| self._max_instances_per_category = max_instances_per_category | |
| if ignore_label >= num_classes: | |
| self._confusion_matrix_size = num_classes + 1 | |
| self._include_indices = np.arange(self._num_classes) | |
| else: | |
| self._confusion_matrix_size = num_classes | |
| self._include_indices = np.array( | |
| [i for i in range(num_classes) if i != self._ignore_label]) | |
| self._iou_confusion_matrix_per_sequence = collections.OrderedDict() | |
| self._predictions = collections.OrderedDict() | |
| self._ground_truth = collections.OrderedDict() | |
| self._intersections = collections.OrderedDict() | |
| self._sequence_length = collections.OrderedDict() | |
| self._offset = offset | |
| lower_bound = num_classes * max_instances_per_category | |
| if offset < lower_bound: | |
| raise ValueError('The provided offset %d is too small. No guarantess ' | |
| 'about the correctness of the results can be made. ' | |
| 'Please choose an offset that is higher than num_classes' | |
| ' * max_instances_per_category = %d' % lower_bound) | |
| def update_state(self, y_true: tf.Tensor, y_pred: tf.Tensor, | |
| sequence_id=0): | |
| """Accumulates the segmentation and tracking quality statistics. | |
| Args: | |
| y_true: The ground-truth panoptic label map for a particular video frame | |
| (defined as semantic_map * max_instances_per_category + instance_map). | |
| y_pred: The predicted panoptic label map for a particular video frame | |
| (defined as semantic_map * max_instances_per_category + instance_map). | |
| sequence_id: The optional ID of the sequence the frames belong to. When no | |
| sequence is given, all frames are considered to belong to the same | |
| sequence (default: 0). | |
| """ | |
| y_true = tf.cast(y_true, dtype=tf.int64) | |
| y_pred = tf.cast(y_pred, dtype=tf.int64) | |
| semantic_label = y_true // self._max_instances_per_category | |
| semantic_prediction = y_pred // self._max_instances_per_category | |
| # Check if the ignore value is outside the range [0, num_classes]. If yes, | |
| # map `_ignore_label` to `_num_classes`, so it can be used to create the | |
| # confusion matrix. | |
| if self._ignore_label > self._num_classes: | |
| semantic_label = tf.where( | |
| tf.not_equal(semantic_label, self._ignore_label), semantic_label, | |
| self._num_classes) | |
| semantic_prediction = tf.where( | |
| tf.not_equal(semantic_prediction, self._ignore_label), | |
| semantic_prediction, self._num_classes) | |
| if sequence_id in self._iou_confusion_matrix_per_sequence: | |
| self._iou_confusion_matrix_per_sequence[sequence_id] += ( | |
| tf.math.confusion_matrix( | |
| tf.reshape(semantic_label, [-1]), | |
| tf.reshape(semantic_prediction, [-1]), | |
| self._confusion_matrix_size, | |
| dtype=tf.int64)) | |
| self._sequence_length[sequence_id] += 1 | |
| else: | |
| self._iou_confusion_matrix_per_sequence[sequence_id] = ( | |
| tf.math.confusion_matrix( | |
| tf.reshape(semantic_label, [-1]), | |
| tf.reshape(semantic_prediction, [-1]), | |
| self._confusion_matrix_size, | |
| dtype=tf.int64)) | |
| self._predictions[sequence_id] = {} | |
| self._ground_truth[sequence_id] = {} | |
| self._intersections[sequence_id] = {} | |
| self._sequence_length[sequence_id] = 1 | |
| instance_label = y_true % self._max_instances_per_category | |
| label_mask = tf.zeros_like(semantic_label, dtype=tf.bool) | |
| prediction_mask = tf.zeros_like(semantic_prediction, dtype=tf.bool) | |
| for things_class_id in self._things_list: | |
| label_mask = tf.logical_or(label_mask, | |
| tf.equal(semantic_label, things_class_id)) | |
| prediction_mask = tf.logical_or( | |
| prediction_mask, tf.equal(semantic_prediction, things_class_id)) | |
| # Select the `crowd` region of the current class. This region is encoded | |
| # instance id `0`. | |
| is_crowd = tf.logical_and(tf.equal(instance_label, 0), label_mask) | |
| # Select the non-crowd region of the corresponding class as the `crowd` | |
| # region is ignored for the tracking term. | |
| label_mask = tf.logical_and(label_mask, tf.logical_not(is_crowd)) | |
| # Do not punish id assignment for regions that are annotated as `crowd` in | |
| # the ground-truth. | |
| prediction_mask = tf.logical_and(prediction_mask, tf.logical_not(is_crowd)) | |
| seq_preds = self._predictions[sequence_id] | |
| seq_gts = self._ground_truth[sequence_id] | |
| seq_intersects = self._intersections[sequence_id] | |
| # Compute and update areas of ground-truth, predictions and intersections. | |
| _update_dict_stats(seq_preds, y_pred[prediction_mask]) | |
| _update_dict_stats(seq_gts, y_true[label_mask]) | |
| non_crowd_intersection = tf.logical_and(label_mask, prediction_mask) | |
| intersection_ids = ( | |
| y_true[non_crowd_intersection] * self._offset + | |
| y_pred[non_crowd_intersection]) | |
| _update_dict_stats(seq_intersects, intersection_ids) | |
| def result(self) -> Dict[Text, Any]: | |
| """Computes the segmentation and tracking quality. | |
| Returns: | |
| A dictionary containing: | |
| - 'STQ': The total STQ score. | |
| - 'AQ': The total association quality (AQ) score. | |
| - 'IoU': The total mean IoU. | |
| - 'STQ_per_seq': A list of the STQ score per sequence. | |
| - 'AQ_per_seq': A list of the AQ score per sequence. | |
| - 'IoU_per_seq': A list of mean IoU per sequence. | |
| - 'Id_per_seq': A list of sequence Ids to map list index to sequence. | |
| - 'Length_per_seq': A list of the length of each sequence. | |
| """ | |
| # Compute association quality (AQ) | |
| num_tubes_per_seq = [0] * len(self._ground_truth) | |
| aq_per_seq = [0] * len(self._ground_truth) | |
| iou_per_seq = [0] * len(self._ground_truth) | |
| id_per_seq = [''] * len(self._ground_truth) | |
| for index, sequence_id in enumerate(self._ground_truth): | |
| outer_sum = 0.0 | |
| predictions = self._predictions[sequence_id] | |
| ground_truth = self._ground_truth[sequence_id] | |
| intersections = self._intersections[sequence_id] | |
| num_tubes_per_seq[index] = len(ground_truth) | |
| id_per_seq[index] = sequence_id | |
| for gt_id, gt_size in ground_truth.items(): | |
| inner_sum = 0.0 | |
| for pr_id, pr_size in predictions.items(): | |
| tpa_key = self._offset * gt_id + pr_id | |
| if tpa_key in intersections: | |
| tpa = intersections[tpa_key].numpy() | |
| fpa = pr_size.numpy() - tpa | |
| fna = gt_size.numpy() - tpa | |
| inner_sum += tpa * (tpa / (tpa + fpa + fna)) | |
| outer_sum += 1.0 / gt_size.numpy() * inner_sum | |
| aq_per_seq[index] = outer_sum | |
| aq_mean = np.sum(aq_per_seq) / np.maximum(np.sum(num_tubes_per_seq), 1e-15) | |
| aq_per_seq = aq_per_seq / np.maximum(num_tubes_per_seq, 1e-15) | |
| # Compute IoU scores. | |
| # The rows correspond to ground-truth and the columns to predictions. | |
| # Remove fp from confusion matrix for the void/ignore class. | |
| total_confusion = np.zeros( | |
| (self._confusion_matrix_size, self._confusion_matrix_size), | |
| dtype=np.int64) | |
| for index, confusion in enumerate( | |
| self._iou_confusion_matrix_per_sequence.values()): | |
| confusion = confusion.numpy() | |
| removal_matrix = np.zeros_like(confusion) | |
| removal_matrix[self._include_indices, :] = 1.0 | |
| confusion *= removal_matrix | |
| total_confusion += confusion | |
| # `intersections` corresponds to true positives. | |
| intersections = confusion.diagonal() | |
| fps = confusion.sum(axis=0) - intersections | |
| fns = confusion.sum(axis=1) - intersections | |
| unions = intersections + fps + fns | |
| num_classes = np.count_nonzero(unions) | |
| ious = (intersections.astype(np.double) / | |
| np.maximum(unions, 1e-15).astype(np.double)) | |
| iou_per_seq[index] = np.sum(ious) / num_classes | |
| # `intersections` corresponds to true positives. | |
| intersections = total_confusion.diagonal() | |
| fps = total_confusion.sum(axis=0) - intersections | |
| fns = total_confusion.sum(axis=1) - intersections | |
| unions = intersections + fps + fns | |
| num_classes = np.count_nonzero(unions) | |
| ious = (intersections.astype(np.double) / | |
| np.maximum(unions, 1e-15).astype(np.double)) | |
| iou_mean = np.sum(ious) / num_classes | |
| st_quality = np.sqrt(aq_mean * iou_mean) | |
| st_quality_per_seq = np.sqrt(aq_per_seq * iou_per_seq) | |
| return {'STQ': st_quality, | |
| 'AQ': aq_mean, | |
| 'IoU': float(iou_mean), | |
| 'STQ_per_seq': st_quality_per_seq, | |
| 'AQ_per_seq': aq_per_seq, | |
| 'IoU_per_seq': iou_per_seq, | |
| 'ID_per_seq': id_per_seq, | |
| 'Length_per_seq': list(self._sequence_length.values()), | |
| } | |
| def reset_states(self): | |
| """Resets all states that accumulated data.""" | |
| self._iou_confusion_matrix_per_sequence = collections.OrderedDict() | |
| self._predictions = collections.OrderedDict() | |
| self._ground_truth = collections.OrderedDict() | |
| self._intersections = collections.OrderedDict() | |
| self._sequence_length = collections.OrderedDict() | |