# app.py import io import json import os from typing import Dict, Any, Optional, Tuple, List import numpy as np import pandas as pd import matplotlib.pyplot as plt import streamlit as st # ========================= # Theme (your spec + paper knobs) # ========================= plt.rcParams["font.family"] = "monospace" PRIMARY = np.array([166, 0, 0]) / 255 CONTRARY = np.array([0, 166, 166]) / 255 NEUTRAL_MEDIUM_GREY = np.array([128, 128, 128]) / 255 NEUTRAL_DARK_GREY = np.array([64, 64, 64]) / 255 def _mix(c1, c2, t: float): c1 = np.array(c1, dtype=float) c2 = np.array(c2, dtype=float) return (1 - t) * c1 + t * c2 def palette(): white = np.array([1.0, 1.0, 1.0]) return [ PRIMARY, CONTRARY, NEUTRAL_DARK_GREY, NEUTRAL_MEDIUM_GREY, _mix(PRIMARY, white, 0.35), _mix(CONTRARY, white, 0.35), _mix(NEUTRAL_DARK_GREY, white, 0.45), _mix(NEUTRAL_MEDIUM_GREY, white, 0.35), ] def set_paper_style(exaggerated: bool = True): if exaggerated: base = 18 label = 22 title = 24 tick = 18 legend = 18 else: base = 12 label = 14 title = 16 tick = 12 legend = 12 plt.rcParams.update({ "font.size": base, "axes.titlesize": title, "axes.labelsize": label, "xtick.labelsize": tick, "ytick.labelsize": tick, "legend.fontsize": legend, "axes.linewidth": 1.6, "lines.linewidth": 2.8, "lines.markersize": 7.0, "grid.alpha": 0.25, "grid.linewidth": 1.0, "figure.dpi": 120, "savefig.dpi": 600, "savefig.bbox": "tight", "savefig.pad_inches": 0.03, "xtick.direction": "out", "ytick.direction": "out", "xtick.major.size": 6.0, "ytick.major.size": 6.0, "xtick.major.width": 1.4, "ytick.major.width": 1.4, }) def clean_axes(ax): ax.grid(True, which="major", axis="both") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) return ax def figure_size(preset: str) -> Tuple[float, float]: presets = { "single": (3.45, 2.60), "single_tall": (3.45, 3.20), "double": (7.10, 2.90), "double_tall": (7.10, 3.80), "square": (4.00, 4.00), "wide": (7.10, 2.40), } return presets[preset] # ========================= # Loading: csv / json / npz / npy # ========================= def load_to_df(uploaded_file) -> pd.DataFrame: name = uploaded_file.name ext = os.path.splitext(name)[1].lower() data = uploaded_file.getvalue() if ext == ".csv": return pd.read_csv(io.BytesIO(data)) if ext == ".json": obj = json.loads(data.decode("utf-8")) if isinstance(obj, dict): return pd.DataFrame(obj) if isinstance(obj, list): return pd.DataFrame(obj) raise ValueError("Unsupported JSON: use dict-of-lists or list-of-dicts.") if ext == ".npz": z = np.load(io.BytesIO(data), allow_pickle=True) cols: Dict[str, Any] = {k: z[k] for k in z.files} # try to flatten 1D arrays into columns df = pd.DataFrame() for k, v in cols.items(): v = np.asarray(v) if v.ndim == 1: df[k] = v if len(df.columns) == 0: raise ValueError(".npz has no 1D arrays to treat as columns.") return df if ext == ".npy": arr = np.load(io.BytesIO(data), allow_pickle=True) arr = np.asarray(arr) if arr.dtype.names: return pd.DataFrame({n: arr[n] for n in arr.dtype.names}) if arr.ndim == 1: return pd.DataFrame({"y": arr}) if arr.ndim == 2: # columns: y0,y1,... return pd.DataFrame(arr, columns=[f"y{i}" for i in range(arr.shape[1])]) raise ValueError("Unsupported .npy shape. Use 1D or 2D array or structured array.") raise ValueError(f"Unsupported file extension: {ext}") # ========================= # Aggregation for error bars # ========================= def aggregate_xy(x: np.ndarray, y: np.ndarray, mode: str): # groups by exact x df = pd.DataFrame({"x": x, "y": y}).dropna() g = df.groupby("x")["y"] mean = g.mean() if mode == "std": err = g.std(ddof=1).fillna(0.0) elif mode == "sem": err = (g.std(ddof=1) / np.sqrt(g.count())).fillna(0.0) else: err = pd.Series(0.0, index=mean.index) xu = mean.index.to_numpy() return xu, mean.to_numpy(), err.to_numpy() # ========================= # Plotting # ========================= def make_plot( df: pd.DataFrame, kind: str, xcol: Optional[str], ycols: List[str], hue: Optional[str], agg: str, fill_band: bool, title: str, xlabel: str, ylabel: str, logx: bool, logy: bool, legend_mode: str, size_preset: str, hist_bins: int, hist_density: bool, exaggerated_text: bool, ): set_paper_style(exaggerated=exaggerated_text) w, h = figure_size(size_preset) fig, ax = plt.subplots(figsize=(w, h), constrained_layout=True) colors = palette() def _plot_series(label, x, y, color): if kind == "line": if agg in ("std", "sem"): xu, ym, ye = aggregate_xy(x, y, agg) ax.plot(xu, ym, marker="o", label=label, color=color) if fill_band and np.any(ye > 0): ax.fill_between(xu, ym - ye, ym + ye, alpha=0.18, color=color, linewidth=0) else: ax.plot(x, y, marker="o", label=label, color=color) elif kind == "scatter": ax.scatter(x, y, label=label, color=color, s=52, alpha=0.85, edgecolors="none") elif kind == "bar": # category bars: mean per category tmp = pd.DataFrame({"x": x, "y": y}).dropna() means = tmp.groupby("x")["y"].mean() xs = means.index.tolist() ys = means.values # stable positions pos = np.arange(len(xs)) ax.bar(pos, ys, label=label, color=color) ax.set_xticks(pos, xs) elif kind == "hist": ax.hist(np.asarray(y, dtype=float), bins=hist_bins, density=hist_density, alpha=0.35, label=label, color=color) if kind != "hist": assert xcol is not None x = df[xcol].to_numpy() # hue grouping if hue and hue in df.columns: groups = df[hue].astype(str).unique().tolist() ci = 0 for g in groups: sub = df[df[hue].astype(str) == g] gx = sub[xcol].to_numpy() for yc in ycols: _plot_series(f"{yc} | {hue}={g}", gx, sub[yc].to_numpy(), colors[ci % len(colors)]) ci += 1 else: for i, yc in enumerate(ycols): _plot_series(yc, x, df[yc].to_numpy(), colors[i % len(colors)]) else: for i, yc in enumerate(ycols): _plot_series(yc, None, df[yc].to_numpy(), colors[i % len(colors)]) clean_axes(ax) if title.strip(): ax.set_title(title) if kind != "hist": ax.set_xlabel(xlabel if xlabel.strip() else xcol) else: ax.set_xlabel(xlabel if xlabel.strip() else "") ax.set_ylabel(ylabel if ylabel.strip() else (", ".join(ycols) if ycols else "")) if logx and kind != "hist": ax.set_xscale("log") if logy: ax.set_yscale("log") if legend_mode == "none": if ax.get_legend() is not None: ax.get_legend().remove() elif legend_mode == "outside": ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False) else: ax.legend(loc="best", frameon=False) return fig def fig_to_bytes(fig, fmt: str) -> bytes: buf = io.BytesIO() fig.savefig(buf, format=fmt) buf.seek(0) return buf.read() # ========================= # Streamlit UI # ========================= st.set_page_config(page_title="PaperPlot (Matplotlib)", layout="wide") st.title("PaperPlot: upload data → tweak params → live preview → export") left, right = st.columns([1, 2]) with left: uploaded = st.file_uploader("Upload data", type=["csv", "json", "npz", "npy"]) st.caption("Supported: .csv / .json / .npz / .npy") kind = st.selectbox("Plot kind", ["line", "scatter", "bar", "hist"], index=0) exaggerated_text = st.toggle("Exaggerate text (paper readability)", value=True) size_preset = st.selectbox( "Figure size preset", ["single", "single_tall", "double", "double_tall", "square", "wide"], index=0 ) title = st.text_input("Title", value="") xlabel = st.text_input("X label (optional)", value="") ylabel = st.text_input("Y label (optional)", value="") logx = st.toggle("Log X", value=False) logy = st.toggle("Log Y", value=False) legend_mode = st.selectbox("Legend", ["best", "outside", "none"], index=0) agg = st.selectbox("Aggregate repeated x (line only)", ["none", "std", "sem"], index=0) fill_band = st.toggle("Show error band (line + agg)", value=True) hist_bins = st.slider("Hist bins", 5, 200, 30) hist_density = st.toggle("Hist density", value=True) with right: if not uploaded: st.info("Upload a dataset to start.") st.stop() try: df = load_to_df(uploaded) except Exception as e: st.error(f"Failed to load file: {e}") st.stop() st.subheader("Data preview") st.dataframe(df.head(50), use_container_width=True) cols = df.columns.tolist() numeric_cols = [c for c in cols if pd.api.types.is_numeric_dtype(df[c])] if kind != "hist": xcol = st.selectbox("X column", options=numeric_cols if numeric_cols else cols) else: xcol = None if numeric_cols: default_y = numeric_cols[:1] else: default_y = cols[:1] ycols = st.multiselect("Y column(s)", options=numeric_cols if numeric_cols else cols, default=default_y) hue = None if kind != "hist": hue = st.selectbox("Group / hue (optional)", options=["(none)"] + cols, index=0) hue = None if hue == "(none)" else hue if not ycols: st.warning("Pick at least one Y column.") st.stop() fig = make_plot( df=df, kind=kind, xcol=xcol, ycols=ycols, hue=hue, agg=agg if kind == "line" else "none", fill_band=fill_band, title=title, xlabel=xlabel, ylabel=ylabel, logx=logx, logy=logy, legend_mode=legend_mode, size_preset=size_preset, hist_bins=hist_bins, hist_density=hist_density, exaggerated_text=exaggerated_text, ) st.subheader("Live preview") st.pyplot(fig, use_container_width=True) c1, c2 = st.columns(2) with c1: st.download_button( "Download PDF", data=fig_to_bytes(fig, "pdf"), file_name="figure.pdf", mime="application/pdf", ) with c2: st.download_button( "Download PNG", data=fig_to_bytes(fig, "png"), file_name="figure.png", mime="image/png", )