Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModel, AutoTokenizer | |
| import torch | |
| import faiss | |
| import glob | |
| import numpy as np | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model = AutoModel.from_pretrained("google/siglip-base-patch16-256-multilingual").to(device) | |
| processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-256-multilingual") | |
| tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-256-multilingual") | |
| num_dimensions = model.vision_model.config.hidden_size # 768 | |
| num_k = 30 | |
| text_examples = [ | |
| "Frog waiting on a rock", | |
| "Bird with open mouth", | |
| "Bridge and a ship", | |
| "Bike for two people", | |
| "Biene auf der Blume", | |
| "Hesap makinesi" | |
| ] | |
| def preprocess_images(pathname="images/*", index_file="index.faiss"): | |
| print("Preprocessing images...") | |
| index = faiss.IndexFlatIP(num_dimensions) # Build the index using Inner Product (IP) similarity. | |
| image_filenames = [] | |
| image_features = [] | |
| for image_filename in glob.glob(pathname): | |
| try: | |
| image_raw = Image.open(image_filename) | |
| image_rgb = image_raw.convert('RGB') | |
| image_filenames.append(image_filename) | |
| inputs = processor(images=image_rgb, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| image_embedding = model.get_image_features(**inputs).to("cpu") | |
| image_embedding_n = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True) | |
| image_embedding_n = image_embedding_n.numpy() | |
| image_features.append(image_embedding_n) | |
| except Exception as e: | |
| print(f"Error processing {image_filename}".format(image_filename)) | |
| print(e) | |
| exit(1) | |
| print("Indexing images...") | |
| image_features = np.concatenate(image_features, axis=0) | |
| index.add(image_features) | |
| print("Saving index...") | |
| faiss.write_index(index, index_file) | |
| with open("image_filenames.txt", "w") as f: | |
| for image_filename in image_filenames: | |
| f.write(image_filename + "\n") | |
| print("Preprocessing complete.") | |
| return index, image_filenames | |
| def load_processed_images(index_file="index.faiss", image_filenames_file="image_filenames.txt"): | |
| print("Loading index...") | |
| index = faiss.read_index(index_file) | |
| with open(image_filenames_file) as f: | |
| image_filenames = f.readlines() | |
| image_filenames = [x.strip() for x in image_filenames] | |
| return index, image_filenames | |
| def search_using_text(text): | |
| inputs = tokenizer(text, padding="max_length", return_tensors="pt").to(device) | |
| text_features = model.get_text_features(**inputs).to("cpu") | |
| text_features_n = text_features / text_features.norm(p=2, dim=-1, keepdim=True) | |
| text_features_n = text_features_n.numpy() | |
| D, I = index.search(text_features_n, num_k) | |
| scale = model.logit_scale.exp().cpu().numpy() | |
| bias = model.logit_bias.cpu().numpy() | |
| result = [] | |
| for dist, idx in zip(D[0], I[0]): | |
| score_logit = dist * scale + bias | |
| score_probability = torch.sigmoid(torch.tensor(score_logit)).item() | |
| found_image = Image.open(image_filenames[idx]) | |
| found_image.load() | |
| result.append((found_image, "{:.2f}%".format(score_probability*100))) | |
| return result | |
| def search_using_image(image): | |
| image = Image.fromarray(image) | |
| image_rgb = image.convert('RGB') | |
| inputs = processor(images=image_rgb, return_tensors="pt").to(device) | |
| image_embedding = model.get_image_features(**inputs).to("cpu") | |
| image_embedding_n = image_embedding / image_embedding.norm(p=2, dim=-1, keepdim=True) | |
| image_embedding_n = image_embedding_n.numpy() | |
| D, I = index.search(image_embedding_n, num_k) | |
| result = [] | |
| for dist, idx in zip(D[0], I[0]): | |
| found_image = Image.open(image_filenames[idx]) | |
| found_image.load() | |
| result.append(found_image) | |
| return result | |
| if __name__ == "__main__": | |
| #index, image_filenames = preprocess_images() # uncomment this line to preprocess images | |
| index, image_filenames = load_processed_images() | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Image Search Engine Demo") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(): | |
| gr.Markdown("This app is powered by [SigLIP](https://huggingface.co/google/siglip-base-patch16-256-multilingual) with multilingual support and [GPR1200 Dataset](https://www.kaggle.com/datasets/mathurinache/gpr1200-dataset) image contents. Enter your query in the text box or upload an image to search for similar images.") | |
| with gr.Tab("Text-Image Search"): | |
| text_input = gr.Textbox(label="Type a word or a sentence") | |
| search_using_text_btn = gr.Button("Search with text", scale=0) | |
| gr.Examples( | |
| examples = text_examples, | |
| inputs = [text_input] | |
| ) | |
| with gr.Tab("Image-Image Search"): | |
| image_input = gr.Image() | |
| search_using_image_btn = gr.Button("Search with image", scale=0) | |
| gallery = gr.Gallery(label="Generated images", show_label=False, | |
| elem_id="gallery", columns=3, | |
| object_fit="contain", interactive=False, scale=2.75) | |
| search_using_text_btn.click(search_using_text, inputs=text_input, outputs=gallery) | |
| search_using_image_btn.click(search_using_image, inputs=image_input, outputs=gallery) | |
| demo.launch(share=False) |