Spaces:
Running
Running
| #from openvino.runtime import Core | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from torchvision import models,transforms | |
| from typing import Iterable | |
| import gradio as gr | |
| from torch import nn | |
| from gradio.themes.base import Base | |
| from gradio.themes.utils import colors, fonts, sizes | |
| import time | |
| import torch | |
| import intel_extension_for_pytorch as ipex | |
| #core = Core() | |
| def conv(in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| class resconv(nn.Module): | |
| def __init__(self,in_features,out_features): | |
| super(resconv,self).__init__() | |
| self.block=nn.Sequential( | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(in_features,out_features,3), | |
| nn.InstanceNorm2d(out_features), | |
| nn.ReLU(inplace=True), | |
| nn.ReflectionPad2d(1), | |
| nn.Conv2d(out_features,out_features,3), | |
| nn.InstanceNorm2d(out_features), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self,x): | |
| return x+self.block(x) | |
| def up_conv(in_channels, out_channels): | |
| return nn.Sequential( | |
| nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| class ResnUnet(nn.Module): | |
| def __init__(self, out_channels=32,number_of_block=9): | |
| super().__init__() | |
| out_features=64 | |
| channels=3 | |
| model=[nn.ReflectionPad2d(3),nn.Conv2d(3,out_features,7),nn.InstanceNorm2d(out_features), | |
| nn.ReLU(inplace=True),nn.MaxPool2d(3,stride=2)] | |
| model+=[resconv(out_features,out_features)] | |
| model+=[nn.Conv2d(out_features,out_features*2,3,stride=2,padding=1),nn.InstanceNorm2d(out_features), | |
| nn.ReLU(inplace=True)] | |
| model+=[resconv(out_features*2,out_features*2)] | |
| model+=[nn.Conv2d(out_features*2,out_features*4,3,stride=2,padding=1),nn.InstanceNorm2d(out_features), | |
| nn.ReLU(inplace=True)] | |
| model+=[resconv(out_features*4,out_features*4)] | |
| model+=[nn.Conv2d(out_features*4,out_features*8,3,stride=2,padding=1),nn.InstanceNorm2d(out_features), | |
| nn.ReLU(inplace=True)] | |
| model+=[resconv(out_features*8,out_features*8)] | |
| out_features*=8 | |
| input_features=out_features | |
| for _ in range(4): | |
| out_features//=2 | |
| model+=[ | |
| nn.Upsample(scale_factor=2), | |
| nn.Conv2d(input_features,out_features,3,stride=1,padding=1), | |
| nn.InstanceNorm2d(out_features), | |
| nn.ReLU(inplace=True) | |
| ] | |
| input_features=out_features | |
| model+=[nn.ReflectionPad2d(3),nn.Conv2d(32,32,7), | |
| ] | |
| self.model=nn.Sequential(*model) | |
| def forward(self,x): | |
| return self.model(x) | |
| model=ResnUnet().to('cpu') | |
| # Load the state_dict | |
| state_dict = torch.load('real_model2_onnx_compat.pt',map_location='cpu') | |
| # Create a new state_dict without the 'module.' prefix | |
| new_state_dict = {} | |
| for key, value in state_dict.items(): | |
| new_key = key.replace("module.", "") # Remove the 'module.' prefix | |
| new_state_dict[new_key] = value | |
| # Load the new state_dict into your model | |
| model.load_state_dict(new_state_dict) | |
| model.eval() | |
| model = ipex.optimize(model, weights_prepack=False) | |
| model = torch.compile(model, backend="ipex") | |
| # Read model to OpenVINO Runtime | |
| #model_ir = core.read_model(model="Davinci_eye.onnx") | |
| #compiled_model_ir = core.compile_model(model=model_ir, device_name='CPU') | |
| tfms = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # imagenet | |
| ]) | |
| color_map={ | |
| (251,244,5): 1, | |
| (37,250,5):2, | |
| (0,21,209):3, | |
| (172,21,2): 4, | |
| (172,21,229): 5, | |
| (6,254,249): 6, | |
| (141,216,23):7, | |
| (96,13,13):8, | |
| (65,214,24):9, | |
| (124,3,252):10, | |
| (214,55,153):11, | |
| (48,61,173):12, | |
| (110,31,254):13, | |
| (249,37,14):14, | |
| (249,137,254):15, | |
| (34,255,113):16, | |
| (169,52,14):17, | |
| (124,49,176):18, | |
| (4,88,238):19, | |
| (115,214,178):20, | |
| (115,63,178):21, | |
| (115,214,235):22, | |
| (63,63,178): 23, | |
| (130,34,26):24, | |
| (220,158,161):25, | |
| (201,117,56):26, | |
| (121,16,40):27, | |
| (15,126,0):28, | |
| (0,50,70):29, | |
| (20,20,0):30, | |
| (20,20,0):31, | |
| } | |
| colormap={v:[i for i in k] for k,v in color_map.items()} | |
| items = { | |
| 1: "HarmonicAce_Head", | |
| 2: "HarmonicAce_Body", | |
| 3: "MarylandBipolarForceps_Head", | |
| 4: "MarylandBipolarForceps_Wrist", | |
| 5: "MarylandBipolarForceps_Body", | |
| 6: "CadiereForceps_Head", | |
| 7: "CadiereForceps_Wrist", | |
| 8: "CadiereForceps_Body", | |
| 9: "CurvedAtraumaticGrasper_Head", | |
| 10: "CurvedAtraumaticGrasper_Body", | |
| 11: "Stapler_Head", | |
| 12: "Stapler_Body", | |
| 13: "MediumLargeClipApplier_Head", | |
| 14: "MediumLargeClipApplier_Wrist", | |
| 15: "MediumLargeClipApplier_Body", | |
| 16: "SmallClipApplier_Head", | |
| 17: "SmallClipApplier_Wrist", | |
| 18: "SmallClipApplier_Body", | |
| 19: "SuctionIrrigation", | |
| 20: "Needle", | |
| 21: "Endotip", | |
| 22: "Specimenbag", | |
| 23: "DrainTube", | |
| 24: "Liver", | |
| 25: "Stomach", | |
| 26: "Pancreas", | |
| 27: "Spleen", | |
| 28: "Gallbladder", | |
| 29:"Gauze", | |
| 30:"TheOther_Instruments", | |
| 31:"TheOther_Tissues", | |
| } | |
| class Davinci_Eye(Base): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.stone, | |
| secondary_hue: colors.Color | str = colors.blue, | |
| neutral_hue: colors.Color | str = colors.gray, | |
| spacing_size: sizes.Size | str = sizes.spacing_md, | |
| radius_size: sizes.Size | str = sizes.radius_md, | |
| text_size: sizes.Size | str = sizes.text_lg, | |
| font: fonts.Font | |
| | str | |
| | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), | |
| "ui-sans-serif", | |
| "sans-serif", | |
| ), | |
| font_mono: fonts.Font | |
| | str | |
| | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), | |
| "ui-monospace", | |
| "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| spacing_size=spacing_size, | |
| radius_size=radius_size, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| davincieye = Davinci_Eye() | |
| def convert_mask_to_rgb(pred_mask): | |
| rgb_mask=np.zeros((pred_mask.shape[0],pred_mask.shape[1],3),dtype=np.uint8) | |
| for k,v in colormap.items(): | |
| rgb_mask[pred_mask==k]=v | |
| return rgb_mask | |
| def segment_image(filepath,tag): | |
| image=cv2.imread(filepath) | |
| image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB) | |
| image = cv2.resize(image, (224,224)) | |
| x=tfms(image.copy()/255.) | |
| with torch.no_grad(): | |
| mask=model(x.unsqueeze(0).float()) | |
| #ort_input={ort_session.get_inputs()[0].name:x.cpu().unsqueeze(0).float().numpy()} | |
| #out=ort_session.run(None,ort_input) | |
| _,pred_mask=torch.max(mask,dim=1) | |
| pred_mask=pred_mask[0].numpy() | |
| pred_mask=pred_mask.astype(np.uint8) | |
| color_mask=convert_mask_to_rgb(pred_mask) | |
| masked_image=cv2.addWeighted(image,0.3,color_mask,0.8,0.2) | |
| pred_keys=pred_mask[np.nonzero(pred_mask)] | |
| objects=[items[k] for k in pred_keys] | |
| surgery_items=np.unique(np.array(objects),axis=0) | |
| surg="" | |
| for item in surgery_items: | |
| surg+=item+","+" " | |
| return Image.fromarray(masked_image),surg | |
| demo=gr.Interface(fn=segment_image,inputs=[gr.Image(type='filepath')], | |
| outputs=[gr.Image(type="pil"),gr.Text()], | |
| examples=["R001_ch1_video_03_00-29-13-03.jpg", | |
| "R002_ch1_video_01_01-07-25-19.jpg", | |
| "R003_ch1_video_05_00-22-42-23.jpg", | |
| "R004_ch1_video_01_01-12-22-00.jpg", | |
| "R005_ch1_video_03_00-19-10-11.jpg", | |
| "R006_ch1_video_01_00-45-02-10.jpg", | |
| "R013_ch1_video_03_00-40-17-11.jpg"], | |
| #examples='R001_ch1_video_03_00-29-13-03.jpg', | |
| theme=davincieye.set(loader_color='#65aab1'), | |
| title="Davinci Eye(Quantized for CPU)") | |
| demo.launch() |