paperplot / src /streamlit_app.py
zhangify's picture
Update src/streamlit_app.py
3de369a verified
# app.py
import io
import json
import os
from typing import Dict, Any, Optional, Tuple, List
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import streamlit as st
# =========================
# Theme (your spec + paper knobs)
# =========================
plt.rcParams["font.family"] = "monospace"
PRIMARY = np.array([166, 0, 0]) / 255
CONTRARY = np.array([0, 166, 166]) / 255
NEUTRAL_MEDIUM_GREY = np.array([128, 128, 128]) / 255
NEUTRAL_DARK_GREY = np.array([64, 64, 64]) / 255
def _mix(c1, c2, t: float):
c1 = np.array(c1, dtype=float)
c2 = np.array(c2, dtype=float)
return (1 - t) * c1 + t * c2
def palette():
white = np.array([1.0, 1.0, 1.0])
return [
PRIMARY,
CONTRARY,
NEUTRAL_DARK_GREY,
NEUTRAL_MEDIUM_GREY,
_mix(PRIMARY, white, 0.35),
_mix(CONTRARY, white, 0.35),
_mix(NEUTRAL_DARK_GREY, white, 0.45),
_mix(NEUTRAL_MEDIUM_GREY, white, 0.35),
]
def set_paper_style(exaggerated: bool = True):
if exaggerated:
base = 18
label = 22
title = 24
tick = 18
legend = 18
else:
base = 12
label = 14
title = 16
tick = 12
legend = 12
plt.rcParams.update({
"font.size": base,
"axes.titlesize": title,
"axes.labelsize": label,
"xtick.labelsize": tick,
"ytick.labelsize": tick,
"legend.fontsize": legend,
"axes.linewidth": 1.6,
"lines.linewidth": 2.8,
"lines.markersize": 7.0,
"grid.alpha": 0.25,
"grid.linewidth": 1.0,
"figure.dpi": 120,
"savefig.dpi": 600,
"savefig.bbox": "tight",
"savefig.pad_inches": 0.03,
"xtick.direction": "out",
"ytick.direction": "out",
"xtick.major.size": 6.0,
"ytick.major.size": 6.0,
"xtick.major.width": 1.4,
"ytick.major.width": 1.4,
})
def clean_axes(ax):
ax.grid(True, which="major", axis="both")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
return ax
def figure_size(preset: str) -> Tuple[float, float]:
presets = {
"single": (3.45, 2.60),
"single_tall": (3.45, 3.20),
"double": (7.10, 2.90),
"double_tall": (7.10, 3.80),
"square": (4.00, 4.00),
"wide": (7.10, 2.40),
}
return presets[preset]
# =========================
# Loading: csv / json / npz / npy
# =========================
def load_to_df(uploaded_file) -> pd.DataFrame:
name = uploaded_file.name
ext = os.path.splitext(name)[1].lower()
data = uploaded_file.getvalue()
if ext == ".csv":
return pd.read_csv(io.BytesIO(data))
if ext == ".json":
obj = json.loads(data.decode("utf-8"))
if isinstance(obj, dict):
return pd.DataFrame(obj)
if isinstance(obj, list):
return pd.DataFrame(obj)
raise ValueError("Unsupported JSON: use dict-of-lists or list-of-dicts.")
if ext == ".npz":
z = np.load(io.BytesIO(data), allow_pickle=True)
cols: Dict[str, Any] = {k: z[k] for k in z.files}
# try to flatten 1D arrays into columns
df = pd.DataFrame()
for k, v in cols.items():
v = np.asarray(v)
if v.ndim == 1:
df[k] = v
if len(df.columns) == 0:
raise ValueError(".npz has no 1D arrays to treat as columns.")
return df
if ext == ".npy":
arr = np.load(io.BytesIO(data), allow_pickle=True)
arr = np.asarray(arr)
if arr.dtype.names:
return pd.DataFrame({n: arr[n] for n in arr.dtype.names})
if arr.ndim == 1:
return pd.DataFrame({"y": arr})
if arr.ndim == 2:
# columns: y0,y1,...
return pd.DataFrame(arr, columns=[f"y{i}" for i in range(arr.shape[1])])
raise ValueError("Unsupported .npy shape. Use 1D or 2D array or structured array.")
raise ValueError(f"Unsupported file extension: {ext}")
# =========================
# Aggregation for error bars
# =========================
def aggregate_xy(x: np.ndarray, y: np.ndarray, mode: str):
# groups by exact x
df = pd.DataFrame({"x": x, "y": y}).dropna()
g = df.groupby("x")["y"]
mean = g.mean()
if mode == "std":
err = g.std(ddof=1).fillna(0.0)
elif mode == "sem":
err = (g.std(ddof=1) / np.sqrt(g.count())).fillna(0.0)
else:
err = pd.Series(0.0, index=mean.index)
xu = mean.index.to_numpy()
return xu, mean.to_numpy(), err.to_numpy()
# =========================
# Plotting
# =========================
def make_plot(
df: pd.DataFrame,
kind: str,
xcol: Optional[str],
ycols: List[str],
hue: Optional[str],
agg: str,
fill_band: bool,
title: str,
xlabel: str,
ylabel: str,
logx: bool,
logy: bool,
legend_mode: str,
size_preset: str,
hist_bins: int,
hist_density: bool,
exaggerated_text: bool,
):
set_paper_style(exaggerated=exaggerated_text)
w, h = figure_size(size_preset)
fig, ax = plt.subplots(figsize=(w, h), constrained_layout=True)
colors = palette()
def _plot_series(label, x, y, color):
if kind == "line":
if agg in ("std", "sem"):
xu, ym, ye = aggregate_xy(x, y, agg)
ax.plot(xu, ym, marker="o", label=label, color=color)
if fill_band and np.any(ye > 0):
ax.fill_between(xu, ym - ye, ym + ye, alpha=0.18, color=color, linewidth=0)
else:
ax.plot(x, y, marker="o", label=label, color=color)
elif kind == "scatter":
ax.scatter(x, y, label=label, color=color, s=52, alpha=0.85, edgecolors="none")
elif kind == "bar":
# category bars: mean per category
tmp = pd.DataFrame({"x": x, "y": y}).dropna()
means = tmp.groupby("x")["y"].mean()
xs = means.index.tolist()
ys = means.values
# stable positions
pos = np.arange(len(xs))
ax.bar(pos, ys, label=label, color=color)
ax.set_xticks(pos, xs)
elif kind == "hist":
ax.hist(np.asarray(y, dtype=float), bins=hist_bins, density=hist_density,
alpha=0.35, label=label, color=color)
if kind != "hist":
assert xcol is not None
x = df[xcol].to_numpy()
# hue grouping
if hue and hue in df.columns:
groups = df[hue].astype(str).unique().tolist()
ci = 0
for g in groups:
sub = df[df[hue].astype(str) == g]
gx = sub[xcol].to_numpy()
for yc in ycols:
_plot_series(f"{yc} | {hue}={g}", gx, sub[yc].to_numpy(), colors[ci % len(colors)])
ci += 1
else:
for i, yc in enumerate(ycols):
_plot_series(yc, x, df[yc].to_numpy(), colors[i % len(colors)])
else:
for i, yc in enumerate(ycols):
_plot_series(yc, None, df[yc].to_numpy(), colors[i % len(colors)])
clean_axes(ax)
if title.strip():
ax.set_title(title)
if kind != "hist":
ax.set_xlabel(xlabel if xlabel.strip() else xcol)
else:
ax.set_xlabel(xlabel if xlabel.strip() else "")
ax.set_ylabel(ylabel if ylabel.strip() else (", ".join(ycols) if ycols else ""))
if logx and kind != "hist":
ax.set_xscale("log")
if logy:
ax.set_yscale("log")
if legend_mode == "none":
if ax.get_legend() is not None:
ax.get_legend().remove()
elif legend_mode == "outside":
ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)
else:
ax.legend(loc="best", frameon=False)
return fig
def fig_to_bytes(fig, fmt: str) -> bytes:
buf = io.BytesIO()
fig.savefig(buf, format=fmt)
buf.seek(0)
return buf.read()
# =========================
# Streamlit UI
# =========================
st.set_page_config(page_title="PaperPlot (Matplotlib)", layout="wide")
st.title("PaperPlot: upload data → tweak params → live preview → export")
left, right = st.columns([1, 2])
with left:
uploaded = st.file_uploader("Upload data", type=["csv", "json", "npz", "npy"])
st.caption("Supported: .csv / .json / .npz / .npy")
kind = st.selectbox("Plot kind", ["line", "scatter", "bar", "hist"], index=0)
exaggerated_text = st.toggle("Exaggerate text (paper readability)", value=True)
size_preset = st.selectbox(
"Figure size preset",
["single", "single_tall", "double", "double_tall", "square", "wide"],
index=0
)
title = st.text_input("Title", value="")
xlabel = st.text_input("X label (optional)", value="")
ylabel = st.text_input("Y label (optional)", value="")
logx = st.toggle("Log X", value=False)
logy = st.toggle("Log Y", value=False)
legend_mode = st.selectbox("Legend", ["best", "outside", "none"], index=0)
agg = st.selectbox("Aggregate repeated x (line only)", ["none", "std", "sem"], index=0)
fill_band = st.toggle("Show error band (line + agg)", value=True)
hist_bins = st.slider("Hist bins", 5, 200, 30)
hist_density = st.toggle("Hist density", value=True)
with right:
if not uploaded:
st.info("Upload a dataset to start.")
st.stop()
try:
df = load_to_df(uploaded)
except Exception as e:
st.error(f"Failed to load file: {e}")
st.stop()
st.subheader("Data preview")
st.dataframe(df.head(50), use_container_width=True)
cols = df.columns.tolist()
numeric_cols = [c for c in cols if pd.api.types.is_numeric_dtype(df[c])]
if kind != "hist":
xcol = st.selectbox("X column", options=numeric_cols if numeric_cols else cols)
else:
xcol = None
if numeric_cols:
default_y = numeric_cols[:1]
else:
default_y = cols[:1]
ycols = st.multiselect("Y column(s)", options=numeric_cols if numeric_cols else cols, default=default_y)
hue = None
if kind != "hist":
hue = st.selectbox("Group / hue (optional)", options=["(none)"] + cols, index=0)
hue = None if hue == "(none)" else hue
if not ycols:
st.warning("Pick at least one Y column.")
st.stop()
fig = make_plot(
df=df,
kind=kind,
xcol=xcol,
ycols=ycols,
hue=hue,
agg=agg if kind == "line" else "none",
fill_band=fill_band,
title=title,
xlabel=xlabel,
ylabel=ylabel,
logx=logx,
logy=logy,
legend_mode=legend_mode,
size_preset=size_preset,
hist_bins=hist_bins,
hist_density=hist_density,
exaggerated_text=exaggerated_text,
)
st.subheader("Live preview")
st.pyplot(fig, use_container_width=True)
c1, c2 = st.columns(2)
with c1:
st.download_button(
"Download PDF",
data=fig_to_bytes(fig, "pdf"),
file_name="figure.pdf",
mime="application/pdf",
)
with c2:
st.download_button(
"Download PNG",
data=fig_to_bytes(fig, "png"),
file_name="figure.png",
mime="image/png",
)