| | """Visualization utilities leveraging the Strategy Pattern for the BI dashboard.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | from abc import ABC, abstractmethod |
| | from io import BytesIO |
| | from typing import Any, Dict, Iterable, Optional |
| |
|
| | import matplotlib |
| | import matplotlib.pyplot as plt |
| | from matplotlib.figure import Figure |
| | import pandas as pd |
| | import numpy as np |
| |
|
| | |
| | matplotlib.use('Agg') |
| |
|
| | AGGREGATIONS = { |
| | "sum": "sum", |
| | "mean": "mean", |
| | "median": "median", |
| | "count": "count", |
| | } |
| |
|
| |
|
| | class VisualizationStrategy(ABC): |
| | """Abstract base class for visualization strategies.""" |
| |
|
| | @abstractmethod |
| | def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure: |
| | """Generate a Matplotlib figure from the provided dataframe and arguments.""" |
| | pass |
| |
|
| | def validate_columns(self, df: pd.DataFrame, columns: Iterable[str]) -> None: |
| | """Ensure every column exists inside the DataFrame.""" |
| | missing = [col for col in columns if col not in df.columns] |
| | if missing: |
| | raise ValueError(f"Column(s) not found in dataset: {', '.join(missing)}") |
| |
|
| | def _create_figure(self) -> Figure: |
| | """Helper to create a standard figure with tight layout.""" |
| | fig = Figure(figsize=(10, 6)) |
| | fig.set_layout_engine("tight") |
| | return fig |
| |
|
| |
|
| | class TimeSeriesStrategy(VisualizationStrategy): |
| | """Strategy for generating time-series plots.""" |
| |
|
| | def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure: |
| | date_column = kwargs.get("date_column") |
| | value_column = kwargs.get("value_column") |
| | aggregation = kwargs.get("aggregation", "sum") |
| |
|
| | if not date_column or not value_column: |
| | raise ValueError("Date and value columns are required for Time Series.") |
| | |
| | self.validate_columns(df, [date_column, value_column]) |
| | |
| | if aggregation not in AGGREGATIONS: |
| | raise ValueError("Unsupported aggregation method.") |
| |
|
| | date_series = pd.to_datetime(df[date_column], errors="coerce") |
| | subset = df.loc[date_series.notna(), [date_column, value_column]].copy() |
| | subset[date_column] = pd.to_datetime(subset[date_column]) |
| | grouped = subset.groupby(subset[date_column].dt.date)[value_column].agg(aggregation).reset_index() |
| | |
| | |
| | grouped = grouped.sort_values(by=date_column) |
| |
|
| | fig = self._create_figure() |
| | ax = fig.add_subplot(111) |
| | |
| | ax.plot(grouped[date_column], grouped[value_column], marker='o', linestyle='-') |
| | ax.set_title(f"{value_column} over time ({aggregation})") |
| | ax.set_xlabel(date_column) |
| | ax.set_ylabel(value_column) |
| | ax.grid(True, linestyle='--', alpha=0.7) |
| | |
| | |
| | fig.autofmt_xdate() |
| | |
| | return fig |
| |
|
| |
|
| | class DistributionStrategy(VisualizationStrategy): |
| | """Strategy for generating distribution plots (histogram/box).""" |
| |
|
| | def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure: |
| | column = kwargs.get("column") |
| | plot_type = kwargs.get("plot_type", "histogram") |
| |
|
| | if not column: |
| | raise ValueError("Numeric column is required for Distribution plot.") |
| |
|
| | self.validate_columns(df, [column]) |
| |
|
| | |
| | numeric_series = pd.to_numeric(df[column], errors="coerce").dropna() |
| | if numeric_series.empty: |
| | raise ValueError("Selected column does not contain numeric data.") |
| |
|
| | fig = self._create_figure() |
| | ax = fig.add_subplot(111) |
| |
|
| | if plot_type == "box": |
| | ax.boxplot(numeric_series, vert=True, patch_artist=True) |
| | ax.set_title(f"Distribution of {column}") |
| | ax.set_ylabel(column) |
| | ax.set_xticks([]) |
| | else: |
| | ax.hist(numeric_series, bins=30, edgecolor='black', alpha=0.7) |
| | ax.set_title(f"Distribution of {column}") |
| | ax.set_xlabel(column) |
| | ax.set_ylabel("Frequency") |
| | ax.grid(axis='y', linestyle='--', alpha=0.7) |
| |
|
| | return fig |
| |
|
| |
|
| | class CategoryStrategy(VisualizationStrategy): |
| | """Strategy for generating categorical charts (bar/pie).""" |
| |
|
| | def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure: |
| | category_column = kwargs.get("category_column") |
| | value_column = kwargs.get("value_column") |
| | aggregation = kwargs.get("aggregation", "sum") |
| | chart_type = kwargs.get("chart_type", "bar").lower() |
| |
|
| | if not category_column or not value_column: |
| | raise ValueError("Category and value columns are required for Category plot.") |
| |
|
| | self.validate_columns(df, [category_column, value_column]) |
| | if aggregation not in AGGREGATIONS: |
| | raise ValueError("Unsupported aggregation method.") |
| |
|
| | grouped = ( |
| | df.groupby(category_column)[value_column] |
| | .agg(aggregation) |
| | .reset_index() |
| | .sort_values(by=value_column, ascending=False) |
| | ) |
| |
|
| | fig = self._create_figure() |
| | ax = fig.add_subplot(111) |
| |
|
| | if chart_type == "pie": |
| | |
| | wedges, texts, autotexts = ax.pie( |
| | grouped[value_column], |
| | labels=grouped[category_column], |
| | autopct='%1.1f%%', |
| | startangle=90 |
| | ) |
| | ax.set_title(f"{value_column} by {category_column}") |
| | else: |
| | |
| | bars = ax.bar(grouped[category_column], grouped[value_column], alpha=0.7, edgecolor='black') |
| | ax.set_title(f"{value_column} by {category_column}") |
| | ax.set_xlabel(category_column) |
| | ax.set_ylabel(f"{aggregation} of {value_column}") |
| | ax.grid(axis='y', linestyle='--', alpha=0.7) |
| | |
| | |
| | if len(grouped) > 5: |
| | plt.setp(ax.get_xticklabels(), rotation=45, ha="right") |
| |
|
| | return fig |
| |
|
| |
|
| | class ScatterStrategy(VisualizationStrategy): |
| | """Strategy for generating scatter plots.""" |
| |
|
| | def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure: |
| | x_column = kwargs.get("x_column") |
| | y_column = kwargs.get("y_column") |
| | color_column = kwargs.get("color_column") |
| |
|
| | if not x_column or not y_column: |
| | raise ValueError("X and Y columns are required for Scatter plot.") |
| |
|
| | columns = [x_column, y_column] |
| | if color_column: |
| | columns.append(color_column) |
| | self.validate_columns(df, columns) |
| |
|
| | |
| | x = pd.to_numeric(df[x_column], errors="coerce") |
| | y = pd.to_numeric(df[y_column], errors="coerce") |
| |
|
| | valid_mask = ~(x.isna() | y.isna()) |
| | if valid_mask.sum() == 0: |
| | raise ValueError("Scatter plot requires numeric data in both X and Y columns.") |
| |
|
| | plot_df = df.loc[valid_mask].copy() |
| | plot_df[x_column] = x[valid_mask] |
| | plot_df[y_column] = y[valid_mask] |
| |
|
| | fig = self._create_figure() |
| | ax = fig.add_subplot(111) |
| |
|
| | if color_column: |
| | |
| | |
| | c_data = plot_df[color_column] |
| | if pd.api.types.is_numeric_dtype(c_data): |
| | sc = ax.scatter(plot_df[x_column], plot_df[y_column], c=c_data, cmap='viridis', alpha=0.7) |
| | fig.colorbar(sc, ax=ax, label=color_column) |
| | else: |
| | |
| | categories = c_data.unique() |
| | colors = plt.cm.tab10(np.linspace(0, 1, len(categories))) |
| | for cat, color in zip(categories, colors): |
| | mask = c_data == cat |
| | ax.scatter(plot_df.loc[mask, x_column], plot_df.loc[mask, y_column], label=str(cat), color=color, alpha=0.7) |
| | ax.legend(title=color_column) |
| | else: |
| | ax.scatter(plot_df[x_column], plot_df[y_column], alpha=0.7) |
| |
|
| | ax.set_title(f"{y_column} vs {x_column}") |
| | ax.set_xlabel(x_column) |
| | ax.set_ylabel(y_column) |
| | ax.grid(True, linestyle='--', alpha=0.7) |
| |
|
| | return fig |
| |
|
| |
|
| | class CorrelationHeatmapStrategy(VisualizationStrategy): |
| | """Strategy for generating correlation heatmaps.""" |
| |
|
| | def generate(self, df: pd.DataFrame, **kwargs: Any) -> Figure: |
| | numeric_df = df.select_dtypes(include=["number"]).copy() |
| | if numeric_df.shape[1] < 2: |
| | raise ValueError("At least two numeric columns are required for a correlation heatmap.") |
| |
|
| | |
| | numeric_df = numeric_df.dropna(how="all") |
| | if numeric_df.empty: |
| | raise ValueError("No valid numeric data available for correlation heatmap.") |
| |
|
| | corr = numeric_df.corr() |
| |
|
| | fig = self._create_figure() |
| | ax = fig.add_subplot(111) |
| | |
| | cax = ax.imshow(corr, cmap='RdBu', vmin=-1, vmax=1) |
| | fig.colorbar(cax, ax=ax) |
| | |
| | |
| | ax.set_xticks(range(len(corr.columns))) |
| | ax.set_yticks(range(len(corr.columns))) |
| | ax.set_xticklabels(corr.columns, rotation=45, ha="right") |
| | ax.set_yticklabels(corr.columns) |
| | |
| | |
| | for i in range(len(corr.columns)): |
| | for j in range(len(corr.columns)): |
| | text = ax.text(j, i, f"{corr.iloc[i, j]:.2f}", |
| | ha="center", va="center", color="black") |
| | |
| | ax.set_title("Correlation Heatmap") |
| |
|
| | return fig |
| |
|
| |
|
| | def figure_to_png_bytes(fig: Figure) -> BytesIO: |
| | """Export the figure to an in-memory PNG buffer.""" |
| | buf = BytesIO() |
| | fig.savefig(buf, format="png") |
| | buf.seek(0) |
| | return buf |
| |
|
| |
|
| | def create_time_series_plot(df: pd.DataFrame, date_column: str, value_column: str, aggregation: str = "sum") -> Figure: |
| | """Generate a time-series plot using the TimeSeriesStrategy.""" |
| | strategy = TimeSeriesStrategy() |
| | return strategy.generate(df, date_column=date_column, value_column=value_column, aggregation=aggregation) |
| |
|
| |
|
| | def create_distribution_plot(df: pd.DataFrame, column: str, plot_type: str = "histogram") -> Figure: |
| | """Generate a distribution plot using the DistributionStrategy.""" |
| | strategy = DistributionStrategy() |
| | return strategy.generate(df, column=column, plot_type=plot_type) |
| |
|
| |
|
| | def create_category_plot( |
| | df: pd.DataFrame, category_column: str, value_column: str, aggregation: str = "sum", chart_type: str = "bar" |
| | ) -> Figure: |
| | """Generate a category plot using the CategoryStrategy.""" |
| | strategy = CategoryStrategy() |
| | return strategy.generate( |
| | df, category_column=category_column, value_column=value_column, aggregation=aggregation, chart_type=chart_type |
| | ) |
| |
|
| |
|
| | def create_scatter_plot( |
| | df: pd.DataFrame, x_column: str, y_column: str, color_column: Optional[str] = None |
| | ) -> Figure: |
| | """Generate a scatter plot using the ScatterStrategy.""" |
| | strategy = ScatterStrategy() |
| | return strategy.generate(df, x_column=x_column, y_column=y_column, color_column=color_column) |
| |
|
| |
|
| | def create_correlation_heatmap(df: pd.DataFrame) -> Figure: |
| | """Generate a correlation heatmap using the CorrelationHeatmapStrategy.""" |
| | strategy = CorrelationHeatmapStrategy() |
| | return strategy.generate(df) |
| |
|