Spaces:
Configuration error
Configuration error
| import os | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from open_clip import create_model, get_tokenizer | |
| from torchvision import transforms | |
| from templates import openai_imagenet_template | |
| hf_token = os.getenv("HF_TOKEN") | |
| hf_writer = gr.HuggingFaceDatasetSaver(hf_token, "bioclip-demo") | |
| model_str = "hf-hub:imageomics/bioclip" | |
| tokenizer_str = "ViT-B-16" | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| preprocess_img = transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=(0.48145466, 0.4578275, 0.40821073), | |
| std=(0.26862954, 0.26130258, 0.27577711), | |
| ), | |
| ] | |
| ) | |
| def get_txt_features(classnames, templates): | |
| all_features = [] | |
| for classname in classnames: | |
| txts = [template(classname) for template in templates] | |
| txts = tokenizer(txts).to(device) | |
| txt_features = model.encode_text(txts) | |
| txt_features = F.normalize(txt_features, dim=-1).mean(dim=0) | |
| txt_features /= txt_features.norm() | |
| all_features.append(txt_features) | |
| all_features = torch.stack(all_features, dim=1) | |
| return all_features | |
| def predict(img, classes: list[str]) -> dict[str, float]: | |
| classes = [cls.strip() for cls in classes if cls.strip()] | |
| txt_features = get_txt_features(classes, openai_imagenet_template) | |
| img = preprocess_img(img).to(device) | |
| img_features = model.encode_image(img.unsqueeze(0)) | |
| img_features = F.normalize(img_features, dim=-1) | |
| logits = (model.logit_scale.exp() * img_features @ txt_features).squeeze() | |
| probs = F.softmax(logits, dim=0).to("cpu").tolist() | |
| return {cls: prob for cls, prob in zip(classes, probs)} | |
| def hierarchical_predict(img) -> list[str]: | |
| """ | |
| Predicts from the top of the tree of life down to the species. | |
| """ | |
| img = preprocess_img(img).to(device) | |
| img_features = model.encode_image(img.unsqueeze(0)) | |
| img_features = F.normalize(img_features, dim=-1) | |
| breakpoint() | |
| def run(img, cls_str: str) -> dict[str, float]: | |
| breakpoint() | |
| if cls_str: | |
| classes = [cls.strip() for cls in cls_str.split("\n") if cls.strip()] | |
| return predict(img, classes) | |
| else: | |
| return hierarchical_predict(img) | |
| if __name__ == "__main__": | |
| print("Starting.") | |
| model = create_model(model_str, output_dict=True, require_pretrained=True) | |
| model = model.to(device) | |
| print("Created model.") | |
| model = torch.compile(model) | |
| print("Compiled model.") | |
| tokenizer = get_tokenizer(tokenizer_str) | |
| demo = gr.Interface( | |
| fn=run, | |
| inputs=[ | |
| gr.Image(shape=(224, 224)), | |
| gr.Textbox( | |
| placeholder="dog\ncat\n...", | |
| lines=3, | |
| label="Classes", | |
| show_label=True, | |
| info="If empty, will predict from the entire tree of life.", | |
| ), | |
| ], | |
| outputs=gr.Label(num_top_classes=20, label="Predictions", show_label=True), | |
| allow_flagging="manual", | |
| flagging_options=["Incorrect", "Other"], | |
| flagging_callback=hf_writer, | |
| ) | |
| demo.launch() | |