# app.py โ€” Streamlit Legal LED Summarizer with Source Mapping (HF Hub-ready) # # Run locally: # streamlit run app.py # # On Hugging Face Spaces: # Put this file + requirements.txt in the Space repo. # # It will: # - Download your fine-tuned LED checkpoint from HF Hub # - Run summarization # - Map generated sentences back to source sentences via LegalBERT + FAISS import os import re import textwrap from typing import List, Tuple import streamlit as st import torch import numpy as np from transformers import ( LEDTokenizerFast, LEDForConditionalGeneration, AutoTokenizer, AutoModel, ) from huggingface_hub import hf_hub_download # Avoid OpenMP duplicate errors in some environments os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # ----------------------------- # CONFIG # ----------------------------- DEFAULT_LED_MODEL = "allenai/led-base-16384" DEFAULT_MAX_INPUT_LEN = 4096 DEFAULT_BEAMS = 5 DEFAULT_MAX_TARGET_LEN = 512 # Mapping defaults LEGALBERT_NAME = "nlpaueb/legal-bert-base-uncased" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SIM_DEFAULT = 0.85 # similarity threshold to call a sentence SUPPORTED TOP_K_SOURCES = 3 # how many source sentences to show per generated sentence # ----------------------------- # HF Hub checkpoint config # ----------------------------- # ๐Ÿ” Change these to your actual model repo + filename HF_REPO_ID = "samraatd/legal-longdoc-summarization" HF_CHECKPOINT_FILE = "checkpoint_epoch_50.pt" def get_checkpoint_path_from_hub() -> str: """ Download the fine-tuned LED checkpoint from Hugging Face Hub and return the local file path. """ try: ckpt_path = hf_hub_download( repo_id=HF_REPO_ID, filename=HF_CHECKPOINT_FILE, ) return ckpt_path except Exception as e: st.error(f"โŒ Failed to download checkpoint from Hugging Face Hub: {e}") return "" # ----------------------------- # Caches / Loads # ----------------------------- @st.cache_resource(show_spinner=False) def load_led_model_and_tokenizer(model_name=DEFAULT_LED_MODEL, device=DEVICE): tokenizer = LEDTokenizerFast.from_pretrained(model_name) model = LEDForConditionalGeneration.from_pretrained(model_name).to(device) return tokenizer, model def load_checkpoint_weights_into_led(checkpoint_path, led_model): if not checkpoint_path or not os.path.exists(checkpoint_path): st.warning(f"Checkpoint not found at: {checkpoint_path}") return {} ck = torch.load(checkpoint_path, map_location="cpu") loaded = {} # Try common keys first for keyname in ("led_state", "led_state_dict", "led_model", "led"): if keyname in ck: try: led_model.load_state_dict(ck[keyname], strict=False) loaded["led"] = keyname st.info(f"Loaded LED weights from checkpoint key: '{keyname}'") except Exception as e: st.warning(f"Failed to load LED weights from key '{keyname}': {e}") # Fallback: scan for a dict-like that overlaps with model state dict if "led" not in loaded: for k, v in ck.items(): if isinstance(v, dict) and set(v.keys()) & set(led_model.state_dict().keys()): try: led_model.load_state_dict(v, strict=False) loaded["led"] = k st.info(f"Loaded LED weights from checkpoint top-level key: '{k}'") break except Exception: pass if "led" not in loaded: st.warning("Could not find LED weights key in checkpoint. Using base HF LED.") return loaded # ----------------------------- # Input building (original) # ----------------------------- def build_condensed_natural_from_text( raw_text, max_chars=20000, facts=None, max_facts=8, max_chunks=12, ): text = raw_text.strip() if not text: return "[NO_INPUT_TEXT_PROVIDED]" if len(text) > max_chars: text = text[:max_chars] + "\n\n[TRUNCATED]" # Facts if facts: enumerated = "\n".join([f"{i+1}. {f}" for i, f in enumerate(facts[:max_facts])]) facts_part = f"Relevant facts:\n{enumerated}\n" else: sentences = [s.strip() for s in text.replace("\n", " ").split(".") if s.strip()] top_facts = sentences[:max_facts] enumerated = "\n".join([f"{i+1}. {s}" for i, s in enumerate(top_facts)]) facts_part = ( f"Relevant facts:\n{enumerated}\n" if enumerated else "Relevant facts:\n\n" ) # Chunks = paragraphs paras = [p.strip() for p in text.split("\n\n") if p.strip()] if not paras: paras = [" ".join(text.split(".")[:5])] paras = paras[:max_chunks] para_lines = [] for i, p in enumerate(paras): head = f"- Paragraph {i+1}: " content = p if len(p) < 1200 else (p[:1200] + " [TRUNCATED]") para_lines.append(head + content) chunks_part = "Important paragraphs:\n" + "\n".join(para_lines) + "\n" instruction = "\nPlease write a concise, professional summary in fluent English (3-5 sentences)." combined = facts_part + "\n" + chunks_part + "\n" + instruction return combined def find_subsequence_indices(seq, sub): if len(sub) == 0 or len(seq) < len(sub): return [] res = [] Ls = len(sub) for i in range(len(seq) - Ls + 1): if seq[i : i + Ls] == sub: res.append(i) return res def build_global_attention_mask_for_headers(tokenizer, input_ids_batch, header_texts, device): if isinstance(input_ids_batch, torch.Tensor): input_ids = input_ids_batch.cpu().tolist() else: input_ids = [list(map(int, row)) for row in input_ids_batch] B = len(input_ids) T = max(len(r) for r in input_ids) gmask = [[0] * T for _ in range(B)] header_token_seqs = [] for h in header_texts: if not h: header_token_seqs.append([]) continue enc = tokenizer(h, add_special_tokens=False, truncation=True, return_tensors=None) header_token_seqs.append(enc["input_ids"]) for b, seq in enumerate(input_ids): L = len(seq) if L > 0: gmask[b][0] = 1 for hseq in header_token_seqs: if not hseq: continue starts = find_subsequence_indices(seq, hseq) for s in starts: for offs in range(len(hseq)): idx = s + offs if idx < T: gmask[b][idx] = 1 return torch.tensor(gmask, dtype=torch.long, device=device) # ----------------------------- # Source mapping helpers # ----------------------------- SENT_SPLIT_REGEX = re.compile( r"(?<=[.!?])\s+(?=[A-Z(\[]|\d+\.|\โ€ข|\-)" ) def split_sentences(text: str) -> List[str]: parts = [s.strip() for s in SENT_SPLIT_REGEX.split(text) if s and s.strip()] return [s for s in parts if len(s) > 1] @st.cache_resource(show_spinner=False) def load_legalbert(name: str = LEGALBERT_NAME): tok = AutoTokenizer.from_pretrained(name) mdl = AutoModel.from_pretrained(name) mdl.to(DEVICE) mdl.eval() return tok, mdl def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: mask = attention_mask.unsqueeze(-1) summed = (last_hidden_state * mask).sum(dim=1) counts = mask.sum(dim=1).clamp(min=1) return summed / counts def embed_texts_legalbert(texts: List[str], batch_size: int = 16) -> np.ndarray: Tok, Mdl = load_legalbert() vecs = [] with torch.no_grad(): for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] enc = Tok( batch, padding=True, truncation=True, max_length=512, return_tensors="pt", ).to(DEVICE) out = Mdl(**enc).last_hidden_state mean = ( mean_pool(out, enc["attention_mask"]) .detach() .cpu() .numpy() .astype("float32") ) vecs.append(mean) return np.vstack(vecs) if vecs else np.zeros((0, 768), dtype="float32") def build_faiss_index(sentences: List[str]): try: import faiss except Exception: st.error("FAISS is required. Install with `pip install faiss-cpu`.") st.stop() embs = embed_texts_legalbert(sentences) faiss.normalize_L2(embs) index = faiss.IndexFlatIP(embs.shape[1]) index.add(embs) return index, embs.shape[1] def map_generated_to_sources( gen_sents: List[str], index, source_sents: List[str], k: int = TOP_K_SOURCES, ): try: import faiss except Exception: st.error("FAISS is required. Install with `pip install faiss-cpu`.") st.stop() if not gen_sents: return [] embs = embed_texts_legalbert(gen_sents) faiss.normalize_L2(embs) D, I = index.search(embs, k) results = [] for i in range(len(gen_sents)): triples = [] for idx, sim in zip(I[i], D[i]): if 0 <= idx < len(source_sents): triples.append((int(idx), float(sim), source_sents[idx])) results.append(triples) return results # ----------------------------- # STREAMLIT UI # ----------------------------- st.set_page_config(page_title="LDS - Validation-style Summarizer", layout="wide") st.title("Legal Long Document Summarizer โ€” with source mapping") st.sidebar.header("Model & Checkpoint") st.sidebar.write(f"**Base LED model**: `{DEFAULT_LED_MODEL}`") st.sidebar.write(f"**Checkpoint (HF Hub)**: `{HF_REPO_ID}/{HF_CHECKPOINT_FILE}`") st.sidebar.write("Device: " + DEVICE) max_input_len = st.sidebar.number_input( "LED max input tokens", value=DEFAULT_MAX_INPUT_LEN, step=512, ) beam = st.sidebar.number_input("num_beams (generate)", value=DEFAULT_BEAMS, step=1) max_target_len = st.sidebar.number_input( "max_target_len", value=DEFAULT_MAX_TARGET_LEN, step=16, ) st.sidebar.markdown("---") st.sidebar.header("Input options") use_naturalized = st.sidebar.checkbox( "Build naturalized condensed input", value=False ) show_condensed = st.sidebar.checkbox("Show condensed input", value=True) st.sidebar.markdown("---") st.sidebar.header("Citations / Mapping") sim_threshold = st.sidebar.slider( "Similarity threshold", 0.5, 0.99, SIM_DEFAULT, step=0.01 ) topk_sources = st.sidebar.slider( "Top-K sources per sentence", 1, 10, TOP_K_SOURCES ) # Main input st.subheader("Document input") raw_text = st.text_area( "Paste your long document text here (or small text for testing).", height=360, ) if not raw_text: st.info("Paste a document above to get started.") # Controls col1, col2 = st.columns([1, 3]) with col1: if st.button("Load LED model + checkpoint from HF Hub"): st.session_state["loaded"] = False st.session_state["loaded_led"] = False try: tokenizer, led_model = load_led_model_and_tokenizer( DEFAULT_LED_MODEL, device=DEVICE ) st.session_state["tokenizer"] = tokenizer st.session_state["led_model"] = led_model st.success("โœ… Loaded HF LED base model and tokenizer.") ckpt_path = get_checkpoint_path_from_hub() if ckpt_path: loaded = load_checkpoint_weights_into_led(ckpt_path, led_model) if loaded: st.session_state["loaded_led"] = True st.success("โœ… Loaded fine-tuned checkpoint from HF Hub.") st.session_state["loaded"] = True except Exception as e: st.error(f"Failed to load LED model/tokenizer or checkpoint: {e}") with col2: run_generate = st.button("Generate Summary") # Generation step if run_generate: if "led_model" not in st.session_state: st.error( "Model not loaded. Click 'Load LED model + checkpoint from HF Hub' first." ) elif not raw_text or raw_text.strip() == "": st.error("Please paste some input text in the document input area.") else: tokenizer = st.session_state["tokenizer"] led_model = st.session_state["led_model"] # Build condensed input if use_naturalized: condensed = build_condensed_natural_from_text(raw_text, facts=None) else: condensed = raw_text.strip() # Tokenize enc = tokenizer( [condensed], truncation=True, padding="longest", max_length=int(max_input_len), return_tensors="pt", ) input_ids = enc["input_ids"].to(DEVICE) attention_mask = enc["attention_mask"].to(DEVICE) # Global attention mask header_texts = ["Relevant facts:", "Important paragraphs:", "Please write"] global_attn = build_global_attention_mask_for_headers( tokenizer, input_ids, header_texts, device=DEVICE, ) # Generate try: led_model.eval() with torch.no_grad(): gen_ids = led_model.generate( input_ids=input_ids, attention_mask=attention_mask, global_attention_mask=global_attn, num_beams=int(beam), max_length=int(max_target_len), no_repeat_ngram_size=3, length_penalty=1.2, early_stopping=True, ) preds = [ tokenizer.decode( g, skip_special_tokens=True, clean_up_tokenization_spaces=True, ) for g in gen_ids ] pred = preds[0] if preds else "" except Exception as e: st.error(f"Generation failed: {e}") pred = "" # Show outputs st.markdown("### Generated summary") st.write(pred) st.markdown("### Stats") st.write( { "input_token_count": int(input_ids.size(1)), "pred_token_count": len(tokenizer.encode(pred)), } ) if show_condensed: st.markdown("### Condensed input used (truncated to 2000 chars)") st.code( textwrap.shorten( condensed, width=2000, placeholder="... [TRUNCATED]" ), language="text", ) # ----------------------------- # Sentence-to-source mapping # ----------------------------- if raw_text and pred: st.markdown("---") st.markdown("## ๐Ÿ”— Sentence-to-Source Mapping") # Split sentences source_sents = split_sentences(raw_text) gen_sents = split_sentences(pred) if not source_sents: st.info("Could not split the source into sentences.") elif not gen_sents: st.info("Could not split the generated summary into sentences.") else: # Build FAISS index over source sentences index, dim = build_faiss_index(source_sents) mappings = map_generated_to_sources( gen_sents, index, source_sents, k=int(topk_sources) ) # Render per-sentence with tags for i, sent in enumerate(gen_sents, start=1): hits = mappings[i - 1] if i - 1 < len(mappings) else [] strong = [ (idx, sim, s) for (idx, sim, s) in hits if sim >= float(sim_threshold) ] tag = ( "[EXTRACTIVE]" if strong and strong[0][1] >= 0.995 else ("[SUPPORTED]" if strong else "[UNSUPPORTED]") ) with st.expander(f"{i}. {sent} {tag}"): if strong: for rank, (idx, sim, src) in enumerate( strong, start=1 ): st.markdown( f"**Source #{rank}** ยท line **{idx+1}** ยท sim **{sim:.3f}**" ) st.write(src) st.markdown("---") else: # show top-1 anyway to help debugging if hits: idx, sim, src = hits[0] st.info( f"No hit above threshold {sim_threshold:.2f}. Closest:" ) st.markdown( f"**Closest** ยท line **{idx+1}** ยท sim **{sim:.3f}**" ) st.write(src) else: st.info("No close source sentence found.")