Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import plotly.express as px | |
| import cv2 | |
| from src.error_analysis import ErrorAnalysis, transform_gt_bbox_format | |
| import yaml | |
| import os | |
| from src.confusion_matrix import ConfusionMatrix | |
| from plotly.subplots import make_subplots | |
| import plotly.graph_objects as go | |
| import pandas as pd | |
| def amend_cm_df(cm_df, labels_dict): | |
| """Helper function to amend the index and column name for readability | |
| Example - index currently is 0, 1 ... -> GT - person | |
| Likewise in Column - 0, 1 ... -> Pred - person etc | |
| Args: | |
| cm_df (_type_): _description_ | |
| labels_dict (_type_): _description_ | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| index_list = list(labels_dict.values()) | |
| index_list.append("background") | |
| cm_df = cm_df.set_axis([f"GT - {elem}" for elem in index_list]) | |
| cm_df = cm_df.set_axis([f"Pred - {elem}" for elem in index_list], axis=1) | |
| cm_df = cm_df.astype(int) | |
| return cm_df | |
| class ImageTool: | |
| def __init__(self, cfg_path="cfg/cfg.yml"): | |
| # inistialising the model and getting the annotations | |
| self.ea_obj = ErrorAnalysis(cfg_path) | |
| cfg_file = open(cfg_path) | |
| self.cfg_obj = yaml.load(cfg_file, Loader=yaml.FullLoader) | |
| self.inference_folder = self.ea_obj.inference_folder | |
| self.ea_obj.get_annots() | |
| self.gt_annots = self.ea_obj.gt_dict | |
| self.all_img = os.listdir(self.inference_folder) | |
| # for labels | |
| self.labels_dict = self.cfg_obj["error_analysis"]["labels_dict"] | |
| self.labels_dict = {v: k for k, v in self.labels_dict.items()} | |
| self.idx_base = self.cfg_obj["error_analysis"]["idx_base"] | |
| # for visualisation | |
| self.bbox_thickness = self.cfg_obj["visual_tool"]["bbox_thickness"] | |
| self.font_scale = self.cfg_obj["visual_tool"]["font_scale"] | |
| self.font_thickness = self.cfg_obj["visual_tool"]["font_thickness"] | |
| self.pred_colour = tuple(self.cfg_obj["visual_tool"]["pred_colour"]) | |
| self.gt_colour = tuple(self.cfg_obj["visual_tool"]["gt_colour"]) | |
| def show_img(self, img_fname="000000011149.jpg", show_preds=False, show_gt=False): | |
| """_summary_ | |
| Args: | |
| img_fname (str, optional): _description_. Defaults to "000000011149.jpg". | |
| show_preds (bool, optional): _description_. Defaults to False. | |
| show_gt (bool, optional): _description_. Defaults to False. | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| img = cv2.imread(f"{self.inference_folder}{img_fname}") | |
| labels = {"x": "X", "y": "Y", "color": "Colour"} | |
| if show_preds: | |
| preds = self.get_preds(img_fname) | |
| img = self.draw_pred_bboxes(img, preds) | |
| if show_gt: | |
| gt_annots = self.get_gt_annot(img_fname) | |
| img = self.draw_gt_bboxes(img, gt_annots) | |
| fig = px.imshow(img[..., ::-1], aspect="equal", labels=labels) | |
| if show_gt and show_preds: | |
| cm_df, cm_tpfpfn_dict = self.generate_cm_one_image(preds, gt_annots) | |
| return [fig, cm_df, cm_tpfpfn_dict] | |
| return fig | |
| def show_img_sbs(self, img_fname="000000011149.jpg"): | |
| """_summary_ | |
| Args: | |
| img_fname (str, optional): _description_. Defaults to "000000011149.jpg". | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| # shows the image side by side | |
| img = cv2.imread(f"{self.inference_folder}{img_fname}") | |
| labels = {"x": "X", "y": "Y", "color": "Colour"} | |
| img_pred = img.copy() | |
| img_gt = img.copy() | |
| preds = self.get_preds(img_fname) | |
| img_pred = self.draw_pred_bboxes(img_pred, preds) | |
| gt_annots = self.get_gt_annot(img_fname) | |
| img_gt = self.draw_gt_bboxes(img_gt, gt_annots) | |
| fig1 = px.imshow(img_gt[..., ::-1], aspect="equal", labels=labels) | |
| fig2 = px.imshow(img_pred[..., ::-1], aspect="equal", labels=labels) | |
| fig2.update_yaxes(visible=False) | |
| cm_df, cm_tpfpfn_df = self.generate_cm_one_image(preds, gt_annots) | |
| return [fig1, fig2, cm_df, cm_tpfpfn_df] | |
| def generate_cm_one_image(self, preds, gt_annots): | |
| """_summary_ | |
| Args: | |
| preds (_type_): _description_ | |
| gt_annots (_type_): _description_ | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| num_classes = len(list(self.cfg_obj["error_analysis"]["labels_dict"].keys())) | |
| idx_base = self.cfg_obj["error_analysis"]["idx_base"] | |
| conf_threshold, iou_threshold = ( | |
| self.ea_obj.model.score_threshold, | |
| self.ea_obj.model.iou_threshold, | |
| ) | |
| cm = ConfusionMatrix( | |
| num_classes=num_classes, | |
| CONF_THRESHOLD=conf_threshold, | |
| IOU_THRESHOLD=iou_threshold, | |
| ) | |
| gt_annots[:, 0] -= idx_base | |
| preds[:, -1] -= idx_base | |
| cm.process_batch(preds, gt_annots) | |
| confusion_matrix_df = cm.return_as_df() | |
| cm.get_tpfpfn() | |
| cm_tpfpfn_dict = { | |
| "True Positive": cm.tp, | |
| "False Positive": cm.fp, | |
| "False Negative": cm.fn, | |
| } | |
| cm_tpfpfn_df = pd.DataFrame(cm_tpfpfn_dict, index=[0]) | |
| cm_tpfpfn_df = cm_tpfpfn_df.set_axis(["Values"], axis=0) | |
| cm_tpfpfn_df = cm_tpfpfn_df.astype(int) | |
| # amend df | |
| confusion_matrix_df = amend_cm_df(confusion_matrix_df, self.labels_dict) | |
| # print (cm.matrix) | |
| return confusion_matrix_df, cm_tpfpfn_df | |
| def get_preds(self, img_fname="000000011149.jpg"): | |
| """_summary_ | |
| Args: | |
| img_fname (str, optional): _description_. Defaults to "000000011149.jpg". | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| # run inference using the error analysis object per image | |
| outputs, img_shape = self.ea_obj.generate_inference(img_fname) | |
| # converts image coordinates from normalised to integer values | |
| # image shape is [Y, X, C] (because Rows are Y) | |
| # So don't get confused! | |
| outputs[:, 0] *= img_shape[1] | |
| outputs[:, 1] *= img_shape[0] | |
| outputs[:, 2] *= img_shape[1] | |
| outputs[:, 3] *= img_shape[0] | |
| return outputs | |
| def get_gt_annot(self, img_fname): | |
| """_summary_ | |
| Args: | |
| img_fname (_type_): _description_ | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| ground_truth = self.gt_annots[img_fname].copy() | |
| img = cv2.imread(f"{self.inference_folder}{img_fname}") | |
| img_shape = img.shape | |
| ground_truth = transform_gt_bbox_format(ground_truth, img_shape, format="coco") | |
| # converts image coordinates from normalised to integer values | |
| # image shape is [Y, X, C] (because Rows are Y) | |
| # So don't get confused! | |
| ground_truth[:, 1] *= img_shape[1] | |
| ground_truth[:, 2] *= img_shape[0] | |
| ground_truth[:, 3] *= img_shape[1] | |
| ground_truth[:, 4] *= img_shape[0] | |
| return ground_truth | |
| def draw_pred_bboxes(self, img_pred, preds): | |
| """_summary_ | |
| Args: | |
| img_pred (_type_): _description_ | |
| preds (_type_): _description_ | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| for pred in preds: | |
| pred = pred.astype(int) | |
| img_pred = cv2.rectangle( | |
| img_pred, | |
| (pred[0], pred[1]), | |
| (pred[2], pred[3]), | |
| color=self.pred_colour, | |
| thickness=self.bbox_thickness, | |
| ) | |
| img_pred = cv2.putText( | |
| img_pred, | |
| self.labels_dict[pred[5]], | |
| (pred[0] + 5, pred[1] + 25), | |
| color=self.pred_colour, | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
| fontScale=self.font_scale, | |
| thickness=self.font_thickness, | |
| ) | |
| return img_pred | |
| def draw_gt_bboxes(self, img_gt, gt_annots, **kwargs): | |
| """_summary_ | |
| Args: | |
| img_gt (_type_): _description_ | |
| gt_annots (_type_): _description_ | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| for annot in gt_annots: | |
| annot = annot.astype(int) | |
| # print (annot) | |
| img_gt = cv2.rectangle( | |
| img_gt, | |
| (annot[1], annot[2]), | |
| (annot[3], annot[4]), | |
| color=self.gt_colour, | |
| thickness=self.bbox_thickness, | |
| ) | |
| img_gt = cv2.putText( | |
| img_gt, | |
| self.labels_dict[annot[0]], | |
| (annot[1] + 5, annot[2] + 25), | |
| color=(0, 255, 0), | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
| fontScale=self.font_scale, | |
| thickness=self.font_thickness, | |
| ) | |
| return img_gt | |
| def plot_with_preds_gt(self, option, side_by_side=False, plot_type=None): | |
| """Rules on what plot to generate | |
| Args: | |
| option (_string_): image filename. Toggled on the app itself. See app.py | |
| side_by_side (bool, optional): Whether to have two plots side by side. | |
| Defaults to False. | |
| plot_type (_type_, optional): "all" - both GT and pred will be plotted, | |
| "pred" - only preds, | |
| "GT" - only ground truth | |
| None - only image generated | |
| Will be overridden if side_by_side = True | |
| Defaults to None. | |
| """ | |
| if plot_type == "all": | |
| plot, df, cm_tpfpfn_df = self.show_img( | |
| option, show_preds=True, show_gt=True | |
| ) | |
| st.plotly_chart(plot, use_container_width=True) | |
| st.caption("Blue: Model BBox, Green: GT BBox") | |
| st.table(df) | |
| st.table(cm_tpfpfn_df) | |
| elif plot_type == "pred": | |
| st.plotly_chart( | |
| self.show_img(option, show_preds=True), use_container_width=True | |
| ) | |
| elif plot_type == "gt": | |
| st.plotly_chart( | |
| self.show_img(option, show_gt=True), use_container_width=True | |
| ) | |
| elif side_by_side: | |
| plot1, plot2, df, cm_tpfpfn_df = self.show_img_sbs(option) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| col1.subheader("Ground Truth") | |
| st.plotly_chart(plot1, use_container_width=True) | |
| with col2: | |
| col2.subheader("Prediction") | |
| st.plotly_chart(plot2, use_container_width=True) | |
| st.table(df) | |
| st.table(cm_tpfpfn_df) | |
| else: | |
| st.plotly_chart(self.show_img(option), use_container_width=True) | |