LukeFP's picture
hot fix, excluding a disicpline
9876e16
"""
PhySH Taxonomy Classifier — Gradio App
Two-stage hierarchical cascade:
Stage 1 → Discipline prediction (18-class multi-label)
Stage 2 → Concept prediction (186-class multi-label, conditioned on discipline probs)
Models were trained on APS PhySH labels with google/embeddinggemma-300m embeddings.
"""
import os
import re
from pathlib import Path
from typing import Dict, List, Tuple
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
# ---------------------------------------------------------------------------
# Model definitions (mirror the training code exactly)
# ---------------------------------------------------------------------------
class MultiLabelMLP(nn.Module):
def __init__(self, input_dim: int, output_dim: int,
hidden_layers: Tuple[int, ...] = (1024, 512), dropout: float = 0.3):
super().__init__()
layers = []
prev_dim = input_dim
for hidden_dim in hidden_layers:
layers.extend([nn.Linear(prev_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)])
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, output_dim))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
class DisciplineConditionedMLP(nn.Module):
def __init__(self, embedding_dim: int, discipline_dim: int, output_dim: int,
hidden_layers: Tuple[int, ...] = (1024, 512), dropout: float = 0.3,
discipline_dropout: float = 0.0, use_logits: bool = False):
super().__init__()
self.use_logits = use_logits
self.discipline_dropout = nn.Dropout(discipline_dropout)
layers = []
prev_dim = embedding_dim + discipline_dim
for hidden_dim in hidden_layers:
layers.extend([nn.Linear(prev_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)])
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, output_dim))
self.network = nn.Sequential(*layers)
def forward(self, embedding: torch.Tensor, discipline_probs: torch.Tensor) -> torch.Tensor:
if self.use_logits:
disc_features = torch.clamp(discipline_probs, 1e-7, 1 - 1e-7)
disc_features = torch.log(disc_features / (1 - disc_features))
else:
disc_features = discipline_probs
disc_features = self.discipline_dropout(disc_features)
return self.network(torch.cat([embedding, disc_features], dim=1))
# ---------------------------------------------------------------------------
# Paths
# ---------------------------------------------------------------------------
MODELS_DIR = Path(__file__).resolve().parent
DISCIPLINE_MODEL_PATH = MODELS_DIR / "discipline_classifier_gemma_20260130_140842.pt"
CONCEPT_MODEL_PATH = MODELS_DIR / "concept_conditioned_gemma_20260130_140842.pt"
EMBEDDING_MODEL_NAME = "google/embeddinggemma-300m"
EXCLUDED_DISCIPLINES = {"Quantum Physics"}
# ---------------------------------------------------------------------------
# Globals (loaded once at startup)
# ---------------------------------------------------------------------------
device: str = "cpu"
embedding_model: SentenceTransformer = None
discipline_model: MultiLabelMLP = None
concept_model: DisciplineConditionedMLP = None
discipline_labels: List[Dict] = []
concept_labels: List[Dict] = []
def load_models():
global device, embedding_model, discipline_model, concept_model
global discipline_labels, concept_labels
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print(f"Loading embedding model ({EMBEDDING_MODEL_NAME}) on {device} …")
hf_token = os.environ.get("HF_TOKEN")
embedding_model = SentenceTransformer(
EMBEDDING_MODEL_NAME, device=device, token=hf_token,
)
# --- discipline model ---
disc_ckpt = torch.load(DISCIPLINE_MODEL_PATH, map_location=device, weights_only=False)
dc = disc_ckpt["model_config"]
discipline_model = MultiLabelMLP(
dc["input_dim"], dc["output_dim"],
tuple(dc["hidden_layers"]), dc["dropout"],
)
discipline_model.load_state_dict(disc_ckpt["model_state_dict"])
discipline_model.to(device).eval()
discipline_labels = disc_ckpt["class_labels"]
# --- concept model ---
conc_ckpt = torch.load(CONCEPT_MODEL_PATH, map_location=device, weights_only=False)
cc = conc_ckpt["model_config"]
concept_model = DisciplineConditionedMLP(
cc["embedding_dim"], cc["discipline_dim"], cc["output_dim"],
tuple(cc["hidden_layers"]), cc["dropout"],
cc.get("discipline_dropout", 0.0), cc.get("use_logits", False),
)
concept_model.load_state_dict(conc_ckpt["model_state_dict"])
concept_model.to(device).eval()
concept_labels = conc_ckpt["class_labels"]
print(f"Loaded {len(discipline_labels)} disciplines, {len(concept_labels)} concepts")
# ---------------------------------------------------------------------------
# Prediction
# ---------------------------------------------------------------------------
def clean_text(text: str) -> str:
if not text:
return ""
return re.sub(r"\s+", " ", text).strip()
def predict(title: str, abstract: str, threshold: float, top_k: int):
"""Run the two-stage cascade and return formatted results."""
combined = clean_text(title)
abs_clean = clean_text(abstract)
if combined and abs_clean:
combined = f"{combined} [SEP] {abs_clean}"
elif abs_clean:
combined = abs_clean
if not combined.strip():
return "Please enter at least a title or abstract.", ""
# Embed
embedding = embedding_model.encode(
[combined], normalize_embeddings=True, convert_to_numpy=True,
)
emb_tensor = torch.FloatTensor(embedding).to(device)
with torch.no_grad():
# Stage 1
disc_logits = discipline_model(emb_tensor)
disc_probs = torch.sigmoid(disc_logits).cpu().numpy()[0]
# Stage 2
disc_probs_tensor = torch.FloatTensor(disc_probs).unsqueeze(0).to(device)
conc_logits = concept_model(emb_tensor, disc_probs_tensor)
conc_probs = torch.sigmoid(conc_logits).cpu().numpy()[0]
# Format discipline results (skip excluded labels)
disc_order = np.argsort(disc_probs)[::-1]
disc_lines = []
rank = 0
for idx in disc_order:
label = discipline_labels[idx].get("label", f"Discipline_{idx}")
if label in EXCLUDED_DISCIPLINES:
continue
rank += 1
if rank > top_k:
break
prob = disc_probs[idx]
marker = "**" if prob >= threshold else ""
disc_lines.append(f"{rank}. {marker}{label}{marker}{prob:.1%}")
# Format concept results
conc_order = np.argsort(conc_probs)[::-1]
conc_lines = []
for rank, idx in enumerate(conc_order[:top_k], 1):
prob = conc_probs[idx]
label = concept_labels[idx].get("label", f"Concept_{idx}")
marker = "**" if prob >= threshold else ""
conc_lines.append(f"{rank}. {marker}{label}{marker}{prob:.1%}")
disc_md = f"### Disciplines (threshold ≥ {threshold:.0%})\n\n" + "\n".join(disc_lines)
conc_md = f"### Research-Area Concepts (threshold ≥ {threshold:.0%})\n\n" + "\n".join(conc_lines)
return disc_md, conc_md
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
EXAMPLES = [
[
"Quantum Computing: Vision and Challenges",
(
"The recent development of quantum computing, which uses entanglement, superposition, and other quantum fundamental concepts, "
"can provide substantial processing advantages over traditional computing. These quantum features help solve many complex "
"problems that cannot be solved otherwise with conventional computing methods. These problems include modeling quantum mechanics, "
"logistics, chemical-based advances, drug design, statistical science, sustainable energy, banking, reliable communication, and "
"quantum chemical engineering. The last few years have witnessed remarkable progress in quantum software and algorithm creation "
"and quantum hardware research, which has significantly advanced the prospect of realizing quantum computers. It would be helpful "
"to have comprehensive literature research on this area to grasp the current status and find outstanding problems that require "
"considerable attention from the research community working in the quantum computing industry. To better understand quantum computing, "
"this paper examines the foundations and vision based on current research in this area. We discuss cutting-edge developments in quantum "
"computer hardware advancement and subsequent advances in quantum cryptography, quantum software, and high-scalability quantum computers. "
"Many potential challenges and exciting new trends for quantum technology research and development are highlighted in this paper for a broader debate."
),
],
[
"Topological Insulators and Superconductors",
(
"Topological insulators are electronic materials that have a bulk band gap like an ordinary insulator but have protected conducting states "
"on their edge or surface. We review the theoretical foundation for topological insulators and superconductors and describe recent experiments."
),
],
[
"Floquet Topological Insulator in Semiconductor Quantum Wells",
(
"Topological phase transitions between a conventional insulator and a state of matter with topological properties have been proposed and observed "
"in mercury telluride - cadmium telluride quantum wells. We show that a topological state can be induced in such a device, initially in the trivial "
"phase, by irradiation with microwave frequencies, without closing the gap and crossing the phase transition. We show that the quasi-energy spectrum "
"exhibits a single pair of helical edge states. The velocity of the edge states can be tuned by adjusting the intensity of the microwave radiation. "
"We discuss the necessary experimental parameters for our proposal. This proposal provides an example and a proof of principle of a new non-equilibrium "
"topological state, Floquet topological insulator, introduced in this paper."
),
],
]
def build_app() -> gr.Blocks:
with gr.Blocks(
title="PhySH Taxonomy Classifier",
theme=gr.themes.Soft(primary_hue="indigo", secondary_hue="slate"),
) as demo:
gr.Markdown(
"# PhySH Taxonomy Classifier\n"
"Enter a paper **title** and **abstract** to predict APS PhySH disciplines "
"and research-area concepts using a two-stage hierarchical cascade.\n\n"
"Labels above the threshold are **bolded**."
)
with gr.Row():
with gr.Column(scale=2):
title_box = gr.Textbox(label="Title", lines=2, placeholder="Paper title …")
abstract_box = gr.Textbox(label="Abstract", lines=8, placeholder="Paper abstract …")
with gr.Row():
threshold_slider = gr.Slider(
minimum=0.05, maximum=0.95, value=0.35, step=0.05,
label="Threshold",
)
topk_slider = gr.Slider(
minimum=1, maximum=20, value=10, step=1, label="Top-K",
)
predict_btn = gr.Button("Classify", variant="primary", size="lg")
with gr.Column(scale=3):
disc_output = gr.Markdown(label="Disciplines")
conc_output = gr.Markdown(label="Concepts")
predict_btn.click(
fn=predict,
inputs=[title_box, abstract_box, threshold_slider, topk_slider],
outputs=[disc_output, conc_output],
)
gr.Examples(
examples=EXAMPLES,
inputs=[title_box, abstract_box],
label="Example papers",
)
return demo
if __name__ == "__main__":
load_models()
app = build_app()
app.launch()