|
|
|
|
|
import io |
|
|
import json |
|
|
import os |
|
|
from typing import Dict, Any, Optional, Tuple, List |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
import streamlit as st |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plt.rcParams["font.family"] = "monospace" |
|
|
|
|
|
PRIMARY = np.array([166, 0, 0]) / 255 |
|
|
CONTRARY = np.array([0, 166, 166]) / 255 |
|
|
NEUTRAL_MEDIUM_GREY = np.array([128, 128, 128]) / 255 |
|
|
NEUTRAL_DARK_GREY = np.array([64, 64, 64]) / 255 |
|
|
|
|
|
|
|
|
def _mix(c1, c2, t: float): |
|
|
c1 = np.array(c1, dtype=float) |
|
|
c2 = np.array(c2, dtype=float) |
|
|
return (1 - t) * c1 + t * c2 |
|
|
|
|
|
|
|
|
def palette(): |
|
|
white = np.array([1.0, 1.0, 1.0]) |
|
|
return [ |
|
|
PRIMARY, |
|
|
CONTRARY, |
|
|
NEUTRAL_DARK_GREY, |
|
|
NEUTRAL_MEDIUM_GREY, |
|
|
_mix(PRIMARY, white, 0.35), |
|
|
_mix(CONTRARY, white, 0.35), |
|
|
_mix(NEUTRAL_DARK_GREY, white, 0.45), |
|
|
_mix(NEUTRAL_MEDIUM_GREY, white, 0.35), |
|
|
] |
|
|
|
|
|
|
|
|
def set_paper_style(exaggerated: bool = True): |
|
|
if exaggerated: |
|
|
base = 18 |
|
|
label = 22 |
|
|
title = 24 |
|
|
tick = 18 |
|
|
legend = 18 |
|
|
else: |
|
|
base = 12 |
|
|
label = 14 |
|
|
title = 16 |
|
|
tick = 12 |
|
|
legend = 12 |
|
|
|
|
|
plt.rcParams.update({ |
|
|
"font.size": base, |
|
|
"axes.titlesize": title, |
|
|
"axes.labelsize": label, |
|
|
"xtick.labelsize": tick, |
|
|
"ytick.labelsize": tick, |
|
|
"legend.fontsize": legend, |
|
|
"axes.linewidth": 1.6, |
|
|
"lines.linewidth": 2.8, |
|
|
"lines.markersize": 7.0, |
|
|
"grid.alpha": 0.25, |
|
|
"grid.linewidth": 1.0, |
|
|
"figure.dpi": 120, |
|
|
"savefig.dpi": 600, |
|
|
"savefig.bbox": "tight", |
|
|
"savefig.pad_inches": 0.03, |
|
|
"xtick.direction": "out", |
|
|
"ytick.direction": "out", |
|
|
"xtick.major.size": 6.0, |
|
|
"ytick.major.size": 6.0, |
|
|
"xtick.major.width": 1.4, |
|
|
"ytick.major.width": 1.4, |
|
|
}) |
|
|
|
|
|
|
|
|
def clean_axes(ax): |
|
|
ax.grid(True, which="major", axis="both") |
|
|
ax.spines["top"].set_visible(False) |
|
|
ax.spines["right"].set_visible(False) |
|
|
return ax |
|
|
|
|
|
|
|
|
def figure_size(preset: str) -> Tuple[float, float]: |
|
|
presets = { |
|
|
"single": (3.45, 2.60), |
|
|
"single_tall": (3.45, 3.20), |
|
|
"double": (7.10, 2.90), |
|
|
"double_tall": (7.10, 3.80), |
|
|
"square": (4.00, 4.00), |
|
|
"wide": (7.10, 2.40), |
|
|
} |
|
|
return presets[preset] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_to_df(uploaded_file) -> pd.DataFrame: |
|
|
name = uploaded_file.name |
|
|
ext = os.path.splitext(name)[1].lower() |
|
|
data = uploaded_file.getvalue() |
|
|
|
|
|
if ext == ".csv": |
|
|
return pd.read_csv(io.BytesIO(data)) |
|
|
|
|
|
if ext == ".json": |
|
|
obj = json.loads(data.decode("utf-8")) |
|
|
if isinstance(obj, dict): |
|
|
return pd.DataFrame(obj) |
|
|
if isinstance(obj, list): |
|
|
return pd.DataFrame(obj) |
|
|
raise ValueError("Unsupported JSON: use dict-of-lists or list-of-dicts.") |
|
|
|
|
|
if ext == ".npz": |
|
|
z = np.load(io.BytesIO(data), allow_pickle=True) |
|
|
cols: Dict[str, Any] = {k: z[k] for k in z.files} |
|
|
|
|
|
df = pd.DataFrame() |
|
|
for k, v in cols.items(): |
|
|
v = np.asarray(v) |
|
|
if v.ndim == 1: |
|
|
df[k] = v |
|
|
if len(df.columns) == 0: |
|
|
raise ValueError(".npz has no 1D arrays to treat as columns.") |
|
|
return df |
|
|
|
|
|
if ext == ".npy": |
|
|
arr = np.load(io.BytesIO(data), allow_pickle=True) |
|
|
arr = np.asarray(arr) |
|
|
if arr.dtype.names: |
|
|
return pd.DataFrame({n: arr[n] for n in arr.dtype.names}) |
|
|
if arr.ndim == 1: |
|
|
return pd.DataFrame({"y": arr}) |
|
|
if arr.ndim == 2: |
|
|
|
|
|
return pd.DataFrame(arr, columns=[f"y{i}" for i in range(arr.shape[1])]) |
|
|
raise ValueError("Unsupported .npy shape. Use 1D or 2D array or structured array.") |
|
|
|
|
|
raise ValueError(f"Unsupported file extension: {ext}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def aggregate_xy(x: np.ndarray, y: np.ndarray, mode: str): |
|
|
|
|
|
df = pd.DataFrame({"x": x, "y": y}).dropna() |
|
|
g = df.groupby("x")["y"] |
|
|
mean = g.mean() |
|
|
if mode == "std": |
|
|
err = g.std(ddof=1).fillna(0.0) |
|
|
elif mode == "sem": |
|
|
err = (g.std(ddof=1) / np.sqrt(g.count())).fillna(0.0) |
|
|
else: |
|
|
err = pd.Series(0.0, index=mean.index) |
|
|
xu = mean.index.to_numpy() |
|
|
return xu, mean.to_numpy(), err.to_numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_plot( |
|
|
df: pd.DataFrame, |
|
|
kind: str, |
|
|
xcol: Optional[str], |
|
|
ycols: List[str], |
|
|
hue: Optional[str], |
|
|
agg: str, |
|
|
fill_band: bool, |
|
|
title: str, |
|
|
xlabel: str, |
|
|
ylabel: str, |
|
|
logx: bool, |
|
|
logy: bool, |
|
|
legend_mode: str, |
|
|
size_preset: str, |
|
|
hist_bins: int, |
|
|
hist_density: bool, |
|
|
exaggerated_text: bool, |
|
|
): |
|
|
set_paper_style(exaggerated=exaggerated_text) |
|
|
w, h = figure_size(size_preset) |
|
|
fig, ax = plt.subplots(figsize=(w, h), constrained_layout=True) |
|
|
colors = palette() |
|
|
|
|
|
def _plot_series(label, x, y, color): |
|
|
if kind == "line": |
|
|
if agg in ("std", "sem"): |
|
|
xu, ym, ye = aggregate_xy(x, y, agg) |
|
|
ax.plot(xu, ym, marker="o", label=label, color=color) |
|
|
if fill_band and np.any(ye > 0): |
|
|
ax.fill_between(xu, ym - ye, ym + ye, alpha=0.18, color=color, linewidth=0) |
|
|
else: |
|
|
ax.plot(x, y, marker="o", label=label, color=color) |
|
|
|
|
|
elif kind == "scatter": |
|
|
ax.scatter(x, y, label=label, color=color, s=52, alpha=0.85, edgecolors="none") |
|
|
|
|
|
elif kind == "bar": |
|
|
|
|
|
tmp = pd.DataFrame({"x": x, "y": y}).dropna() |
|
|
means = tmp.groupby("x")["y"].mean() |
|
|
xs = means.index.tolist() |
|
|
ys = means.values |
|
|
|
|
|
pos = np.arange(len(xs)) |
|
|
ax.bar(pos, ys, label=label, color=color) |
|
|
ax.set_xticks(pos, xs) |
|
|
|
|
|
elif kind == "hist": |
|
|
ax.hist(np.asarray(y, dtype=float), bins=hist_bins, density=hist_density, |
|
|
alpha=0.35, label=label, color=color) |
|
|
|
|
|
if kind != "hist": |
|
|
assert xcol is not None |
|
|
x = df[xcol].to_numpy() |
|
|
|
|
|
if hue and hue in df.columns: |
|
|
groups = df[hue].astype(str).unique().tolist() |
|
|
ci = 0 |
|
|
for g in groups: |
|
|
sub = df[df[hue].astype(str) == g] |
|
|
gx = sub[xcol].to_numpy() |
|
|
for yc in ycols: |
|
|
_plot_series(f"{yc} | {hue}={g}", gx, sub[yc].to_numpy(), colors[ci % len(colors)]) |
|
|
ci += 1 |
|
|
else: |
|
|
for i, yc in enumerate(ycols): |
|
|
_plot_series(yc, x, df[yc].to_numpy(), colors[i % len(colors)]) |
|
|
else: |
|
|
for i, yc in enumerate(ycols): |
|
|
_plot_series(yc, None, df[yc].to_numpy(), colors[i % len(colors)]) |
|
|
|
|
|
clean_axes(ax) |
|
|
if title.strip(): |
|
|
ax.set_title(title) |
|
|
if kind != "hist": |
|
|
ax.set_xlabel(xlabel if xlabel.strip() else xcol) |
|
|
else: |
|
|
ax.set_xlabel(xlabel if xlabel.strip() else "") |
|
|
ax.set_ylabel(ylabel if ylabel.strip() else (", ".join(ycols) if ycols else "")) |
|
|
|
|
|
if logx and kind != "hist": |
|
|
ax.set_xscale("log") |
|
|
if logy: |
|
|
ax.set_yscale("log") |
|
|
|
|
|
if legend_mode == "none": |
|
|
if ax.get_legend() is not None: |
|
|
ax.get_legend().remove() |
|
|
elif legend_mode == "outside": |
|
|
ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False) |
|
|
else: |
|
|
ax.legend(loc="best", frameon=False) |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
def fig_to_bytes(fig, fmt: str) -> bytes: |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format=fmt) |
|
|
buf.seek(0) |
|
|
return buf.read() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="PaperPlot (Matplotlib)", layout="wide") |
|
|
st.title("PaperPlot: upload data → tweak params → live preview → export") |
|
|
|
|
|
left, right = st.columns([1, 2]) |
|
|
|
|
|
with left: |
|
|
uploaded = st.file_uploader("Upload data", type=["csv", "json", "npz", "npy"]) |
|
|
st.caption("Supported: .csv / .json / .npz / .npy") |
|
|
|
|
|
kind = st.selectbox("Plot kind", ["line", "scatter", "bar", "hist"], index=0) |
|
|
exaggerated_text = st.toggle("Exaggerate text (paper readability)", value=True) |
|
|
|
|
|
size_preset = st.selectbox( |
|
|
"Figure size preset", |
|
|
["single", "single_tall", "double", "double_tall", "square", "wide"], |
|
|
index=0 |
|
|
) |
|
|
|
|
|
title = st.text_input("Title", value="") |
|
|
xlabel = st.text_input("X label (optional)", value="") |
|
|
ylabel = st.text_input("Y label (optional)", value="") |
|
|
|
|
|
logx = st.toggle("Log X", value=False) |
|
|
logy = st.toggle("Log Y", value=False) |
|
|
|
|
|
legend_mode = st.selectbox("Legend", ["best", "outside", "none"], index=0) |
|
|
|
|
|
agg = st.selectbox("Aggregate repeated x (line only)", ["none", "std", "sem"], index=0) |
|
|
fill_band = st.toggle("Show error band (line + agg)", value=True) |
|
|
|
|
|
hist_bins = st.slider("Hist bins", 5, 200, 30) |
|
|
hist_density = st.toggle("Hist density", value=True) |
|
|
|
|
|
with right: |
|
|
if not uploaded: |
|
|
st.info("Upload a dataset to start.") |
|
|
st.stop() |
|
|
|
|
|
try: |
|
|
df = load_to_df(uploaded) |
|
|
except Exception as e: |
|
|
st.error(f"Failed to load file: {e}") |
|
|
st.stop() |
|
|
|
|
|
st.subheader("Data preview") |
|
|
st.dataframe(df.head(50), use_container_width=True) |
|
|
|
|
|
cols = df.columns.tolist() |
|
|
numeric_cols = [c for c in cols if pd.api.types.is_numeric_dtype(df[c])] |
|
|
|
|
|
if kind != "hist": |
|
|
xcol = st.selectbox("X column", options=numeric_cols if numeric_cols else cols) |
|
|
else: |
|
|
xcol = None |
|
|
|
|
|
if numeric_cols: |
|
|
default_y = numeric_cols[:1] |
|
|
else: |
|
|
default_y = cols[:1] |
|
|
|
|
|
ycols = st.multiselect("Y column(s)", options=numeric_cols if numeric_cols else cols, default=default_y) |
|
|
|
|
|
hue = None |
|
|
if kind != "hist": |
|
|
hue = st.selectbox("Group / hue (optional)", options=["(none)"] + cols, index=0) |
|
|
hue = None if hue == "(none)" else hue |
|
|
|
|
|
if not ycols: |
|
|
st.warning("Pick at least one Y column.") |
|
|
st.stop() |
|
|
|
|
|
fig = make_plot( |
|
|
df=df, |
|
|
kind=kind, |
|
|
xcol=xcol, |
|
|
ycols=ycols, |
|
|
hue=hue, |
|
|
agg=agg if kind == "line" else "none", |
|
|
fill_band=fill_band, |
|
|
title=title, |
|
|
xlabel=xlabel, |
|
|
ylabel=ylabel, |
|
|
logx=logx, |
|
|
logy=logy, |
|
|
legend_mode=legend_mode, |
|
|
size_preset=size_preset, |
|
|
hist_bins=hist_bins, |
|
|
hist_density=hist_density, |
|
|
exaggerated_text=exaggerated_text, |
|
|
) |
|
|
|
|
|
st.subheader("Live preview") |
|
|
st.pyplot(fig, use_container_width=True) |
|
|
|
|
|
c1, c2 = st.columns(2) |
|
|
with c1: |
|
|
st.download_button( |
|
|
"Download PDF", |
|
|
data=fig_to_bytes(fig, "pdf"), |
|
|
file_name="figure.pdf", |
|
|
mime="application/pdf", |
|
|
) |
|
|
with c2: |
|
|
st.download_button( |
|
|
"Download PNG", |
|
|
data=fig_to_bytes(fig, "png"), |
|
|
file_name="figure.png", |
|
|
mime="image/png", |
|
|
) |
|
|
|