Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The Deeplab2 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Implementation of Depth-aware Segmentation and Tracking Quality (DSTQ) metric.""" | |
| import collections | |
| from typing import Sequence, List, Tuple | |
| import tensorflow as tf | |
| from deeplab2.evaluation import segmentation_and_tracking_quality as stq | |
| class DSTQuality(stq.STQuality): | |
| """Metric class for Depth-aware Segmentation and Tracking Quality (DSTQ). | |
| This metric computes STQ and the inlier depth metric (or depth quality (DQ)) | |
| under several thresholds. Then it returns the geometric mean of DQ's, AQ and | |
| IoU to get the final DSTQ, i.e., | |
| DSTQ@{threshold_1} = pow(STQ ** 2 * DQ@{threshold_1}, 1/3) | |
| DSTQ@{threshold_2} = pow(STQ ** 2 * DQ@{threshold_2}, 1/3) | |
| ... | |
| DSTQ = pow(STQ ** 2 * DQ, 1/3) | |
| where DQ = pow(prod_i^n(threshold_i), 1/n) for n depth thresholds. | |
| The default choices for depth thresholds are 1.1 and 1.25, i.e., | |
| max(pred/gt, gt/pred) <= 1.1 and max(pred/gt, gt/pred) <= 1.25. | |
| Commonly used thresholds for the inlier metrics are 1.25, 1.25**2, 1.25**3. | |
| These thresholds are so loose that many methods achieves > 99%. | |
| Therefore, we choose 1.25 and 1.1 to encourage high-precision predictions. | |
| Example usage: | |
| dstq_obj = depth_aware_segmentation_and_tracking_quality.DSTQuality( | |
| num_classes, things_list, ignore_label, max_instances_per_category, | |
| offset, depth_threshold) | |
| dstq.update_state(y_true_1, y_pred_1, d_true_1, d_pred_1) | |
| dstq.update_state(y_true_2, y_pred_2, d_true_2, d_pred_2) | |
| ... | |
| result = dstq_obj.result().numpy() | |
| """ | |
| _depth_threshold: Tuple[float, float] = (1.25, 1.1) | |
| _depth_total_counts: collections.OrderedDict | |
| _depth_inlier_counts: List[collections.OrderedDict] | |
| def __init__(self, | |
| num_classes: int, | |
| things_list: Sequence[int], | |
| ignore_label: int, | |
| max_instances_per_category: int, | |
| offset: int, | |
| depth_threshold: Tuple[float] = (1.25, 1.1), | |
| name: str = 'dstq',): # pytype: disable=annotation-type-mismatch | |
| """Initialization of the DSTQ metric. | |
| Args: | |
| num_classes: Number of classes in the dataset as an integer. | |
| things_list: A sequence of class ids that belong to `things`. | |
| ignore_label: The class id to be ignored in evaluation as an integer or | |
| integer tensor. | |
| max_instances_per_category: The maximum number of instances for each class | |
| as an integer or integer tensor. | |
| offset: The maximum number of unique labels as an integer or integer | |
| tensor. | |
| depth_threshold: A sequence of depth thresholds for the depth quality. | |
| (default: (1.25, 1.1)) | |
| name: An optional name. (default: 'dstq') | |
| """ | |
| super().__init__(num_classes, things_list, ignore_label, | |
| max_instances_per_category, offset, name) | |
| if not (isinstance(depth_threshold, tuple) or | |
| isinstance(depth_threshold, list)): | |
| raise TypeError('The type of depth_threshold must be tuple or list.') | |
| if not depth_threshold: | |
| raise ValueError('depth_threshold must be non-empty.') | |
| self._depth_threshold = tuple(depth_threshold) | |
| self._depth_total_counts = collections.OrderedDict() | |
| self._depth_inlier_counts = [] | |
| for _ in range(len(self._depth_threshold)): | |
| self._depth_inlier_counts.append(collections.OrderedDict()) | |
| def update_state(self, | |
| y_true: tf.Tensor, | |
| y_pred: tf.Tensor, | |
| d_true: tf.Tensor, | |
| d_pred: tf.Tensor, | |
| sequence_id: int = 0): | |
| """Accumulates the depth-aware segmentation and tracking quality statistics. | |
| Args: | |
| y_true: The ground-truth panoptic label map for a particular video frame | |
| (defined as semantic_map * max_instances_per_category + instance_map). | |
| y_pred: The predicted panoptic label map for a particular video frame | |
| (defined as semantic_map * max_instances_per_category + instance_map). | |
| d_true: The ground-truth depth map for this video frame. | |
| d_pred: The predicted depth map for this video frame. | |
| sequence_id: The optional ID of the sequence the frames belong to. When no | |
| sequence is given, all frames are considered to belong to the same | |
| sequence (default: 0). | |
| """ | |
| super().update_state(y_true, y_pred, sequence_id) | |
| # Valid depth labels contain positive values. | |
| d_valid_mask = d_true > 0 | |
| d_valid_total = tf.reduce_sum(tf.cast(d_valid_mask, tf.int32)) | |
| # Valid depth prediction is expected to contain positive values. | |
| d_valid_mask = tf.logical_and(d_valid_mask, d_pred > 0) | |
| d_valid_true = tf.boolean_mask(d_true, d_valid_mask) | |
| d_valid_pred = tf.boolean_mask(d_pred, d_valid_mask) | |
| inlier_error = tf.maximum(d_valid_pred / d_valid_true, | |
| d_valid_true / d_valid_pred) | |
| # For each threshold, count the number of inliers. | |
| for threshold_index, threshold in enumerate(self._depth_threshold): | |
| num_inliers = tf.reduce_sum(tf.cast(inlier_error <= threshold, tf.int32)) | |
| inlier_counts = self._depth_inlier_counts[threshold_index] | |
| inlier_counts[sequence_id] = (inlier_counts.get(sequence_id, 0) + | |
| int(num_inliers.numpy())) | |
| # Update the total counts of the depth labels. | |
| self._depth_total_counts[sequence_id] = ( | |
| self._depth_total_counts.get(sequence_id, 0) + | |
| int(d_valid_total.numpy())) | |
| def result(self): | |
| """Computes the depth-aware segmentation and tracking quality. | |
| Returns: | |
| A dictionary containing: | |
| - 'STQ': The total STQ score. | |
| - 'AQ': The total association quality (AQ) score. | |
| - 'IoU': The total mean IoU. | |
| - 'STQ_per_seq': A list of the STQ score per sequence. | |
| - 'AQ_per_seq': A list of the AQ score per sequence. | |
| - 'IoU_per_seq': A list of mean IoU per sequence. | |
| - 'Id_per_seq': A list of sequence Ids to map list index to sequence. | |
| - 'Length_per_seq': A list of the length of each sequence. | |
| - 'DSTQ': The total DSTQ score. | |
| - 'DSTQ@thres': The total DSTQ score for threshold thres | |
| - 'DSTQ_per_seq@thres': A list of DSTQ score per sequence for thres. | |
| - 'DQ': The total DQ score. | |
| - 'DQ@thres': The total DQ score for threshold thres. | |
| - 'DQ_per_seq@thres': A list of DQ score per sequence for thres. | |
| """ | |
| # Gather the results for STQ. | |
| stq_results = super().result() | |
| # Collect results for depth quality per sequecne and threshold. | |
| dq_per_seq_at_threshold = {} | |
| dq_at_threshold = {} | |
| for threshold_index, threshold in enumerate(self._depth_threshold): | |
| dq_per_seq_at_threshold[threshold] = [0] * len(self._ground_truth) | |
| total_count = 0 | |
| inlier_count = 0 | |
| # Follow the order of computing STQ by enumerating _ground_truth. | |
| for index, sequence_id in enumerate(self._ground_truth): | |
| sequence_inlier = self._depth_inlier_counts[threshold_index][ | |
| sequence_id] | |
| sequence_total = self._depth_total_counts[sequence_id] | |
| if sequence_total > 0: | |
| dq_per_seq_at_threshold[threshold][ | |
| index] = sequence_inlier / sequence_total | |
| total_count += sequence_total | |
| inlier_count += sequence_inlier | |
| if total_count == 0: | |
| dq_at_threshold[threshold] = 0 | |
| else: | |
| dq_at_threshold[threshold] = inlier_count / total_count | |
| # Compute DQ as the geometric mean of DQ's at different thresholds. | |
| dq = 1 | |
| for _, threshold in enumerate(self._depth_threshold): | |
| dq *= dq_at_threshold[threshold] | |
| dq = dq ** (1 / len(self._depth_threshold)) | |
| dq_results = {} | |
| dq_results['DQ'] = dq | |
| for _, threshold in enumerate(self._depth_threshold): | |
| dq_results['DQ@{}'.format(threshold)] = dq_at_threshold[threshold] | |
| dq_results['DQ_per_seq@{}'.format( | |
| threshold)] = dq_per_seq_at_threshold[threshold] | |
| # Combine STQ and DQ to get DSTQ. | |
| dstq_results = {} | |
| dstq_results['DSTQ'] = (stq_results['STQ'] ** 2 * dq) ** (1/3) | |
| for _, threshold in enumerate(self._depth_threshold): | |
| dstq_results['DSTQ@{}'.format(threshold)] = ( | |
| stq_results['STQ'] ** 2 * dq_at_threshold[threshold]) ** (1/3) | |
| dstq_results['DSTQ_per_seq@{}'.format(threshold)] = [ | |
| (stq_result**2 * dq_result)**(1 / 3) for stq_result, dq_result in zip( | |
| stq_results['STQ_per_seq'], dq_per_seq_at_threshold[threshold]) | |
| ] | |
| # Merge all the results. | |
| dstq_results.update(stq_results) | |
| dstq_results.update(dq_results) | |
| return dstq_results | |
| def reset_states(self): | |
| """Resets all states that accumulated data.""" | |
| super().reset_states() | |
| self._depth_total_counts = collections.OrderedDict() | |
| self._depth_inlier_counts = [] | |
| for _ in range(len(self._depth_threshold)): | |
| self._depth_inlier_counts.append(collections.OrderedDict()) | |