Spaces:
Running
Running
| import os | |
| import torch | |
| import modelscope | |
| import huggingface_hub | |
| import gradio as gr | |
| from PIL import Image | |
| from torchvision.transforms import transforms | |
| EN_US = os.getenv("LANG") != "zh_CN.UTF-8" | |
| ZH2EN = { | |
| "上传细胞图像": "Upload a cell picture", | |
| "状态栏": "Status", | |
| "图片名": "Picture name", | |
| "识别结果": "Recognition result", | |
| "请上传 PNG 格式的 HEp2 细胞图片": "It is recommended to upload HEp2 cell images in PNG format.", | |
| } | |
| def _L(zh_txt: str): | |
| return ZH2EN[zh_txt] if EN_US else zh_txt | |
| MODEL_DIR = ( | |
| huggingface_hub.snapshot_download( | |
| "Genius-Society/HEp2", | |
| cache_dir="./__pycache__", | |
| ) | |
| if EN_US | |
| else modelscope.snapshot_download( | |
| "Genius-Society/HEp2", | |
| cache_dir="./__pycache__", | |
| ) | |
| ) | |
| TRANSLATE = { | |
| "Centromere": "着丝粒", | |
| "Golgi": "高尔基体", | |
| "Homogeneous": "同质", | |
| "NuMem": "记忆体", | |
| "Nucleolar": "核仁", | |
| "Speckled": "斑核", | |
| } | |
| CLASSES = list(TRANSLATE.keys()) | |
| def embeding(img_path: str): | |
| compose = transforms.Compose( | |
| [ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.RandomAffine(5), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
| ] | |
| ) | |
| img = Image.open(img_path).convert("RGB") | |
| return compose(img) | |
| def infer(target: str): | |
| status = "Success" | |
| filename = result = None | |
| try: | |
| model = torch.load(f"{MODEL_DIR}/save.pt", map_location=torch.device("cpu")) | |
| if not target: | |
| raise ValueError("请上传细胞图片") | |
| torch.cuda.empty_cache() | |
| input: torch.Tensor = embeding(target) | |
| output: torch.Tensor = model(input.unsqueeze(0)) | |
| predict = torch.max(output.data, 1)[1] | |
| filename = os.path.basename(target) | |
| result = CLASSES[predict] if EN_US else TRANSLATE[CLASSES[predict]] | |
| except Exception as e: | |
| status = f"{e}" | |
| return status, filename, result | |
| if __name__ == "__main__": | |
| example_imgs = [] | |
| for cls in CLASSES: | |
| example_imgs.append(f"{MODEL_DIR}/examples/{cls}.png") | |
| gr.Interface( | |
| fn=infer, | |
| inputs=gr.Image(type="filepath", label=_L("上传细胞图像")), | |
| outputs=[ | |
| gr.Textbox(label=_L("状态栏"), show_copy_button=True), | |
| gr.Textbox(label=_L("图片名"), show_copy_button=True), | |
| gr.Textbox(label=_L("识别结果"), show_copy_button=True), | |
| ], | |
| title=_L("请上传 PNG 格式的 HEp2 细胞图片"), | |
| examples=example_imgs, | |
| flagging_mode="never", | |
| cache_examples=False, | |
| ).launch(ssr_mode=False) | |