Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from PIL import Image | |
| from transformers import BlipProcessor, BlipForConditionalGeneration, AutoTokenizer | |
| import itertools | |
| from nltk.corpus import stopwords | |
| import nltk | |
| import easyocr | |
| import torch | |
| import numpy as np | |
| nltk.download('stopwords') | |
| # load the model and tokenizer | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning") | |
| reader = easyocr.Reader(['en']) | |
| # set up Streamlit app | |
| st.set_page_config(layout='wide', page_title='Image Hashtag Recommender') | |
| def generate_hashtags(image_file): | |
| # get image and convert to RGB mode | |
| image = Image.open(image_file).convert('RGB') | |
| # extract image features | |
| inputs = processor(image, return_tensors="pt") | |
| output_ids = model.generate(**inputs) | |
| # out_text = processor.decode(out[0], skip_special_tokens=True) | |
| # decode the model output to text and extract caption words | |
| output_text = processor.decode(output_ids[0], skip_special_tokens=True) | |
| caption_words = [word.lower() for word in output_text.split() if not word.startswith("#")] | |
| # remove stop words from caption words | |
| stop_words = set(stopwords.words('english')) | |
| caption_words = [word for word in caption_words if word not in stop_words] | |
| # use easyocr to extract text from the image | |
| text = reader.readtext(np.array(image)) | |
| detected_text = " ".join([item[1] for item in text]) | |
| # combine caption words and detected text | |
| all_words = caption_words + detected_text.split() | |
| # generate combinations of words for hashtags | |
| hashtags = [] | |
| for n in range(1, 4): | |
| word_combinations = list(itertools.combinations(all_words, n)) | |
| for combination in word_combinations: | |
| hashtag = "#" + "".join(combination) | |
| hashtags.append(hashtag) | |
| # return top 10 hashtags by frequency | |
| top_hashtags = [tag for tag in sorted(set(hashtags), key=hashtags.count, reverse=True) if tag != "#"] | |
| return [top_hashtags[:10], output_text] | |
| st.title("HashTag Recommender") | |
| image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
| # if the user has submitted an image, generate hashtags | |
| if image_file is not None: | |
| try: | |
| st.image(image_file, width = 500, channels = 'RGB') | |
| hashtags = generate_hashtags(image_file) | |
| if len(hashtags) > 0: | |
| # st.write(f"Caption : {hashtags[1]}") | |
| st.write("Top 10 hashtags for this image:") | |
| for tag in hashtags[0]: | |
| st.write(tag) | |
| else: | |
| st.write("No hashtags found for this image.") | |
| except Exception as e: | |
| st.write(f"Error: {e}") | |