Spaces:
Runtime error
Runtime error
| from transformers import OwlViTProcessor, OwlViTForObjectDetection | |
| import matplotlib.colors as mcolors | |
| import matplotlib.pyplot as plt | |
| import streamlit as st | |
| from PIL import Image | |
| import warnings | |
| import torch | |
| import os | |
| import io | |
| # setttings | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '1' | |
| warnings.filterwarnings('ignore') | |
| st.set_page_config() | |
| class owl_vit: | |
| def __init__(self, image_path, text, threshold): | |
| self.image_path = image_path | |
| self.text = text | |
| self.threshold = threshold | |
| def process(self, processor, model): | |
| image = Image.open(self.image_path) | |
| if len(image.split()) == 1: | |
| image = image.convert("RGB") | |
| inputs = processor(text=[self.text], images=[image], return_tensors="pt") | |
| outputs = model(**inputs) | |
| target_sizes = torch.tensor([[image.height, image.width] for image in [image]]) | |
| self.results = processor.post_process(outputs=outputs, target_sizes=target_sizes) | |
| self.image = image | |
| return self.result_image() | |
| def result_image(self): | |
| boxes, scores, labels = self.results[0]["boxes"], self.results[0]["scores"], self.results[0]["labels"] | |
| plt.imshow(self.image) | |
| ax = plt.gca() | |
| for box, score, label in zip(boxes, scores, labels): | |
| if score >= self.threshold: | |
| box = box.detach().numpy() | |
| color = list(mcolors.CSS4_COLORS.keys())[label] | |
| ax.add_patch(plt.Rectangle(box[:2], box[2] - box[0], box[3] - box[1], fill=False, color=color, linewidth=3,)) | |
| ax.text(box[0], box[1], f"{self.text[label]}: {round(score.item(), 2)}", fontsize=15, color=color) | |
| plt.tight_layout() | |
| img_buf = io.BytesIO() | |
| plt.savefig(img_buf, format='png') | |
| image = Image.open(img_buf) | |
| return image | |
| def load_model(): | |
| with st.spinner('Getting Neruons in Order ...'): | |
| processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch16") | |
| model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch16") | |
| return processor, model | |
| def show_detects(image): | |
| st.title("Results") | |
| st.image(image, use_column_width=True, caption="Object Detection Results", clamp=True) | |
| def process(upload, text, threshold): | |
| # save upload to file | |
| filetype = upload.name.split('.')[-1] | |
| name = len(os.listdir("images")) + 1 | |
| file_path = os.path.join('images', f'{name}.{filetype}') | |
| with open(file_path, "wb") as f: | |
| f.write(upload.getbuffer()) | |
| # predict detections and show results | |
| detector = owl_vit(file_path, text, threshold) | |
| results = detector.process(processor, model) | |
| show_detects(results) | |
| # clean up - if over 1000 images in folder, delete oldest 1 | |
| if len(os.listdir("images")) > 1000: | |
| oldest = min(os.listdir("images"), key=os.path.getctime) | |
| os.remove(os.path.join("images", oldest)) | |
| def main(processor, model): | |
| # splash image | |
| st.image(os.path.join('refs', 'baseball_labeled.png'), use_column_width=True) | |
| # title project descriptions | |
| st.title("OWL-ViT") | |
| st.markdown("**OWL-ViT** is a zero-shot text-conditioned object detection model. OWL-ViT uses CLIP as its multi-modal \ | |
| backbone, with a ViT-like Transformer to get visual features and a causal language model to get the text features. \ | |
| To use CLIP for detection, OWL-ViT removes the final token pooling layer of the vision model and attaches a \ | |
| lightweight classification and box head to each transformer output token. Open-vocabulary classification \ | |
| is enabled by replacing the fixed classification layer weights with the class-name embeddings obtained \ | |
| from the text model. The authors first train CLIP from scratch and fine-tune it end-to-end with the classification \ | |
| and box heads on standard detection datasets using a bipartite matching loss. One or multiple text queries per image \ | |
| can be used to perform zero-shot text-conditioned object detection.", unsafe_allow_html=True) | |
| # example | |
| if st.button("Run the Example Image/Text"): | |
| with st.spinner('Detecting Objects and Comparing Vocab...'): | |
| info = owl_vit(os.path.join('refs', 'baseball.jpg'), ["batter", "umpire", "catcher"], 0.50) | |
| results = info.process(processor, model) | |
| show_detects(results) | |
| if st.button("Clear Example"): | |
| st.markdown("") | |
| # upload | |
| col1, col2 = st.columns(2) | |
| threshold = st.slider('Confidence Threshold', min_value=0.0, max_value=1.0, value=0.1) | |
| with col1: | |
| upload = st.file_uploader('Image:', type=['jpg', 'jpeg', 'png']) | |
| with col2: | |
| text = st.text_area('Objects to Detect: (comma, seperated)', "batter, umpire, catcher") | |
| text = [x.strip() for x in text.split(',')] | |
| # process | |
| if upload is not None and text is not None: | |
| filetype = upload.name.split('.')[-1] | |
| if filetype in ['jpg', 'jpeg', 'png']: | |
| with st.spinner('Detecting and Counting Single Image...'): | |
| process(upload, text, threshold) | |
| else: | |
| st.warning('Unsupported file type.') | |
| if __name__ == '__main__': | |
| processor, model = load_model() | |
| main(processor, model) | |