Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import json | |
| from collections import Counter | |
| import contractions | |
| import csv | |
| import random | |
| import pandas as pd | |
| import altair as alt | |
| from typing import Tuple, List, Optional | |
| from my_model.dataset.dataset_processor import process_okvqa_dataset | |
| from my_model.config import dataset_config as config | |
| class OKVQADatasetAnalyzer: | |
| """ | |
| Provides tools for analyzing and visualizing distributions of question types within given question datasets. | |
| It supports operations such as data loading, categorization of questions based on keywords, visualization of q | |
| uestion distribution, and exporting data to CSV files. | |
| Attributes: | |
| train_file_path (str): Path to the training dataset file. | |
| test_file_path (str): Path to the testing dataset file. | |
| data_choice (str): Choice of dataset(s) to analyze; options include 'train', 'test', or 'train_test'. | |
| questions (List[str]): List of questions aggregated based on the dataset choice. | |
| question_types (Counter): Counter object tracking the frequency of each question type. | |
| Qs (Dict[str, List[str]]): Dictionary mapping question types to lists of corresponding questions. | |
| """ | |
| def __init__(self, train_file_path: str, test_file_path: str, data_choice: str): | |
| """ | |
| Initializes the OKVQADatasetAnalyzer with paths to dataset files and a choice of which datasets to analyze. | |
| Parameters: | |
| train_file_path (str): Path to the training dataset JSON file. This file should contain a list of questions. | |
| test_file_path (str): Path to the testing dataset JSON file. This file should also contain a list of | |
| questions. | |
| data_choice (str): Specifies which dataset(s) to load and analyze. Valid options are 'train', 'test', or | |
| 'train_test'indicating whether to load training data, testing data, or both. | |
| The constructor initializes the paths, selects the dataset based on the choice, and loads the initial data by | |
| calling the `load_data` method. It also prepares structures for categorizing questions and storing the results. | |
| """ | |
| self.train_file_path = train_file_path | |
| self.test_file_path = test_file_path | |
| self.data_choice = data_choice | |
| self.questions = [] | |
| self.question_types = Counter() | |
| self.Qs = {keyword: [] for keyword in config.QUESTION_KEYWORDS + ['others']} | |
| self.load_data() | |
| def load_data(self) -> None: | |
| """ | |
| Loads the dataset(s) from the specified JSON file(s) based on the user's choice of 'train', 'test', or | |
| 'train_test'. | |
| This method updates the internal list of questions depending on the chosen dataset. | |
| """ | |
| if self.data_choice in ['train', 'train_test']: | |
| with open(self.train_file_path, 'r') as file: | |
| train_data = json.load(file) | |
| self.questions += [q['question'] for q in train_data['questions']] | |
| if self.data_choice in ['test', 'train_test']: | |
| with open(self.test_file_path, 'r') as file: | |
| test_data = json.load(file) | |
| self.questions += [q['question'] for q in test_data['questions']] | |
| def categorize_questions(self) -> None: | |
| """ | |
| Categorizes each question in the loaded data into predefined categories based on keywords. | |
| This method updates the internal dictionary `self.Qs` and the Counter `self.question_types` with categorized | |
| questions. | |
| """ | |
| question_keywords = config.QUESTION_KEYWORDS | |
| for question in self.questions: | |
| question = contractions.fix(question) | |
| words = question.lower().split() | |
| question_keyword = None | |
| if words[:2] == ['name', 'the']: | |
| question_keyword = 'name the' | |
| else: | |
| for word in words: | |
| if word in question_keywords: | |
| question_keyword = word | |
| break | |
| if question_keyword: | |
| self.question_types[question_keyword] += 1 | |
| self.Qs[question_keyword].append(question) | |
| else: | |
| self.question_types["others"] += 1 | |
| self.Qs["others"].append(question) | |
| def plot_question_distribution(self) -> None: | |
| """ | |
| Plots an interactive bar chart of question types using Altair and Streamlit, displaying the count and percentage | |
| of each type. | |
| The chart sorts question types by count in descending order and includes detailed tooltips for interaction. | |
| This method is intended for visualization in a Streamlit application. | |
| """ | |
| # Prepare data | |
| total_questions = sum(self.question_types.values()) | |
| items = [(key, value, (value / total_questions) * 100) for key, value in self.question_types.items()] | |
| df = pd.DataFrame(items, columns=['Question Keyword', 'Count', 'Percentage']) | |
| # Sort data and handle 'others' category specifically if present | |
| df = df[df['Question Keyword'] != 'others'].sort_values('Count', ascending=False) | |
| if 'others' in self.question_types: | |
| others_df = pd.DataFrame([('others', self.question_types['others'], | |
| (self.question_types['others'] / total_questions) * 100)], | |
| columns=['Question Keyword', 'Count', 'Percentage']) | |
| df = pd.concat([df, others_df], ignore_index=True) | |
| # Explicitly set the order of the x-axis based on the sorted DataFrame | |
| order = df['Question Keyword'].tolist() | |
| # Create the bar chart | |
| bars = alt.Chart(df).mark_bar().encode( | |
| x=alt.X('Question Keyword:N', sort=order, title='Question Keyword', axis=alt.Axis(labelAngle=-45)), | |
| y=alt.Y('Count:Q', title='Question Count'), | |
| color=alt.Color('Question Keyword:N', scale=alt.Scale(scheme='category20'), legend=None), | |
| tooltip=[alt.Tooltip('Question Keyword:N', title='Type'), | |
| alt.Tooltip('Count:Q', title='Count'), | |
| alt.Tooltip('Percentage:Q', title='Percentage', format='.1f')] | |
| ) | |
| # Create text labels for the bars with count and percentage | |
| text = bars.mark_text( | |
| align='center', | |
| baseline='bottom', | |
| dy=-5 # Nudges text up so it appears above the bar | |
| ).encode( | |
| text=alt.Text('PercentageText:N') | |
| ).transform_calculate( | |
| PercentageText="datum.Count + ' (' + format(datum.Percentage, '.1f') + '%)'" | |
| ) | |
| # Combine the bar and text layers | |
| chart = (bars + text).properties( | |
| width=800, | |
| height=600, | |
| ).configure_axis( | |
| labelFontSize=12, | |
| titleFontSize=16, | |
| labelFontWeight='bold', | |
| titleFontWeight='bold', | |
| grid=False | |
| ).configure_text( | |
| fontWeight='bold' | |
| ).configure_title( | |
| fontSize=20, | |
| font='bold', | |
| anchor='middle' | |
| ) | |
| # Display the chart in Streamlit | |
| st.altair_chart(chart, use_container_width=True) | |
| def plot_bar_chart(self, df: pd.DataFrame, category_col: str, value_col: str, chart_title: str) -> None: | |
| """ | |
| Plots an interactive bar chart using Altair and Streamlit. | |
| Args: | |
| df (pd.DataFrame): DataFrame containing the data for the bar chart. | |
| category_col (str): Name of the column containing the categories. | |
| value_col (str): Name of the column containing the values. | |
| chart_title (str): Title of the chart. | |
| Returns: | |
| None | |
| """ | |
| # Calculate percentage for each category | |
| df['Percentage'] = (df[value_col] / df[value_col].sum()) * 100 | |
| df['PercentageText'] = df['Percentage'].round(1).astype(str) + '%' | |
| # Create the bar chart | |
| bars = alt.Chart(df).mark_bar().encode( | |
| x=alt.X(field=category_col, title='Category', sort='-y', axis=alt.Axis(labelAngle=-45)), | |
| y=alt.Y(field=value_col, type='quantitative', title='Percentage'), | |
| color=alt.Color(field=category_col, type='nominal', legend=None), | |
| tooltip=[ | |
| alt.Tooltip(field=category_col, type='nominal', title='Category'), | |
| alt.Tooltip(field=value_col, type='quantitative', title='Percentage'), | |
| alt.Tooltip(field='Percentage', type='quantitative', title='Percentage', format='.1f') | |
| ] | |
| ).properties( | |
| width=800, | |
| height=600 | |
| ) | |
| # Add text labels to the bars | |
| text = bars.mark_text( | |
| align='center', | |
| baseline='bottom', | |
| dy=-10 # Nudges text up so it appears above the bar | |
| ).encode( | |
| text=alt.Text('PercentageText:N') | |
| ) | |
| # Combine the bar chart and text labels | |
| chart = (bars + text).configure_title( | |
| fontSize=20 | |
| ).configure_axis( | |
| labelFontSize=12, | |
| titleFontSize=16, | |
| labelFontWeight='bold', | |
| titleFontWeight='bold', | |
| grid=False | |
| ).configure_text( | |
| fontWeight='bold') | |
| # Display the chart in Streamlit | |
| st.altair_chart(chart, use_container_width=True) | |
| def export_to_csv(self, qs_filename: str, question_types_filename: str) -> None: | |
| """ | |
| Exports the categorized questions and their counts to two separate CSV files. | |
| Parameters: | |
| qs_filename (str): The filename or path for exporting the `self.Qs` dictionary data. | |
| question_types_filename (str): The filename or path for exporting the `self.question_types` Counter data. | |
| This method writes the contents of `self.Qs` and `self.question_types` to the specified files in CSV format. | |
| Each CSV file includes headers for better understanding and use of the exported data. | |
| """ | |
| # Export self.Qs dictionary | |
| with open(qs_filename, mode='w', newline='', encoding='utf-8') as file: | |
| writer = csv.writer(file) | |
| writer.writerow(['Question Type', 'Questions']) | |
| for q_type, questions in self.Qs.items(): | |
| for question in questions: | |
| writer.writerow([q_type, question]) | |
| # Export self.question_types Counter | |
| with open(question_types_filename, mode='w', newline='', encoding='utf-8') as file: | |
| writer = csv.writer(file) | |
| writer.writerow(['Question Type', 'Count']) | |
| for q_type, count in self.question_types.items(): | |
| writer.writerow([q_type, count]) | |
| def run_dataset_analyzer() -> None: | |
| """ | |
| Executes the dataset analysis process and displays the results using Streamlit. | |
| This function provides an overview of the dataset, it utilizes the OKVQADatasetAnalyzer to visualize | |
| the data. | |
| """ | |
| # Load datasets from Excel | |
| datasets_comparison_table = pd.read_excel(config.DATASET_ANALYSES_PATH, sheet_name="VQA Datasets Comparison") | |
| okvqa_dataset_characteristics = pd.read_excel(config.DATASET_ANALYSES_PATH, sheet_name="OK-VQA Dataset Characteristics") | |
| # Process OK-VQA datasets for validation and training | |
| val_data = process_okvqa_dataset(config.DATASET_VAL_QUESTIONS_PATH, config.DATASET_VAL_ANNOTATIONS_PATH, save_to_csv=False) | |
| train_data = process_okvqa_dataset(config.DATASET_TRAIN_QUESTIONS_PATH, config.DATASET_TRAIN_ANNOTATIONS_PATH, save_to_csv=False) | |
| # Initialize the dataset analyzer | |
| dataset_analyzer = OKVQADatasetAnalyzer(config.DATASET_TRAIN_QUESTIONS_PATH, config.DATASET_VAL_QUESTIONS_PATH, 'train_test') | |
| # Display KB-VQA datasets overview | |
| with st.container(): | |
| st.markdown("## Overview of KB-VQA Datasets") | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| st.write(" ") | |
| with st.expander("1 - Knowledge-Based VQA (KB-VQA)"): | |
| st.markdown(""" [Knowledge-Based VQA (KB-VQA)](https://arxiv.org/abs/1511.02570): One of the earliest datasets in this domain, KB-VQA | |
| comprises 700 images and 2,402 questions, with each question associated with both an image | |
| and a knowledge base (KB). The KB encapsulates facts about the world, including object | |
| names, properties, and relationships, aiming to foster models capable of answering | |
| questions through reasoning over both the image and the KB.\n""") | |
| with st.expander("2 - Factual VQA (FVQA)"): | |
| st.markdown(""" [Factual VQA (FVQA)](https://arxiv.org/abs/1606.05433): This dataset includes 2,190 | |
| images and 5,826 questions, accompanied by a knowledge base containing 193,449 facts. | |
| The FVQA's questions are predominantly factual and less open-ended compared to those | |
| in KB-VQA, offering a different challenge in knowledge-based reasoning.\n""") | |
| with st.expander("3 - Outside-Knowledge VQA (OK-VQA)"): | |
| st.markdown(""" [Outside-Knowledge VQA (OK-VQA)](https://arxiv.org/abs/1906.00067): OK-VQA poses a more | |
| demanding challenge than KB-VQA, featuring an open-ended knowledge base that can be | |
| updated during model training. This dataset contains 14,055 questions and 14,031 images. | |
| Questions are carefully curated to ensure they require reasoning beyond the image | |
| content alone.\n""") | |
| with st.expander("4 - Augmented OK-VQA (A-OKVQA)"): | |
| st.markdown(""" [Augmented OK-VQA (A-OKVQA)](https://arxiv.org/abs/2206.01718): Augmented successor of | |
| OK-VQA dataset, focused on common-sense knowledge and reasoning rather than purely | |
| factual knowledge, A-OKVQA offers approximately 24,903 questions across 23,692 images. | |
| Questions in this dataset demand commonsense reasoning about the scenes depicted in the | |
| images, moving beyond straightforward knowledge base queries. It also provides | |
| rationales for answers, aiming to be a significant testbed for the development of AI | |
| models that integrate visual and natural language reasoning.\n""") | |
| with col2: | |
| st.markdown("#### KB-VQA Datasets Comparison") | |
| st.write(datasets_comparison_table, use_column_width=True) | |
| st.write("-----------------------") | |
| # Display OK-VQA dataset details | |
| with st.container(): | |
| st.write("\n" * 10) | |
| st.markdown("## OK-VQA Dataset") | |
| st.write("This model was fine-tuned and evaluated using OK-VQA dataset.\n") | |
| with st.expander("OK-VQA Dataset Characteristics"): | |
| st.markdown("#### OK-VQA Dataset Characteristics") | |
| st.write(okvqa_dataset_characteristics) | |
| with st.expander("Questions Distribution over Knowledge Category"): | |
| df = pd.read_excel(config.DATASET_ANALYSES_PATH, sheet_name="Question Category Dist") | |
| st.markdown("#### Questions Distribution over Knowledge Category") | |
| dataset_analyzer.plot_bar_chart(df, "Knowledge Category", "Percentage", "Questions Distribution over Knowledge Category") | |
| with st.expander("Distribution of Question Keywords"): | |
| dataset_analyzer.categorize_questions() | |
| st.markdown("#### Distribution of Question Keywords") | |
| dataset_analyzer.plot_question_distribution() | |
| # Display sample data | |
| with st.container(): | |
| with st.expander("Show Dataset Samples"): | |
| n = random.randint(1,len(train_data)-10) | |
| # Displaying 10 random samples. | |
| st.write(train_data[n:n+10]) |