Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| # Plot Function Definitions | |
| def create_gender_pie_chart(df): | |
| """Creates a bar chart for Gender Distribution.""" | |
| gender_counts = df['gender'].value_counts().reset_index() | |
| gender_counts.columns = ['Gender', 'Count'] | |
| fig_gender = px.pie( | |
| gender_counts, | |
| names='Gender', | |
| values='Count', | |
| hover_data=['Count'], | |
| hole=0.3 | |
| ) | |
| st.plotly_chart(fig_gender, use_container_width=True) | |
| def create_race_pie_chart(df): | |
| race_counts = df['race'].value_counts().reset_index() | |
| race_counts.columns = ['Race Type', 'Count'] | |
| fig_race = px.pie( | |
| race_counts, | |
| names='Race Type', | |
| values='Count', | |
| hover_data=['Count'], | |
| hole=0.3 | |
| ) | |
| st.plotly_chart(fig_race, use_container_width=True) | |
| def create_insurance_pie_chart(df): | |
| insurance_counts = df['insurance'].value_counts().reset_index() | |
| insurance_counts.columns = ['Insurance Type', 'Count'] | |
| fig_insurance = px.pie( | |
| insurance_counts, | |
| names='Insurance Type', | |
| values='Count', | |
| hover_data=['Count'], | |
| hole=0.3 | |
| ) | |
| st.plotly_chart(fig_insurance, use_container_width=True) | |
| def create_mortality_pie_chart(df): | |
| #plt.figure(figsize=(6,3), facecolor='white') | |
| total_admissions = df.shape[0] | |
| labels = ['Survived', 'Died'] | |
| sizes = [total_admissions - df['hospital_expire_flag'].sum(), | |
| df['hospital_expire_flag'].sum()] | |
| colors = ['#66b3ff', '#ff6666'] | |
| explode = (0.1, 0) | |
| plt.pie(sizes, explode=explode, labels=labels, colors=colors, | |
| autopct='%1.1f%%', startangle=140, textprops={'fontsize': 14}) | |
| plt.axis('equal') | |
| plt.tight_layout() | |
| st.pyplot(plt.gcf()) | |
| def create_admission_type_bar_chart(df): | |
| admission_counts = df['admission_type'].value_counts().reset_index() | |
| admission_counts.columns = ['Admission Type', 'Count'] | |
| fig_admission = px.bar( | |
| admission_counts, | |
| y='Admission Type', | |
| x='Count', | |
| color='Admission Type', | |
| labels={'Count': 'Number of Admissions', 'Admission Type': 'Admission Type'}, | |
| hover_data=['Count'] | |
| ) | |
| st.plotly_chart(fig_admission, use_container_width=True) | |
| def create_time_series_heatmap(df): | |
| """Creates an admissions over time heatmap.""" | |
| month_order = ['January', 'February', 'March', 'April', 'May', 'June', | |
| 'July', 'August', 'September', 'October', 'November', 'December'] | |
| df['admission_month'] = pd.Categorical(df['admission_month'], categories=month_order, ordered=True) | |
| heatmap_df = df.groupby(['admission_year', 'admission_month']).size().reset_index(name='counts') | |
| fig = px.density_heatmap( | |
| heatmap_df, | |
| x='admission_month', | |
| y='admission_year', | |
| z='counts', | |
| histfunc='sum', | |
| labels={'counts': 'Number of Admissions', 'admission_month': 'Admission Month', 'admission_year': 'Admission Year'}, | |
| color_continuous_scale='rdbu' | |
| ) | |
| fig.update_xaxes(categoryorder='array', categoryarray=month_order) | |
| fig.update_layout(yaxis=dict(autorange='reversed')) | |
| fig.update_traces(colorbar=dict(title='Admissions')) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # def create_stacked_bar_admission_race(df): | |
| # """Creates a stacked bar chart for Admission Types by Race.""" | |
| # admission_race = df.groupby(['race', 'admission_type']).size().unstack(fill_value=0) | |
| # admission_race_percent = admission_race.div(admission_race.sum(axis=1), axis=0) * 100 | |
| # admission_race_percent.plot(kind='bar', stacked=True, figsize=(8, 6), colormap='tab20') | |
| # plt.xlabel("Race") | |
| # plt.ylabel("Percentage of Admission Types") | |
| # plt.legend(title='Admission Type', bbox_to_anchor=(1.05, 1), loc='upper left') | |
| # plt.tight_layout() | |
| # st.pyplot(plt.gcf()) | |
| # def create_los_by_race(df): | |
| # """Creates a box plot for Length of Stay by Race.""" | |
| # fig, ax = plt.subplots(figsize=(6, 4)) | |
| # sns.boxplot(data=df, x='race', y='los', palette='Pastel1', ax=ax) | |
| # ax.set_xlabel("Race") | |
| # ax.set_ylabel("Length of Stay (Days)") | |
| # ax.set_xticklabels(ax.get_xticklabels(), rotation=45) | |
| # plt.tight_layout() | |
| # st.pyplot(fig) | |
| # def create_correlation_heatmap(df): | |
| # """Creates a correlation heatmap for numerical features.""" | |
| # numerical_features = df[['anchor_age', 'los']] | |
| # corr_matrix = numerical_features.corr() | |
| # fig, ax = plt.subplots(figsize=(3.5, 3)) | |
| # sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt=".2f", ax=ax) | |
| # plt.tight_layout() | |
| # st.pyplot(fig) | |
| def create_age_distribution_by_gender(df): | |
| plt.figure(figsize=(12, 8)) | |
| sns.histplot(data=df, x='anchor_age', bins=30, | |
| kde=True, palette='bright', hue='gender') | |
| plt.xlabel('Age', fontsize=16) | |
| plt.ylabel('Number of Admissions', fontsize=16) | |
| plt.xticks(fontsize=16) | |
| plt.yticks(fontsize=16) | |
| plt.tight_layout() | |
| st.pyplot(plt.gcf()) | |
| def create_age_distribution_by_admission_type(df): | |
| plt.figure(figsize=(12, 8)) | |
| sns.boxenplot(data=df, x='admission_type', | |
| y='anchor_age', palette='Set3') | |
| plt.xlabel('Admission Type', fontsize=16) | |
| plt.ylabel('Age', fontsize=16) | |
| plt.xticks(fontsize=16, rotation=45) | |
| plt.yticks(fontsize=16) | |
| plt.tight_layout() | |
| st.pyplot(plt.gcf()) | |
| def create_mortality_by_race(df): | |
| """Creates a bar chart for Mortality Rate by Race.""" | |
| mortality_race = df.groupby('race')['hospital_expire_flag'].mean().reset_index() | |
| mortality_race['mortality_rate'] = mortality_race['hospital_expire_flag'] * 100 | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| sns.barplot(data=mortality_race, x='race', y='mortality_rate', palette='Set2', ax=ax) | |
| ax.set_xlabel("Race") | |
| ax.set_ylabel("Mortality Rate (%)") | |
| ax.set_xticklabels(ax.get_xticklabels(), rotation=45) | |
| plt.tight_layout() | |
| st.pyplot(fig) | |
| def create_mortality_by_gender(df): | |
| """Creates a bar chart for Mortality Rate by Gender.""" | |
| mortality_gender = df.groupby('gender')['hospital_expire_flag'].mean().reset_index() | |
| mortality_gender['mortality_rate'] = mortality_gender['hospital_expire_flag'] * 100 | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| sns.barplot(data=mortality_gender, x='gender', y='mortality_rate', palette='Set3', ax=ax) | |
| ax.set_xlabel("Gender") | |
| ax.set_ylabel("Mortality Rate (%)") | |
| plt.tight_layout() | |
| st.pyplot(fig) | |
| def create_mortality_by_age_group(df): | |
| """Creates a bar chart for Mortality Rate by Age Group.""" | |
| bins = [0, 30, 50, 70, 90, 120] | |
| labels = ['0-30', '31-50', '51-70', '71-90', '91-120'] | |
| df['age_group'] = pd.cut(df['anchor_age'], bins=bins, labels=labels, right=False) | |
| mortality_age = df.groupby('age_group')['hospital_expire_flag'].mean().reset_index() | |
| mortality_age['mortality_rate'] = mortality_age['hospital_expire_flag'] * 100 | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| sns.barplot(data=mortality_age, x='age_group', y='mortality_rate', palette='coolwarm', ax=ax) | |
| ax.set_xlabel("Age Group") | |
| ax.set_ylabel("Mortality Rate (%)") | |
| plt.tight_layout() | |
| st.pyplot(fig) | |
| def create_violin_age_race_mortality(df): | |
| """Creates a violin plot for Age Distribution by Race and Mortality.""" | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| sns.violinplot( | |
| data=df, | |
| x='race', | |
| y='anchor_age', | |
| hue='hospital_expire_flag', | |
| split=True, | |
| palette='Set2', | |
| ax=ax | |
| ) | |
| ax.set_xlabel("Race") | |
| ax.set_ylabel("Age") | |
| ax.legend(title='Mortality', loc='upper right') | |
| plt.tight_layout() | |
| st.pyplot(fig) | |
| def create_heatmap_race_gender_mortality(df): | |
| """Creates a heatmap for Mortality Rate by Race and Gender.""" | |
| pivot_table = df.pivot_table( | |
| index='race', | |
| columns='gender', | |
| values='hospital_expire_flag', | |
| aggfunc='mean' | |
| ) * 100 | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| sns.heatmap(pivot_table, annot=True, fmt=".1f", cmap='YlOrRd', ax=ax) | |
| ax.set_xlabel("Gender") | |
| ax.set_ylabel("Race") | |
| plt.tight_layout() | |
| st.pyplot(fig) | |
| def create_treemap_race_mortality(df): | |
| """Creates a treemap for Race and Mortality.""" | |
| treemap_df = df.groupby(['race', 'hospital_expire_flag']).size().reset_index(name='counts') | |
| treemap_df['Mortality'] = treemap_df['hospital_expire_flag'].map({0: 'Survived', 1: 'Died'}) | |
| fig = px.treemap( | |
| treemap_df, | |
| path=['race', 'Mortality'], | |
| values='counts', | |
| color='Mortality', | |
| color_discrete_map={'Survived':'#66b3ff','Died':'#ff6666'} | |
| ) | |
| fig.update_layout(margin = dict(t=30, l=0, r=0, b=0)) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Streamlit Application | |
| # Set Streamlit page configuration | |
| st.set_page_config( | |
| page_title="MIMIC-IV ICU Patient Data Dashboard", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| st.title("MIMIC-IV ICU Patient Data Dashboard") | |
| st.markdown(''' | |
| Explore the general feature distribution and demographics related bias in ICU patients from the MIMIC-IV dataset. Utilize the sidebar filters to customize the data view''' | |
| ) | |
| # Sidebar Filters | |
| st.sidebar.header("Filter Data") | |
| def load_data(): | |
| admissions_df = pd.read_feather('data/admissions.feather') | |
| patients_df = pd.read_feather('data/patients.feather') | |
| # diagnoses_icd_df = pd.read_csv('data/diagnoses_icd.csv') | |
| # pharmacy_df = pd.read_csv('data/pharmacy.csv') | |
| # prescriptions_df = pd.read_csv('data/prescriptions.csv') | |
| # d_hcpcs_df = pd.read_csv('data/d_hcpcs.csv') | |
| # poe_detail_df = pd.read_csv('data/poe_detail.csv') | |
| # provider_df = pd.read_csv('data/provider.csv') | |
| race_map = {"WHITE":"WHITE", | |
| "BLACK/AFRICAN AMERICAN":"BLACK", | |
| "OTHER":"OTHER", | |
| "UNKNOWN":"UNKNOWN", | |
| "HISPANIC/LATINO - PUERTO RICAN":"HISPANIC", | |
| "WHITE - OTHER EUROPEAN":"WHITE", | |
| "HISPANIC OR LATINO":"HISPANIC", | |
| "ASIAN":"ASIAN", | |
| "ASIAN - CHINESE":"ASIAN", | |
| "WHITE - RUSSIAN":"WHITE", | |
| "BLACK/CAPE VERDEAN":"BLACK", | |
| "HISPANIC/LATINO - DOMINICAN":"HISPANIC", | |
| "BLACK/CARIBBEAN ISLAND":"BLACK", | |
| "BLACK/AFRICAN":"BLACK", | |
| "PATIENT DECLINED TO ANSWER":"UNKNOWN", | |
| "UNABLE TO OBTAIN":"UNKNOWN", | |
| "PORTUGUESE":"WHITE", | |
| "ASIAN - SOUTH EAST ASIAN":"ASIAN", | |
| "HISPANIC/LATINO - GUATEMALAN":"HISPANIC", | |
| "ASIAN - ASIAN INDIAN":"ASIAN", | |
| "WHITE - EASTERN EUROPEAN":"WHITE", | |
| "WHITE - BRAZILIAN":"WHITE", | |
| "AMERICAN INDIAN/ALASKA NATIVE":"NATIVES", | |
| "HISPANIC/LATINO - SALVADORAN":"HISPANIC", | |
| "HISPANIC/LATINO - MEXICAN":"HISPANIC", | |
| "HISPANIC/LATINO - COLUMBIAN":"HISPANIC", | |
| "MULTIPLE RACE/ETHNICITY":"MULTI-ETHINIC", | |
| "HISPANIC/LATINO - HONDURAN":"HISPANIC", | |
| "ASIAN - KOREAN":"ASIAN", | |
| "SOUTH AMERICAN":"HISPANIC", | |
| "HISPANIC/LATINO - CUBAN":"HISPANIC", | |
| "HISPANIC/LATINO - CENTRAL AMERICAN":"HISPANIC", | |
| "NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER":"NATIVES"} | |
| admissions_df['race'] = admissions_df['race'].map(race_map) | |
| merged_df = pd.merge(admissions_df, patients_df, on='subject_id', how='left') | |
| merged_df = merged_df.dropna(subset=['anchor_age', 'gender', 'race', 'hospital_expire_flag']) | |
| merged_df['admittime'] = pd.to_datetime(merged_df['admittime']) | |
| merged_df['dischtime'] = pd.to_datetime(merged_df['dischtime']) | |
| merged_df['deathtime'] = pd.to_datetime(merged_df['deathtime'], errors='coerce') | |
| # Create derived features | |
| merged_df['los'] = (merged_df['dischtime'] - merged_df['admittime']).dt.days | |
| merged_df['admission_year'] = merged_df['admittime'].dt.year | |
| merged_df['admission_month'] = merged_df['admittime'].dt.month_name() | |
| merged_df['admittime_date'] = merged_df['admittime'].dt.date | |
| return merged_df | |
| merged_df = load_data() | |
| # Sidebar Filters Function | |
| def add_sidebar_filters(df): | |
| # Admission Types | |
| admission_types = sorted(df['admission_type'].unique()) | |
| selected_admission_types = st.sidebar.multiselect( | |
| "Select Admission Type(s):", | |
| options=admission_types, | |
| default=admission_types | |
| ) | |
| # Insurance Types | |
| insurance_types = sorted(df['insurance'].unique()) | |
| selected_insurance_types = st.sidebar.multiselect( | |
| "Select Insurance Type(s):", | |
| options=insurance_types, | |
| default=insurance_types | |
| ) | |
| # Gender | |
| genders = sorted(df['gender'].unique()) | |
| selected_genders = st.sidebar.multiselect( | |
| "Select Gender(s):", | |
| options=genders, | |
| default=genders | |
| ) | |
| # Race | |
| races = sorted(df['race'].unique()) | |
| selected_races = st.sidebar.multiselect( | |
| "Select Race(s):", | |
| options=races, | |
| default=races | |
| ) | |
| # Year Range | |
| min_year = int(df['admission_year'].min()) | |
| max_year = int(df['admission_year'].max()) | |
| selected_years = st.sidebar.slider( | |
| "Select Admission Year Range:", | |
| min_value=min_year, | |
| max_value=max_year, | |
| value=(min_year, max_year) | |
| ) | |
| # Apply Filters | |
| filtered_df = df[ | |
| (df['admission_type'].isin(selected_admission_types)) & | |
| (df['insurance'].isin(selected_insurance_types)) & | |
| (df['gender'].isin(selected_genders)) & | |
| (df['race'].isin(selected_races)) & | |
| (df['admission_year'] >= selected_years[0]) & | |
| (df['admission_year'] <= selected_years[1]) | |
| ] | |
| return filtered_df | |
| filtered_df = add_sidebar_filters(merged_df) | |
| # Display Summary Statistics for Q1 | |
| st.header("Summary Statistics") | |
| # Create four columns for metrics | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| total_admissions = filtered_df.shape[0] | |
| st.metric("Total Admissions", f"{total_admissions:,}") | |
| with col2: | |
| average_age = filtered_df['anchor_age'].mean() | |
| st.metric("Average Age", f"{average_age:.2f} years") | |
| with col3: | |
| gender_counts = filtered_df['gender'].value_counts() | |
| male_count = gender_counts.get('M', 0) | |
| female_count = gender_counts.get('F', 0) | |
| st.metric("Male Patients", f"{male_count:,}") | |
| st.metric("Female Patients", f"{female_count:,}") | |
| with col4: | |
| mortality_rate = filtered_df['hospital_expire_flag'].mean() * 100 # Percentage | |
| st.metric("Mortality Rate", f"{mortality_rate:.2f}%") | |
| st.markdown("---") | |
| # Create Tabs for Q1 and Q2 | |
| tabs = st.tabs(["General Overview", "Potential Biases"]) | |
| # Q1: General Overview | |
| with tabs[0]: | |
| st.subheader("General Feature Distribution and Outcome Metrics") | |
| # Define the number of columns per row | |
| num_cols = 2 | |
| # Define all Q1 plots in a list with titles and plot-generating functions | |
| q1_plots_2_col = [ | |
| { | |
| "title": "Gender Distribution", | |
| "plot": lambda: create_gender_pie_chart(filtered_df) | |
| }, | |
| { | |
| "title": "Race Distribution", | |
| "plot": lambda: create_race_pie_chart(filtered_df) | |
| }, | |
| { | |
| "title": "Insurance Type Distribution", | |
| "plot": lambda: create_insurance_pie_chart(filtered_df) | |
| }, | |
| { | |
| "title": "Mortality Rate of ICU Patients", | |
| "plot": lambda: create_mortality_pie_chart(filtered_df) | |
| } | |
| ] | |
| # Arrange Q1 plots in a grid layout | |
| for i in range(0, len(q1_plots_2_col), num_cols): | |
| cols = st.columns(num_cols) | |
| for j in range(num_cols): | |
| if i + j < len(q1_plots_2_col): | |
| with cols[j]: | |
| st.subheader(q1_plots_2_col[i + j]["title"]) | |
| q1_plots_2_col[i + j]["plot"]() | |
| num_cols = 1 | |
| q1_plots_1_col = [ | |
| { | |
| "title": "Admission Type Count", | |
| "plot": lambda: create_admission_type_bar_chart(filtered_df) | |
| }, | |
| { | |
| "title": "Admissions Over Time", | |
| "plot": lambda: create_time_series_heatmap(filtered_df) | |
| } | |
| ] | |
| # Arrange Q1 plots in a grid layout | |
| for i in range(0, len(q1_plots_1_col), num_cols): | |
| cols = st.columns(num_cols) | |
| for j in range(num_cols): | |
| if i + j < len(q1_plots_1_col): | |
| with cols[j]: | |
| st.subheader(q1_plots_1_col[i + j]["title"]) | |
| q1_plots_1_col[i + j]["plot"]() | |
| # Q2: Potential Biases | |
| with tabs[1]: | |
| st.subheader("Analyzing Potential Biases Across Demographics") | |
| # Define the number of columns per row | |
| num_cols = 2 | |
| # Define all Q2 plots in a list with titles and plot-generating functions | |
| q2_plots = [ | |
| { | |
| "title": "Age Distribution of ICU Patients", | |
| "plot": lambda: create_age_distribution_by_gender(filtered_df) | |
| }, | |
| { | |
| "title": "Boxen Plot of Age Distribution by Admission Type", | |
| "plot": lambda: create_age_distribution_by_admission_type(filtered_df) | |
| }, | |
| { | |
| "title": "Mortality Rate by Race", | |
| "plot": lambda: create_mortality_by_race(filtered_df) | |
| }, | |
| { | |
| "title": "Mortality Rate by Gender", | |
| "plot": lambda: create_mortality_by_gender(filtered_df) | |
| }, | |
| { | |
| "title": "Mortality Rate by Age Group", | |
| "plot": lambda: create_mortality_by_age_group(filtered_df) | |
| }, | |
| { | |
| "title": "Age Distribution by Race and Mortality", | |
| "plot": lambda: create_violin_age_race_mortality(filtered_df) | |
| }, | |
| { | |
| "title": "Heatmap: Race & Gender vs. Mortality", | |
| "plot": lambda: create_heatmap_race_gender_mortality(filtered_df) | |
| }, | |
| { | |
| "title": "Treemap of Race and Mortality", | |
| "plot": lambda: create_treemap_race_mortality(filtered_df) | |
| } | |
| ] | |
| # Arrange Q2 plots in a grid layout | |
| for i in range(0, len(q2_plots), num_cols): | |
| cols = st.columns(num_cols) | |
| for j in range(num_cols): | |
| if i + j < len(q2_plots): | |
| with cols[j]: | |
| st.subheader(q2_plots[i + j]["title"]) | |
| q2_plots[i + j]["plot"]() | |
| # Footer | |
| st.markdown(""" | |
| --- | |
| **Data Source:** MIMIC-IV Dataset | |
| **Project:** Fairness in EHR Data | |
| **Developed with:** Streamlit, Python | |
| **Q3 Visuals:** https://idyllic-cucurucho-672fc1.netlify.app/ | |
| """) | |