Spaces:
Sleeping
Sleeping
| import os | |
| import altair as alt | |
| from my_model.config import evaluation_config as config | |
| import streamlit as st | |
| from PIL import Image | |
| import pandas as pd | |
| import random | |
| class ResultDemonstrator: | |
| """ | |
| A class to demonstrate the results of the Knowledge-Based Visual Question Answering (KB-VQA) model. | |
| Attributes: | |
| main_data (pd.DataFrame): Data loaded from an Excel file containing evaluation results. | |
| sample_img_pool (list[str]): List of image file names available for demonstration. | |
| model_names (list[str]): List of model names as defined in the configuration. | |
| model_configs (list[str]): List of model configurations as defined in the configuration. | |
| demo_images_path(str): Path to the demo images directory. | |
| """ | |
| def __init__(self) -> None: | |
| """ | |
| Initializes the ResultDemonstrator class by loading the data from an Excel file. | |
| """ | |
| # Load data | |
| self.main_data = pd.read_excel(config.EVALUATION_DATA_PATH, sheet_name="Main Data") | |
| self.sample_img_pool = list(os.listdir(config.DEMO_IMAGES_PATH)) | |
| self.model_names = config.MODEL_NAMES | |
| self.model_configs = config.MODEL_CONFIGURATIONS | |
| self.demo_images_path = config.DEMO_IMAGES_PATH | |
| def display_table(data: pd.DataFrame) -> None: | |
| """ | |
| Displays a DataFrame using Streamlit's dataframe display function. | |
| Args: | |
| data (pd.DataFrame): The data to display. | |
| """ | |
| st.dataframe(data) | |
| def calculate_and_append_data(self, data_list: list, score_column: str, model_config: str) -> None: | |
| """ | |
| Calculates mean scores by category and appends them to the data list. | |
| Args: | |
| data_list (list): List to append new data rows. | |
| score_column (str): Name of the column to calculate mean scores for. | |
| model_config (str): Configuration of the model. | |
| """ | |
| if score_column in self.main_data.columns: | |
| category_means = self.main_data.groupby('question_category')[score_column].mean() | |
| for category, mean_value in category_means.items(): | |
| data_list.append({ | |
| "Category": category, | |
| "Configuration": model_config, | |
| "Mean Value": round(mean_value * 100, 2) | |
| }) | |
| def display_ablation_results_per_question_category(self) -> None: | |
| """Displays ablation results per question category for each model configuration.""" | |
| score_types = ['vqa', 'vqa_gpt4', 'em', 'em_gpt4'] | |
| data_lists = {key: [] for key in score_types} | |
| column_names = { | |
| 'vqa': 'vqa_score_{config}', | |
| 'vqa_gpt4': 'gpt4_vqa_score_{config}', | |
| 'em': 'exact_match_score_{config}', | |
| 'em_gpt4': 'gpt4_em_score_{config}' | |
| } | |
| for model_name in config.MODEL_NAMES: | |
| for conf in config.MODEL_CONFIGURATIONS: | |
| model_config = f"{model_name}_{conf}" | |
| for score_type, col_template in column_names.items(): | |
| self.calculate_and_append_data(data_lists[score_type], | |
| col_template.format(config=model_config), | |
| model_config) | |
| # Process and display results for each score type | |
| for score_type, data_list in data_lists.items(): | |
| df = pd.DataFrame(data_list) | |
| results_df = df.pivot(index='Category', columns='Configuration', values='Mean Value').applymap( | |
| lambda x: f"{x:.2f}%") | |
| with st.expander(f"{score_type.upper()} Scores per Question Category and Model Configuration"): | |
| self.display_table(results_df) | |
| def display_main_results(self) -> None: | |
| """Displays the main model results from the Scores sheet, these are displayed from the file directly.""" | |
| main_scores = pd.read_excel(config.EVALUATION_DATA_PATH, sheet_name="Scores", index_col=0) | |
| st.markdown("### Main Model Results (Inclusive of Ablation Experiments)") | |
| main_scores.reset_index() | |
| self.display_table(main_scores) | |
| def plot_token_count_vs_scores(self, conf: str, model_name: str, score_name: str = 'VQA Score') -> None: | |
| """ | |
| Plots an interactive scatter plot comparing token count to VQA or EM scores using Altair. | |
| Args: | |
| conf (str): The configuration name. | |
| model_name (str): The name of the model. | |
| score_name (str): The type of score to plot. | |
| """ | |
| # Construct the full model configuration name | |
| model_configuration = f"{model_name}_{conf}" | |
| # Determine the score column name and legend mapping based on the score type | |
| if score_name == 'VQA Score': | |
| score_column_name = f"vqa_score_{model_configuration}" | |
| scores = self.main_data[score_column_name] | |
| # Map scores to categories for the legend | |
| legend_map = ['Correct' if score == 1 else 'Partially Correct' if round(score, 2) == 0.67 else 'Incorrect' | |
| for score in scores] | |
| color_scale = alt.Scale(domain=['Correct', 'Partially Correct', 'Incorrect'], range=['green', 'orange', | |
| 'red']) | |
| else: | |
| score_column_name = f"exact_match_score_{model_configuration}" | |
| scores = self.main_data[score_column_name] | |
| # Map scores to categories for the legend | |
| legend_map = ['Correct' if score == 1 else 'Incorrect' for score in scores] | |
| color_scale = alt.Scale(domain=['Correct', 'Incorrect'], range=['green', 'red']) | |
| # Retrieve token count from the data | |
| token_count = self.main_data[f'tokens_count_{conf}'] | |
| # Create a DataFrame for the scatter plot | |
| scatter_data = pd.DataFrame({ | |
| 'Index': range(len(token_count)), | |
| 'Token Count': token_count, | |
| score_name: legend_map | |
| }) | |
| # Create an interactive scatter plot using Altair | |
| chart = alt.Chart(scatter_data).mark_circle( | |
| size=60, | |
| fillOpacity=1, # Sets the fill opacity to maximum | |
| strokeWidth=1, # Adjusts the border width making the circles bolder | |
| stroke='black' # Sets the border color to black | |
| ).encode( | |
| x=alt.X('Index', scale=alt.Scale(domain=[0, 1020])), | |
| y=alt.Y('Token Count', scale=alt.Scale(domain=[token_count.min()-200, token_count.max()+200])), | |
| color=alt.Color(score_name, scale=color_scale, legend=alt.Legend(title=score_name)), | |
| tooltip=['Index', 'Token Count', score_name] | |
| ).interactive() # Enables zoom & pan | |
| chart = chart.properties( | |
| title={ | |
| "text": f"Token Count vs {score_name} ({model_configuration.replace('_', '-')})", | |
| "color": "black", # Optional color | |
| "fontSize": 20, # Optional font size | |
| "anchor": "middle", # Optional anchor position | |
| "offset": 0 # Optional offset | |
| }, | |
| width=700, | |
| height=500 | |
| ) | |
| # Display the interactive plot in Streamlit | |
| st.altair_chart(chart, use_container_width=True) | |
| def color_scores(value: float) -> str: | |
| """ | |
| Applies color coding based on the score value. | |
| Args: | |
| value (float): The score value. | |
| Returns: | |
| str: CSS color style based on score value. | |
| """ | |
| try: | |
| value = float(value) # Convert to float to handle numerical comparisons | |
| except ValueError: | |
| return 'color: black;' # Return black if value is not a number | |
| if value == 1.0: | |
| return 'color: green;' | |
| elif value == 0.0: | |
| return 'color: red;' | |
| elif value == 0.67: | |
| return 'color: orange;' | |
| return 'color: black;' | |
| def show_samples(self, num_samples: int = 3) -> None: | |
| """ | |
| Displays random sample images and their associated models answers and evaluations. | |
| Args: | |
| num_samples (int): Number of sample images to display. | |
| """ | |
| # Sample images from the pool | |
| target_imgs = random.sample(self.sample_img_pool, num_samples) | |
| # Generate model configurations | |
| model_configs = [f"{model_name}_{conf}" for model_name in self.model_names for conf in self.model_configs] | |
| # Define column names for scores dynamically | |
| column_names = { | |
| 'vqa': 'vqa_score_{config}', | |
| 'vqa_gpt4': 'gpt4_vqa_score_{config}', | |
| 'em': 'exact_match_score_{config}', | |
| 'em_gpt4': 'gpt4_em_score_{config}' | |
| } | |
| for img_filename in target_imgs: | |
| image_data = self.main_data[self.main_data['image_filename'] == img_filename] | |
| im = Image.open(f"{self.demo_images_path}/{img_filename}") | |
| col1, col2 = st.columns([1, 2]) # to display images side by side with their data. | |
| # Create a container for each image | |
| with st.container(): | |
| st.write("-------------------------------") | |
| with col1: | |
| st.image(im, use_column_width=True) | |
| with st.expander('Show Caption'): | |
| st.text(image_data.iloc[0]['caption']) | |
| with st.expander('Show DETIC Objects'): | |
| st.text(image_data.iloc[0]['objects_detic_trimmed']) | |
| with st.expander('Show YOLOv5 Objects'): | |
| st.text(image_data.iloc[0]['objects_yolov5']) | |
| with col2: | |
| if not image_data.empty: | |
| st.write(f"**Question:** {image_data.iloc[0]['question']}") | |
| st.write(f"**Ground Truth Answers:** {image_data.iloc[0]['raw_answers']}") | |
| # Initialize an empty DataFrame for summary data | |
| summary_data = pd.DataFrame( | |
| columns=['Model Configuration', 'Answer', 'VQA Score', 'VQA Score (GPT-4)', 'EM Score', | |
| 'EM Score (GPT-4)']) | |
| for config in model_configs: | |
| # Collect data for each model configuration | |
| row_data = { | |
| 'Model Configuration': config, | |
| 'Answer': image_data.iloc[0].get(f'{config}', '-') | |
| } | |
| for score_type, score_template in column_names.items(): | |
| score_col = score_template.format(config=config) | |
| score_value = image_data.iloc[0].get(score_col, '-') | |
| if pd.notna(score_value) and not isinstance(score_value, str): | |
| # Format score to two decimals if it's a valid number | |
| score_value = f"{float(score_value):.2f}" | |
| row_data[score_type.replace('_', ' ').title()] = score_value | |
| # Convert row data to a DataFrame and concatenate it | |
| rd = pd.DataFrame([row_data]) | |
| rd.columns = summary_data.columns | |
| summary_data = pd.concat([summary_data, rd], axis=0, ignore_index=True) | |
| # Apply styling to DataFrame for score coloring | |
| styled_summary = summary_data.style.applymap(self.color_scores, | |
| subset=['VQA Score', 'VQA Score (GPT-4)', | |
| 'EM Score', | |
| 'EM Score (GPT-4)']) | |
| st.markdown(styled_summary.to_html(escape=False, index=False), unsafe_allow_html=True) | |
| else: | |
| st.write("No data available for this image.") | |