Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import cv2 | |
| import nibabel as nib | |
| from PIL import Image | |
| import io | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| from huggingface_hub import from_pretrained_keras | |
| model = from_pretrained_keras("duzduran/NeuroNest3D") | |
| # Constants | |
| IMG_SIZE = 128 | |
| VOLUME_SLICES = 100 | |
| VOLUME_START_AT = 22 | |
| SEGMENT_CLASSES = ['NOT tumor', 'ENHANCING', 'CORE', 'WHOLE'] | |
| def predictByPath(flair, ce): | |
| X = np.empty((VOLUME_SLICES, IMG_SIZE, IMG_SIZE, 2)) | |
| for j in range(VOLUME_SLICES): | |
| X[j, :, :, 0] = cv2.resize(flair[:, :, j + VOLUME_START_AT], (IMG_SIZE, IMG_SIZE)) | |
| X[j, :, :, 1] = cv2.resize(ce[:, :, j + VOLUME_START_AT], (IMG_SIZE, IMG_SIZE)) | |
| # Normalize and make predictions | |
| X_normalized = X / np.max(X) | |
| return model.predict(X_normalized, verbose=1) | |
| def create_subplot_image(origImage, gt, predictions, slice_index, start_at, img_size): | |
| plt.figure(figsize=(18, 10)) | |
| f, axarr = plt.subplots(1, 6, figsize=(18, 10)) | |
| for i in range(6): | |
| axarr[i].imshow(cv2.resize(origImage[:, :, slice_index + start_at], (img_size, img_size)), cmap="gray", | |
| interpolation='none') | |
| # Original image flair | |
| axarr[0].title.set_text('Original image flair') | |
| # Ground truth | |
| curr_gt = cv2.resize(gt[:, :, slice_index + start_at], (img_size, img_size), interpolation=cv2.INTER_NEAREST) | |
| axarr[1].imshow(curr_gt, cmap="Reds", interpolation='none', alpha=0.3) | |
| axarr[1].title.set_text('Ground truth') | |
| # All classes | |
| axarr[2].imshow(predictions[slice_index, :, :, 1:4], cmap="Reds", interpolation='none', alpha=0.3) | |
| axarr[2].title.set_text('All classes') | |
| SEGMENT_CLASSES | |
| # Class-specific predictions | |
| for i in range(1, 4): # Adjusted to loop over the available prediction classes | |
| axarr[i + 2].imshow(predictions[slice_index, :, :, i], cmap="OrRd", interpolation='none', alpha=0.3) | |
| axarr[i + 2].title.set_text(f'{SEGMENT_CLASSES[i]} predicted') | |
| # Convert plot to image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png') | |
| plt.close(f) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| examples = { | |
| "Example 1": {"flair": "examples/ex_1/BraTS20_Training_001_flair.nii", | |
| "t1ce": "examples/ex_1/BraTS20_Training_001_t1ce.nii", | |
| "seg": "examples/ex_1/BraTS20_Training_001_seg.nii"}, | |
| "Example 2": {"flair": "examples/ex_2/BraTS20_Training_002_flair.nii", | |
| "t1ce": "examples/ex_2/BraTS20_Training_002_t1ce.nii", | |
| "seg": "examples/ex_2/BraTS20_Training_002_seg.nii"}, | |
| } | |
| def automatic_process(example_key): | |
| paths = examples[example_key] | |
| print(paths["flair"]) | |
| flair = nib.load(paths["flair"]).get_fdata() | |
| t1ce = nib.load(paths["t1ce"]).get_fdata() | |
| seg = nib.load(paths["seg"]).get_fdata() | |
| # Default slice index | |
| slice_index = 50 | |
| return process_and_display_direct(flair, t1ce, seg, slice_index) | |
| def process_and_display_direct(flair_data, t1ce_data, seg_data, slice_index): | |
| flair = np.array(flair_data) | |
| t1ce = np.array(t1ce_data) | |
| seg = np.array(seg_data) | |
| p = predictByPath(flair, t1ce) | |
| # Create the subplot image | |
| subplot_img = create_subplot_image(flair, seg, p, slice_index, VOLUME_START_AT, IMG_SIZE) | |
| return subplot_img | |
| def process_and_display(flair_file, t1ce_file, seg_file, slice_index): | |
| if not flair_file or not t1ce_file or not seg_file: | |
| return None # Ensure all files are uploaded | |
| flair = nib.load(flair_file.name).get_fdata() | |
| t1ce = nib.load(t1ce_file.name).get_fdata() | |
| gt = nib.load(seg_file.name).get_fdata() | |
| p = predictByPath(flair, t1ce) | |
| # Create the subplot image | |
| subplot_img = create_subplot_image(flair, gt, p, slice_index, VOLUME_START_AT, IMG_SIZE) | |
| return subplot_img | |
| title = "<center><strong><font size='8'>Open-Vocabulary SAM<font></strong></center>" | |
| css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" | |
| # Gradio Interface | |
| with gr.Blocks(css=css, title="Tumor Segmentation") as demo: | |
| gr.Markdown( | |
| """ | |
| <p style="text-align: center; font-size: 24px;">MRI Brain Tumor Segmentation</p> | |
| <p style="text-align: center;">made by Ahmet Duzduran</p> | |
| ### <p style="text-align: left;">Faculty: Faculty of Computer Science</p> | |
| ### <p style="text-align: left;">Specialization: Intelligent Systems and Data Science</p> | |
| ### <p style="text-align: left;">Supervisor: Wojciech Oronowicz, PhD, Prof. Of PJATK</p> | |
| """ | |
| ) | |
| with gr.Row(): | |
| flair_input = gr.File(label="Upload Flair NIfTI File") | |
| t1ce_input = gr.File(label="Upload T1ce NIfTI File") | |
| seg_input = gr.File(label="Upload Seg NIfTI File") | |
| slice_input = gr.Slider(minimum=0, maximum=VOLUME_SLICES - 1, label="Slice Index") | |
| # eval_class_input = gr.Dropdown(choices=list(range(len(SEGMENT_CLASSES))), label="Select Class") | |
| submit_button = gr.Button("Submit") | |
| with gr.Row(): | |
| example_selector = gr.Dropdown(list(examples.keys()), label="Select Example") | |
| auto_button = gr.Button("Load Example") | |
| output_image = gr.Image(label="Visualization") | |
| submit_button.click( | |
| process_and_display, | |
| inputs=[flair_input, t1ce_input, seg_input, slice_input], | |
| outputs=output_image | |
| ) | |
| auto_button.click( | |
| automatic_process, | |
| inputs=[example_selector], | |
| outputs=output_image | |
| ) | |
| demo.launch() | |