Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import clip | |
| import tempfile | |
| import cv2 | |
| from tqdm import tqdm | |
| from transformers import GPT2Tokenizer | |
| from model import * | |
| from inference import * | |
| st.set_page_config( | |
| page_title="Video Analysis AI", | |
| page_icon="๐ถ๏ธ", | |
| ) | |
| def load_model(): | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| clip_model, preprocess = clip.load("ViT-L/14@336px", device=device, jit=False) | |
| tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3large_based_on_gpt2') | |
| prefix_length = 50 | |
| model_path = 'transformer_clip_gpt-007.pt' | |
| model = ClipCaptionModel('sberbank-ai/rugpt3small_based_on_gpt2', prefix_length=prefix_length) | |
| model.load_state_dict(torch.load(model_path, map_location='cpu')) | |
| model.to(device) | |
| model.eval() | |
| return model, clip_model, preprocess, tokenizer | |
| def _max_width_(): | |
| max_width_str = f"max-width: 1400px;" | |
| st.markdown( | |
| f""" | |
| <style> | |
| .reportview-container .main .block-container{{ | |
| {max_width_str} | |
| }} | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| _max_width_() | |
| def main(): | |
| model, clip_model, preprocess, tokenizer = load_model() | |
| prefix_length = 50 | |
| st.title("๐ฆพ Video Analysis for Education") | |
| st.header("") | |
| with st.sidebar.expander("โน๏ธ - About application", expanded=True): | |
| st.write( | |
| """ | |
| - Upload the video | |
| - Make a question about the content of the video | |
| - Recieve answer according your question prompt | |
| """ | |
| ) | |
| uploaded_file = st.file_uploader("๐ Upload video: ", ['.mp4']) | |
| st.write("---") | |
| a, b = st.columns([4, 1]) | |
| question = a.text_input( | |
| label="โ Enter question prompt: ", | |
| placeholder="", | |
| # label_visibility="collapsed", | |
| ) | |
| button = b.button("Send", use_container_width=True) | |
| if button: | |
| try: | |
| # tfile = tempfile.NamedTemporaryFile(delete=False) | |
| # tfile.write(uploaded_file.read()) | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| val_embeddings = [] | |
| val_captions = [] | |
| result = '' | |
| text = f'Question: {question}? Answer:' | |
| # read video -> get_ans | |
| video = read_video(uploaded_file.name, transform=None, frames_num=4) | |
| if len(video) > 0: | |
| i = image_grid(video, 2, 2) | |
| image = preprocess(i).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) | |
| val_embeddings.append(prefix) | |
| val_captions.append(text) | |
| answers = [] | |
| for i in tqdm(range(len(val_embeddings))): | |
| emb = val_embeddings[i] | |
| caption = val_captions[i] | |
| ans = get_ans(model, tokenizer, emb, prefix_length, caption) | |
| answers.append(ans['answer']) | |
| st.write(answers) | |
| result = answers[0].split(' A: ')[0] | |
| res = st.text_input('โ Answer to the question', result, disabled=False) | |
| except: | |
| pass | |
| if __name__ == '__main__': | |
| main() |