Spaces:
Sleeping
Sleeping
| # Code copied and modified from https://huggingface.co/spaces/BAAI/SegVol/blob/main/utils.py | |
| from pathlib import Path | |
| import matplotlib as mpl | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import SimpleITK as sitk | |
| import torch | |
| from mrsegmentator import inference | |
| from mrsegmentator.utils import add_postfix | |
| from PIL import Image | |
| from scipy import ndimage | |
| import streamlit as st | |
| initial_rectangle = { | |
| "version": "4.4.0", | |
| "objects": [ | |
| { | |
| "type": "rect", | |
| "version": "4.4.0", | |
| "originX": "left", | |
| "originY": "top", | |
| "left": 50, | |
| "top": 50, | |
| "width": 100, | |
| "height": 100, | |
| "fill": "rgba(255, 165, 0, 0.3)", | |
| "stroke": "#2909F1", | |
| "strokeWidth": 3, | |
| "strokeDashArray": None, | |
| "strokeLineCap": "butt", | |
| "strokeDashOffset": 0, | |
| "strokeLineJoin": "miter", | |
| "strokeUniform": True, | |
| "strokeMiterLimit": 4, | |
| "scaleX": 1, | |
| "scaleY": 1, | |
| "angle": 0, | |
| "flipX": False, | |
| "flipY": False, | |
| "opacity": 1, | |
| "shadow": None, | |
| "visible": True, | |
| "backgroundColor": "", | |
| "fillRule": "nonzero", | |
| "paintFirst": "fill", | |
| "globalCompositeOperation": "source-over", | |
| "skewX": 0, | |
| "skewY": 0, | |
| "rx": 0, | |
| "ry": 0, | |
| } | |
| ], | |
| } | |
| def run(tmpdirname): | |
| if st.session_state.option is not None: | |
| image = Path(__file__).parent / str(st.session_state.option) | |
| inference.infer([image], tmpdirname, st.session_state.folds, split_level=1) | |
| seg_name = add_postfix(image.name, "seg") | |
| preds_path = tmpdirname + "/" + seg_name | |
| st.session_state.preds_3D = read_image(preds_path) | |
| st.session_state.preds_3D_ori = sitk.ReadImage(preds_path) | |
| def reflect_box_into_model(box_3d): | |
| z1, y1, x1, z2, y2, x2 = box_3d | |
| x1_prompt = int(x1 * 256.0 / 325.0) | |
| y1_prompt = int(y1 * 256.0 / 325.0) | |
| z1_prompt = int(z1 * 32.0 / 325.0) | |
| x2_prompt = int(x2 * 256.0 / 325.0) | |
| y2_prompt = int(y2 * 256.0 / 325.0) | |
| z2_prompt = int(z2 * 32.0 / 325.0) | |
| return torch.tensor(np.array([z1_prompt, y1_prompt, x1_prompt, z2_prompt, y2_prompt, x2_prompt])) | |
| def reflect_json_data_to_3D_box(json_data, view): | |
| if view == "xy": | |
| st.session_state.rectangle_3Dbox[1] = json_data["objects"][0]["top"] | |
| st.session_state.rectangle_3Dbox[2] = json_data["objects"][0]["left"] | |
| st.session_state.rectangle_3Dbox[4] = ( | |
| json_data["objects"][0]["top"] + json_data["objects"][0]["height"] * json_data["objects"][0]["scaleY"] | |
| ) | |
| st.session_state.rectangle_3Dbox[5] = ( | |
| json_data["objects"][0]["left"] + json_data["objects"][0]["width"] * json_data["objects"][0]["scaleX"] | |
| ) | |
| print(st.session_state.rectangle_3Dbox) | |
| def make_fig(image, preds, px_range=(10, 400), transparency=0.5): | |
| fig, ax = plt.subplots(1, 1, figsize=(4, 4)) | |
| image_slice = image.clip(*px_range) | |
| ax.imshow( | |
| image_slice, | |
| cmap="Greys_r", | |
| vmin=px_range[0], | |
| vmax=px_range[1], | |
| ) | |
| if preds is not None: | |
| image_slice = np.array(preds) | |
| alpha = np.zeros(image_slice.shape) | |
| alpha[image_slice > 0.1] = transparency | |
| ax.imshow( | |
| image_slice, | |
| cmap="jet", | |
| alpha=alpha, | |
| vmin=0, | |
| vmax=40, | |
| ) | |
| # plot edges | |
| edge_slice = np.zeros(image_slice.shape, dtype=int) | |
| for i in np.unique(image_slice): | |
| _slice = image_slice.copy() | |
| _slice[_slice != i] = 0 | |
| edges = ndimage.laplace(_slice) | |
| edge_slice[edges != 0] = i | |
| cmap = mpl.cm.jet(np.linspace(0, 1, int(preds.max()))) | |
| cmap -= 0.4 | |
| cmap = cmap.clip(0, 1) | |
| cmap = mpl.colors.ListedColormap(cmap) | |
| alpha = np.zeros(edge_slice.shape) | |
| alpha[edge_slice > 0.01] = 0.9 | |
| ax.imshow( | |
| edge_slice, | |
| alpha=alpha, | |
| cmap=cmap, | |
| vmin=0, | |
| vmax=40, | |
| ) | |
| plt.axis("off") | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| fig.canvas.draw() | |
| # transform to image | |
| return Image.frombytes("RGB", fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) | |
| ####################################### | |
| def make_isotropic(image, interpolator=sitk.sitkLinear, spacing=None): | |
| """ | |
| Many file formats (e.g. jpg, png,...) expect the pixels to be isotropic, same | |
| spacing for all axes. Saving non-isotropic data in these formats will result in | |
| distorted images. This function makes an image isotropic via resampling, if needed. | |
| Args: | |
| image (SimpleITK.Image): Input image. | |
| interpolator: By default the function uses a linear interpolator. For | |
| label images one should use the sitkNearestNeighbor interpolator | |
| so as not to introduce non-existant labels. | |
| spacing (float): Desired spacing. If none given then use the smallest spacing from | |
| the original image. | |
| Returns: | |
| SimpleITK.Image with isotropic spacing which occupies the same region in space as | |
| the input image. | |
| """ | |
| original_spacing = image.GetSpacing() | |
| # Image is already isotropic, just return a copy. | |
| if all(spc == original_spacing[0] for spc in original_spacing): | |
| return sitk.Image(image) | |
| # Make image isotropic via resampling. | |
| original_size = image.GetSize() | |
| if spacing is None: | |
| spacing = min(original_spacing) | |
| new_spacing = [spacing] * image.GetDimension() | |
| new_size = [int(round(osz * ospc / spacing)) for osz, ospc in zip(original_size, original_spacing)] | |
| return sitk.Resample( | |
| image, | |
| new_size, | |
| sitk.Transform(), | |
| interpolator, | |
| image.GetOrigin(), | |
| new_spacing, | |
| image.GetDirection(), | |
| 0, # default pixel value | |
| image.GetPixelID(), | |
| ) | |
| def read_image(path): | |
| img = sitk.ReadImage(path) | |
| img = sitk.DICOMOrient(img, "LPS") | |
| img = make_isotropic(img) | |
| img = sitk.GetArrayFromImage(img) | |
| return img | |