samraatd's picture
Update app.py
c50779d verified
# app.py — Streamlit Legal LED Summarizer with Source Mapping (HF Hub-ready)
#
# Run locally:
# streamlit run app.py
#
# On Hugging Face Spaces:
# Put this file + requirements.txt in the Space repo.
#
# It will:
# - Download your fine-tuned LED checkpoint from HF Hub
# - Run summarization
# - Map generated sentences back to source sentences via LegalBERT + FAISS
import os
import re
import textwrap
from typing import List, Tuple
import streamlit as st
import torch
import numpy as np
from transformers import (
LEDTokenizerFast,
LEDForConditionalGeneration,
AutoTokenizer,
AutoModel,
)
from huggingface_hub import hf_hub_download
# Avoid OpenMP duplicate errors in some environments
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# -----------------------------
# CONFIG
# -----------------------------
DEFAULT_LED_MODEL = "allenai/led-base-16384"
DEFAULT_MAX_INPUT_LEN = 4096
DEFAULT_BEAMS = 5
DEFAULT_MAX_TARGET_LEN = 512
# Mapping defaults
LEGALBERT_NAME = "nlpaueb/legal-bert-base-uncased"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SIM_DEFAULT = 0.85 # similarity threshold to call a sentence SUPPORTED
TOP_K_SOURCES = 3 # how many source sentences to show per generated sentence
# -----------------------------
# HF Hub checkpoint config
# -----------------------------
# 🔁 Change these to your actual model repo + filename
HF_REPO_ID = "samraatd/legal-longdoc-summarization"
HF_CHECKPOINT_FILE = "checkpoint_epoch_50.pt"
def get_checkpoint_path_from_hub() -> str:
"""
Download the fine-tuned LED checkpoint from Hugging Face Hub and
return the local file path.
"""
try:
ckpt_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=HF_CHECKPOINT_FILE,
)
return ckpt_path
except Exception as e:
st.error(f"❌ Failed to download checkpoint from Hugging Face Hub: {e}")
return ""
# -----------------------------
# Caches / Loads
# -----------------------------
@st.cache_resource(show_spinner=False)
def load_led_model_and_tokenizer(model_name=DEFAULT_LED_MODEL, device=DEVICE):
tokenizer = LEDTokenizerFast.from_pretrained(model_name)
model = LEDForConditionalGeneration.from_pretrained(model_name).to(device)
return tokenizer, model
def load_checkpoint_weights_into_led(checkpoint_path, led_model):
if not checkpoint_path or not os.path.exists(checkpoint_path):
st.warning(f"Checkpoint not found at: {checkpoint_path}")
return {}
ck = torch.load(checkpoint_path, map_location="cpu")
loaded = {}
# Try common keys first
for keyname in ("led_state", "led_state_dict", "led_model", "led"):
if keyname in ck:
try:
led_model.load_state_dict(ck[keyname], strict=False)
loaded["led"] = keyname
st.info(f"Loaded LED weights from checkpoint key: '{keyname}'")
except Exception as e:
st.warning(f"Failed to load LED weights from key '{keyname}': {e}")
# Fallback: scan for a dict-like that overlaps with model state dict
if "led" not in loaded:
for k, v in ck.items():
if isinstance(v, dict) and set(v.keys()) & set(led_model.state_dict().keys()):
try:
led_model.load_state_dict(v, strict=False)
loaded["led"] = k
st.info(f"Loaded LED weights from checkpoint top-level key: '{k}'")
break
except Exception:
pass
if "led" not in loaded:
st.warning("Could not find LED weights key in checkpoint. Using base HF LED.")
return loaded
# -----------------------------
# Input building (original)
# -----------------------------
def build_condensed_natural_from_text(
raw_text,
max_chars=20000,
facts=None,
max_facts=8,
max_chunks=12,
):
text = raw_text.strip()
if not text:
return "[NO_INPUT_TEXT_PROVIDED]"
if len(text) > max_chars:
text = text[:max_chars] + "\n\n[TRUNCATED]"
# Facts
if facts:
enumerated = "\n".join([f"{i+1}. {f}" for i, f in enumerate(facts[:max_facts])])
facts_part = f"Relevant facts:\n{enumerated}\n"
else:
sentences = [s.strip() for s in text.replace("\n", " ").split(".") if s.strip()]
top_facts = sentences[:max_facts]
enumerated = "\n".join([f"{i+1}. {s}" for i, s in enumerate(top_facts)])
facts_part = (
f"Relevant facts:\n{enumerated}\n" if enumerated else "Relevant facts:\n\n"
)
# Chunks = paragraphs
paras = [p.strip() for p in text.split("\n\n") if p.strip()]
if not paras:
paras = [" ".join(text.split(".")[:5])]
paras = paras[:max_chunks]
para_lines = []
for i, p in enumerate(paras):
head = f"- Paragraph {i+1}: "
content = p if len(p) < 1200 else (p[:1200] + " [TRUNCATED]")
para_lines.append(head + content)
chunks_part = "Important paragraphs:\n" + "\n".join(para_lines) + "\n"
instruction = "\nPlease write a concise, professional summary in fluent English (3-5 sentences)."
combined = facts_part + "\n" + chunks_part + "\n" + instruction
return combined
def find_subsequence_indices(seq, sub):
if len(sub) == 0 or len(seq) < len(sub):
return []
res = []
Ls = len(sub)
for i in range(len(seq) - Ls + 1):
if seq[i : i + Ls] == sub:
res.append(i)
return res
def build_global_attention_mask_for_headers(tokenizer, input_ids_batch, header_texts, device):
if isinstance(input_ids_batch, torch.Tensor):
input_ids = input_ids_batch.cpu().tolist()
else:
input_ids = [list(map(int, row)) for row in input_ids_batch]
B = len(input_ids)
T = max(len(r) for r in input_ids)
gmask = [[0] * T for _ in range(B)]
header_token_seqs = []
for h in header_texts:
if not h:
header_token_seqs.append([])
continue
enc = tokenizer(h, add_special_tokens=False, truncation=True, return_tensors=None)
header_token_seqs.append(enc["input_ids"])
for b, seq in enumerate(input_ids):
L = len(seq)
if L > 0:
gmask[b][0] = 1
for hseq in header_token_seqs:
if not hseq:
continue
starts = find_subsequence_indices(seq, hseq)
for s in starts:
for offs in range(len(hseq)):
idx = s + offs
if idx < T:
gmask[b][idx] = 1
return torch.tensor(gmask, dtype=torch.long, device=device)
# -----------------------------
# Source mapping helpers
# -----------------------------
SENT_SPLIT_REGEX = re.compile(
r"(?<=[.!?])\s+(?=[A-Z(\[]|\d+\.|\•|\-)"
)
def split_sentences(text: str) -> List[str]:
parts = [s.strip() for s in SENT_SPLIT_REGEX.split(text) if s and s.strip()]
return [s for s in parts if len(s) > 1]
@st.cache_resource(show_spinner=False)
def load_legalbert(name: str = LEGALBERT_NAME):
tok = AutoTokenizer.from_pretrained(name)
mdl = AutoModel.from_pretrained(name)
mdl.to(DEVICE)
mdl.eval()
return tok, mdl
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.unsqueeze(-1)
summed = (last_hidden_state * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1)
return summed / counts
def embed_texts_legalbert(texts: List[str], batch_size: int = 16) -> np.ndarray:
Tok, Mdl = load_legalbert()
vecs = []
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
enc = Tok(
batch,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
).to(DEVICE)
out = Mdl(**enc).last_hidden_state
mean = (
mean_pool(out, enc["attention_mask"])
.detach()
.cpu()
.numpy()
.astype("float32")
)
vecs.append(mean)
return np.vstack(vecs) if vecs else np.zeros((0, 768), dtype="float32")
def build_faiss_index(sentences: List[str]):
try:
import faiss
except Exception:
st.error("FAISS is required. Install with `pip install faiss-cpu`.")
st.stop()
embs = embed_texts_legalbert(sentences)
faiss.normalize_L2(embs)
index = faiss.IndexFlatIP(embs.shape[1])
index.add(embs)
return index, embs.shape[1]
def map_generated_to_sources(
gen_sents: List[str],
index,
source_sents: List[str],
k: int = TOP_K_SOURCES,
):
try:
import faiss
except Exception:
st.error("FAISS is required. Install with `pip install faiss-cpu`.")
st.stop()
if not gen_sents:
return []
embs = embed_texts_legalbert(gen_sents)
faiss.normalize_L2(embs)
D, I = index.search(embs, k)
results = []
for i in range(len(gen_sents)):
triples = []
for idx, sim in zip(I[i], D[i]):
if 0 <= idx < len(source_sents):
triples.append((int(idx), float(sim), source_sents[idx]))
results.append(triples)
return results
# -----------------------------
# STREAMLIT UI
# -----------------------------
st.set_page_config(page_title="LDS - Validation-style Summarizer", layout="wide")
st.title("Legal Long Document Summarizer — with source mapping")
st.sidebar.header("Model & Checkpoint")
st.sidebar.write(f"**Base LED model**: `{DEFAULT_LED_MODEL}`")
st.sidebar.write(f"**Checkpoint (HF Hub)**: `{HF_REPO_ID}/{HF_CHECKPOINT_FILE}`")
st.sidebar.write("Device: " + DEVICE)
max_input_len = st.sidebar.number_input(
"LED max input tokens",
value=DEFAULT_MAX_INPUT_LEN,
step=512,
)
beam = st.sidebar.number_input("num_beams (generate)", value=DEFAULT_BEAMS, step=1)
max_target_len = st.sidebar.number_input(
"max_target_len",
value=DEFAULT_MAX_TARGET_LEN,
step=16,
)
st.sidebar.markdown("---")
st.sidebar.header("Input options")
use_naturalized = st.sidebar.checkbox(
"Build naturalized condensed input", value=False
)
show_condensed = st.sidebar.checkbox("Show condensed input", value=True)
st.sidebar.markdown("---")
st.sidebar.header("Citations / Mapping")
sim_threshold = st.sidebar.slider(
"Similarity threshold", 0.5, 0.99, SIM_DEFAULT, step=0.01
)
topk_sources = st.sidebar.slider(
"Top-K sources per sentence", 1, 10, TOP_K_SOURCES
)
# Main input
st.subheader("Document input")
raw_text = st.text_area(
"Paste your long document text here (or small text for testing).",
height=360,
)
if not raw_text:
st.info("Paste a document above to get started.")
# Controls
col1, col2 = st.columns([1, 3])
with col1:
if st.button("Load LED model + checkpoint from HF Hub"):
st.session_state["loaded"] = False
st.session_state["loaded_led"] = False
try:
tokenizer, led_model = load_led_model_and_tokenizer(
DEFAULT_LED_MODEL, device=DEVICE
)
st.session_state["tokenizer"] = tokenizer
st.session_state["led_model"] = led_model
st.success("✅ Loaded HF LED base model and tokenizer.")
ckpt_path = get_checkpoint_path_from_hub()
if ckpt_path:
loaded = load_checkpoint_weights_into_led(ckpt_path, led_model)
if loaded:
st.session_state["loaded_led"] = True
st.success("✅ Loaded fine-tuned checkpoint from HF Hub.")
st.session_state["loaded"] = True
except Exception as e:
st.error(f"Failed to load LED model/tokenizer or checkpoint: {e}")
with col2:
run_generate = st.button("Generate Summary")
# Generation step
if run_generate:
if "led_model" not in st.session_state:
st.error(
"Model not loaded. Click 'Load LED model + checkpoint from HF Hub' first."
)
elif not raw_text or raw_text.strip() == "":
st.error("Please paste some input text in the document input area.")
else:
tokenizer = st.session_state["tokenizer"]
led_model = st.session_state["led_model"]
# Build condensed input
if use_naturalized:
condensed = build_condensed_natural_from_text(raw_text, facts=None)
else:
condensed = raw_text.strip()
# Tokenize
enc = tokenizer(
[condensed],
truncation=True,
padding="longest",
max_length=int(max_input_len),
return_tensors="pt",
)
input_ids = enc["input_ids"].to(DEVICE)
attention_mask = enc["attention_mask"].to(DEVICE)
# Global attention mask
header_texts = ["Relevant facts:", "Important paragraphs:", "Please write"]
global_attn = build_global_attention_mask_for_headers(
tokenizer,
input_ids,
header_texts,
device=DEVICE,
)
# Generate
try:
led_model.eval()
with torch.no_grad():
gen_ids = led_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
global_attention_mask=global_attn,
num_beams=int(beam),
max_length=int(max_target_len),
no_repeat_ngram_size=3,
length_penalty=1.2,
early_stopping=True,
)
preds = [
tokenizer.decode(
g,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
for g in gen_ids
]
pred = preds[0] if preds else ""
except Exception as e:
st.error(f"Generation failed: {e}")
pred = ""
# Show outputs
st.markdown("### Generated summary")
st.write(pred)
st.markdown("### Stats")
st.write(
{
"input_token_count": int(input_ids.size(1)),
"pred_token_count": len(tokenizer.encode(pred)),
}
)
if show_condensed:
st.markdown("### Condensed input used (truncated to 2000 chars)")
st.code(
textwrap.shorten(
condensed, width=2000, placeholder="... [TRUNCATED]"
),
language="text",
)
# -----------------------------
# Sentence-to-source mapping
# -----------------------------
if raw_text and pred:
st.markdown("---")
st.markdown("## 🔗 Sentence-to-Source Mapping")
# Split sentences
source_sents = split_sentences(raw_text)
gen_sents = split_sentences(pred)
if not source_sents:
st.info("Could not split the source into sentences.")
elif not gen_sents:
st.info("Could not split the generated summary into sentences.")
else:
# Build FAISS index over source sentences
index, dim = build_faiss_index(source_sents)
mappings = map_generated_to_sources(
gen_sents, index, source_sents, k=int(topk_sources)
)
# Render per-sentence with tags
for i, sent in enumerate(gen_sents, start=1):
hits = mappings[i - 1] if i - 1 < len(mappings) else []
strong = [
(idx, sim, s)
for (idx, sim, s) in hits
if sim >= float(sim_threshold)
]
tag = (
"[EXTRACTIVE]"
if strong and strong[0][1] >= 0.995
else ("[SUPPORTED]" if strong else "[UNSUPPORTED]")
)
with st.expander(f"{i}. {sent} {tag}"):
if strong:
for rank, (idx, sim, src) in enumerate(
strong, start=1
):
st.markdown(
f"**Source #{rank}** · line **{idx+1}** · sim **{sim:.3f}**"
)
st.write(src)
st.markdown("---")
else:
# show top-1 anyway to help debugging
if hits:
idx, sim, src = hits[0]
st.info(
f"No hit above threshold {sim_threshold:.2f}. Closest:"
)
st.markdown(
f"**Closest** · line **{idx+1}** · sim **{sim:.3f}**"
)
st.write(src)
else:
st.info("No close source sentence found.")