Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| def img2label(left, right): | |
| left = Image.fromarray(left.astype('uint8'), 'RGB') | |
| right = Image.fromarray(right.astype('uint8'), 'RGB') | |
| # 将右眼底镜像反转 | |
| r2l = transforms.RandomHorizontalFlip(p=1) | |
| right = r2l(right) | |
| # 调整图片 | |
| left_img = my_transforms(left).to(device) | |
| right_img = my_transforms(right).to(device) | |
| # 读取模型 | |
| model = torch.load('densenet_FD_e4_l5e-4_b32.pkl', map_location='cpu').to(device) | |
| # 模型推理 | |
| with torch.no_grad(): | |
| output = model(left=left_img.unsqueeze(0), right=right_img.unsqueeze(0)) | |
| # 结果处理 | |
| output = torch.sigmoid(output.squeeze(0)) | |
| output_ = output.cpu().numpy().tolist() | |
| res_dict = {LABELS[i]: output_[i] for i in range(len(output_))} | |
| pred = torch.nonzero(output > 0.4).view(-1) | |
| pred = pred.cpu().numpy().tolist() | |
| if len(pred) == 0 or (len(pred) == 1 and pred[0] == 0): | |
| return res_dict, LABELS[0] | |
| res = '' | |
| for i in pred: | |
| if i == 0: | |
| continue | |
| res += ', ' + LABELS[i] | |
| return res_dict, '目前的身体状态:' + res[2:] | |
| if __name__ == '__main__': | |
| device = torch.device("cpu") | |
| # 标题 | |
| title = "基于眼底图像的智能健康诊断分析系统" | |
| # 标题下的描述,支持md格式 | |
| description = "上传并输入左右眼底图像后,点击 submit 按钮,可根据双目眼底图像智能分析出可能有的疾病!" \ | |
| "包含的疾病种类有:糖尿病、青光眼、白内障、年龄性黄斑变性、高血压、病理性近视、其他疾病以及正常共计8类" | |
| # transforms设置 | |
| norm_mean = [0.485, 0.456, 0.406] | |
| norm_std = [0.229, 0.224, 0.225] | |
| my_transforms = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(norm_mean, norm_std) | |
| ]) | |
| # 标签设置 | |
| LABELS = {0: '正常', | |
| 1: '糖尿病', | |
| 2: '青光眼', | |
| 3: '白内障', | |
| 4: '年龄性黄斑变性', | |
| 5: '高血压', | |
| 6: '病理性近视', | |
| 7: '其他疾病'} | |
| left_img_dir = 'left.jpg' | |
| right_img_dir = 'right.jpg' | |
| examples = [[left_img_dir, right_img_dir]] | |
| # r = img2label(left_img_dir, right_img_dir) | |
| demo = gr.Interface(fn=img2label, | |
| inputs=[gr.inputs.Image(), gr.inputs.Image()], | |
| outputs=["label", "text"], | |
| examples=examples, title=title, description=description) | |
| demo.launch() | |