Spaces:
Sleeping
Sleeping
| # Code copied and modified from: https://huggingface.co/spaces/BAAI/SegVol | |
| import tempfile | |
| from pathlib import Path | |
| import SimpleITK as sitk | |
| from mrsegmentator.utils import add_postfix | |
| import streamlit as st | |
| import utils | |
| print("script run") | |
| st.title("MRSegmentator") | |
| st.write("(On-site segmentation is currently disabled, because we lack access to GPUs)") | |
| ############################################# | |
| # init session_state | |
| if "option" not in st.session_state: | |
| st.session_state.option = None | |
| if "reset_demo_case" not in st.session_state: | |
| st.session_state.reset_demo_case = False | |
| if "preds_3D" not in st.session_state: | |
| st.session_state.preds_3D = None | |
| st.session_state.preds_path = None | |
| if "data_item" not in st.session_state: | |
| st.session_state.data_item = None | |
| if "rectangle_3Dbox" not in st.session_state: | |
| st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0] | |
| if "running" not in st.session_state: | |
| st.session_state.running = False | |
| if "transparency" not in st.session_state: | |
| st.session_state.transparency = 0.25 | |
| case_list = [ | |
| "amos_0517_MRI.nii.gz", | |
| "amos_0541_MRI.nii.gz", | |
| "amos_0571_MRI.nii.gz", | |
| ] | |
| ############################################# | |
| ############################################# | |
| # reset functions | |
| def clear_prompts(): | |
| st.session_state.rectangle_3Dbox = [0, 0, 0, 0, 0, 0] | |
| def reset_demo_case(): | |
| st.session_state.data_item = None | |
| st.session_state.reset_demo_case = True | |
| st.session_state.preds_3D = None | |
| st.session_state.preds_3D_ori = None | |
| clear_prompts() | |
| def clear_file(): | |
| st.session_state.option = None | |
| reset_demo_case() | |
| clear_prompts() | |
| ############################################# | |
| github_col, arxive_col = st.columns(2) | |
| with github_col: | |
| st.write("Git: https://github.com/hhaentze/mrsegmentator") | |
| with arxive_col: | |
| st.write("Paper: https://arxiv.org/abs/2405.06463") | |
| # modify demo case here | |
| demo_type = st.radio("Demo case source", ["Select (presegmented)", "Upload"], on_change=clear_file) | |
| with tempfile.TemporaryDirectory() as tmpdirname: | |
| # modify demo case here | |
| if demo_type == "Select (presegmented)": | |
| selection = st.selectbox( | |
| "Select a demo case", | |
| case_list, | |
| index=None, | |
| placeholder="Select a demo case...", | |
| on_change=reset_demo_case, | |
| ) | |
| if selection: | |
| uploaded_file = "images/" + selection | |
| seg_path = Path(__file__).parent / ("segmentations/" + add_postfix(selection, "seg")) | |
| st.session_state.preds_3D = utils.read_image(seg_path) | |
| st.session_state.preds_3D_ori = sitk.ReadImage(seg_path) | |
| else: | |
| uploaded_file = None | |
| else: | |
| uploaded_file = st.file_uploader("Upload demo case(nii.gz)", type="nii.gz", on_change=reset_demo_case) | |
| if uploaded_file is not None: | |
| with open(tmpdirname + "/" + uploaded_file.name, "wb") as f: | |
| f.write(uploaded_file.getvalue()) | |
| uploaded_file = tmpdirname + "/" + uploaded_file.name | |
| st.session_state.option = uploaded_file | |
| if ( | |
| st.session_state.option is not None | |
| and st.session_state.reset_demo_case | |
| or (st.session_state.data_item is None and st.session_state.option is not None) | |
| ): | |
| st.session_state.data_item = utils.read_image(Path(__file__).parent / str(uploaded_file)) | |
| # st.session_state.preds_3D = None | |
| # st.session_state.preds_3D_ori = None | |
| st.session_state.reset_demo_case = False | |
| if st.session_state.option is None: | |
| st.write("please select demo case first") | |
| else: | |
| image_3D = st.session_state.data_item | |
| px_range = st.slider( | |
| "Select intensity range", | |
| int(image_3D.min()), | |
| int(image_3D.max()), | |
| (int(image_3D.min()), int(image_3D.max())), | |
| ) | |
| col_control1, col_control2 = st.columns(2) | |
| with col_control1: | |
| selected_index_z = st.slider( | |
| "Axial view", | |
| 0, | |
| image_3D.shape[0] - 1, | |
| image_3D.shape[0] // 2, | |
| key="xy", | |
| disabled=st.session_state.running, | |
| ) | |
| with col_control2: | |
| selected_index_y = st.slider( | |
| "Coronal view", | |
| 0, | |
| image_3D.shape[1] - 1, | |
| image_3D.shape[1] // 2, | |
| key="xz", | |
| disabled=st.session_state.running, | |
| ) | |
| col_image1, col_image2 = st.columns(2) | |
| if st.session_state.preds_3D is not None: | |
| st.session_state.transparency = st.slider("Mask opacity", 0.0, 1.0, 0.35, disabled=st.session_state.running) | |
| with col_image1: | |
| image_z_array = image_3D[selected_index_z] | |
| preds_z_array = None | |
| if st.session_state.preds_3D is not None: | |
| preds_z_array = st.session_state.preds_3D[selected_index_z] | |
| image_z = utils.make_fig(image_z_array, preds_z_array, px_range, st.session_state.transparency) | |
| st.image(image_z, use_column_width=False) | |
| with col_image2: | |
| image_y_array = image_3D[:, selected_index_y, :] | |
| preds_y_array = None | |
| if st.session_state.preds_3D is not None: | |
| preds_y_array = st.session_state.preds_3D[:, selected_index_y, :] | |
| image_y = utils.make_fig(image_y_array, preds_y_array, px_range, st.session_state.transparency) | |
| st.image(image_y, use_column_width=False) | |
| ###################################################### | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.markdown("#") | |
| st.markdown("####") | |
| st.markdown("####") | |
| if st.button( | |
| "Clear", | |
| use_container_width=True, | |
| disabled=(st.session_state.option is None or (st.session_state.preds_3D is None)), | |
| ): | |
| clear_prompts() | |
| st.session_state.preds_3D = None | |
| st.session_state.preds_path = None | |
| st.rerun() | |
| with col2: | |
| st.markdown("#") | |
| st.markdown("####") | |
| st.markdown("####") | |
| if st.session_state.preds_3D is not None and st.session_state.data_item is not None: | |
| with tempfile.NamedTemporaryFile(suffix=".nii.gz") as tmpfile: | |
| preds = st.session_state.preds_3D_ori | |
| sitk.WriteImage(preds, tmpfile.name) | |
| with open(tmpfile.name, "rb") as f: | |
| bytes_data = f.read() | |
| st.download_button( | |
| label="Download result (.nii.gz)", | |
| data=bytes_data, | |
| file_name="segmentation.nii.gz", | |
| mime="application/octet-stream", | |
| disabled=False, | |
| ) | |
| with col3: | |
| folds = st.radio("", ["Model of Fold 1 (fast)", "Ensemble Segmentation"]) | |
| if folds == "Model of Fold 1": | |
| st.session_state.folds = (0,) | |
| else: | |
| st.session_state.folds = ( | |
| 0, | |
| 1, | |
| 2, | |
| 3, | |
| 4, | |
| ) | |
| run_button_name = "Run" if not st.session_state.running else "Running" | |
| if st.button( | |
| run_button_name, | |
| type="primary", | |
| use_container_width=True, | |
| disabled=True, | |
| # disabled=(st.session_state.data_item is None or st.session_state.running), | |
| ): | |
| st.session_state.running = True | |
| st.rerun() | |
| if st.session_state.running: | |
| st.session_state.running = False | |
| with st.status("Running...", expanded=False) as status: | |
| utils.run(tmpdirname) | |
| st.rerun() | |