Spaces:
Running
Running
| # app.py — Streamlit Legal LED Summarizer with Source Mapping (HF Hub-ready) | |
| # | |
| # Run locally: | |
| # streamlit run app.py | |
| # | |
| # On Hugging Face Spaces: | |
| # Put this file + requirements.txt in the Space repo. | |
| # | |
| # It will: | |
| # - Download your fine-tuned LED checkpoint from HF Hub | |
| # - Run summarization | |
| # - Map generated sentences back to source sentences via LegalBERT + FAISS | |
| import os | |
| import re | |
| import textwrap | |
| from typing import List, Tuple | |
| import streamlit as st | |
| import torch | |
| import numpy as np | |
| from transformers import ( | |
| LEDTokenizerFast, | |
| LEDForConditionalGeneration, | |
| AutoTokenizer, | |
| AutoModel, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| # Avoid OpenMP duplicate errors in some environments | |
| os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | |
| # ----------------------------- | |
| # CONFIG | |
| # ----------------------------- | |
| DEFAULT_LED_MODEL = "allenai/led-base-16384" | |
| DEFAULT_MAX_INPUT_LEN = 4096 | |
| DEFAULT_BEAMS = 5 | |
| DEFAULT_MAX_TARGET_LEN = 512 | |
| # Mapping defaults | |
| LEGALBERT_NAME = "nlpaueb/legal-bert-base-uncased" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| SIM_DEFAULT = 0.85 # similarity threshold to call a sentence SUPPORTED | |
| TOP_K_SOURCES = 3 # how many source sentences to show per generated sentence | |
| # ----------------------------- | |
| # HF Hub checkpoint config | |
| # ----------------------------- | |
| # 🔁 Change these to your actual model repo + filename | |
| HF_REPO_ID = "samraatd/legal-longdoc-summarization" | |
| HF_CHECKPOINT_FILE = "checkpoint_epoch_50.pt" | |
| def get_checkpoint_path_from_hub() -> str: | |
| """ | |
| Download the fine-tuned LED checkpoint from Hugging Face Hub and | |
| return the local file path. | |
| """ | |
| try: | |
| ckpt_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=HF_CHECKPOINT_FILE, | |
| ) | |
| return ckpt_path | |
| except Exception as e: | |
| st.error(f"❌ Failed to download checkpoint from Hugging Face Hub: {e}") | |
| return "" | |
| # ----------------------------- | |
| # Caches / Loads | |
| # ----------------------------- | |
| def load_led_model_and_tokenizer(model_name=DEFAULT_LED_MODEL, device=DEVICE): | |
| tokenizer = LEDTokenizerFast.from_pretrained(model_name) | |
| model = LEDForConditionalGeneration.from_pretrained(model_name).to(device) | |
| return tokenizer, model | |
| def load_checkpoint_weights_into_led(checkpoint_path, led_model): | |
| if not checkpoint_path or not os.path.exists(checkpoint_path): | |
| st.warning(f"Checkpoint not found at: {checkpoint_path}") | |
| return {} | |
| ck = torch.load(checkpoint_path, map_location="cpu") | |
| loaded = {} | |
| # Try common keys first | |
| for keyname in ("led_state", "led_state_dict", "led_model", "led"): | |
| if keyname in ck: | |
| try: | |
| led_model.load_state_dict(ck[keyname], strict=False) | |
| loaded["led"] = keyname | |
| st.info(f"Loaded LED weights from checkpoint key: '{keyname}'") | |
| except Exception as e: | |
| st.warning(f"Failed to load LED weights from key '{keyname}': {e}") | |
| # Fallback: scan for a dict-like that overlaps with model state dict | |
| if "led" not in loaded: | |
| for k, v in ck.items(): | |
| if isinstance(v, dict) and set(v.keys()) & set(led_model.state_dict().keys()): | |
| try: | |
| led_model.load_state_dict(v, strict=False) | |
| loaded["led"] = k | |
| st.info(f"Loaded LED weights from checkpoint top-level key: '{k}'") | |
| break | |
| except Exception: | |
| pass | |
| if "led" not in loaded: | |
| st.warning("Could not find LED weights key in checkpoint. Using base HF LED.") | |
| return loaded | |
| # ----------------------------- | |
| # Input building (original) | |
| # ----------------------------- | |
| def build_condensed_natural_from_text( | |
| raw_text, | |
| max_chars=20000, | |
| facts=None, | |
| max_facts=8, | |
| max_chunks=12, | |
| ): | |
| text = raw_text.strip() | |
| if not text: | |
| return "[NO_INPUT_TEXT_PROVIDED]" | |
| if len(text) > max_chars: | |
| text = text[:max_chars] + "\n\n[TRUNCATED]" | |
| # Facts | |
| if facts: | |
| enumerated = "\n".join([f"{i+1}. {f}" for i, f in enumerate(facts[:max_facts])]) | |
| facts_part = f"Relevant facts:\n{enumerated}\n" | |
| else: | |
| sentences = [s.strip() for s in text.replace("\n", " ").split(".") if s.strip()] | |
| top_facts = sentences[:max_facts] | |
| enumerated = "\n".join([f"{i+1}. {s}" for i, s in enumerate(top_facts)]) | |
| facts_part = ( | |
| f"Relevant facts:\n{enumerated}\n" if enumerated else "Relevant facts:\n\n" | |
| ) | |
| # Chunks = paragraphs | |
| paras = [p.strip() for p in text.split("\n\n") if p.strip()] | |
| if not paras: | |
| paras = [" ".join(text.split(".")[:5])] | |
| paras = paras[:max_chunks] | |
| para_lines = [] | |
| for i, p in enumerate(paras): | |
| head = f"- Paragraph {i+1}: " | |
| content = p if len(p) < 1200 else (p[:1200] + " [TRUNCATED]") | |
| para_lines.append(head + content) | |
| chunks_part = "Important paragraphs:\n" + "\n".join(para_lines) + "\n" | |
| instruction = "\nPlease write a concise, professional summary in fluent English (3-5 sentences)." | |
| combined = facts_part + "\n" + chunks_part + "\n" + instruction | |
| return combined | |
| def find_subsequence_indices(seq, sub): | |
| if len(sub) == 0 or len(seq) < len(sub): | |
| return [] | |
| res = [] | |
| Ls = len(sub) | |
| for i in range(len(seq) - Ls + 1): | |
| if seq[i : i + Ls] == sub: | |
| res.append(i) | |
| return res | |
| def build_global_attention_mask_for_headers(tokenizer, input_ids_batch, header_texts, device): | |
| if isinstance(input_ids_batch, torch.Tensor): | |
| input_ids = input_ids_batch.cpu().tolist() | |
| else: | |
| input_ids = [list(map(int, row)) for row in input_ids_batch] | |
| B = len(input_ids) | |
| T = max(len(r) for r in input_ids) | |
| gmask = [[0] * T for _ in range(B)] | |
| header_token_seqs = [] | |
| for h in header_texts: | |
| if not h: | |
| header_token_seqs.append([]) | |
| continue | |
| enc = tokenizer(h, add_special_tokens=False, truncation=True, return_tensors=None) | |
| header_token_seqs.append(enc["input_ids"]) | |
| for b, seq in enumerate(input_ids): | |
| L = len(seq) | |
| if L > 0: | |
| gmask[b][0] = 1 | |
| for hseq in header_token_seqs: | |
| if not hseq: | |
| continue | |
| starts = find_subsequence_indices(seq, hseq) | |
| for s in starts: | |
| for offs in range(len(hseq)): | |
| idx = s + offs | |
| if idx < T: | |
| gmask[b][idx] = 1 | |
| return torch.tensor(gmask, dtype=torch.long, device=device) | |
| # ----------------------------- | |
| # Source mapping helpers | |
| # ----------------------------- | |
| SENT_SPLIT_REGEX = re.compile( | |
| r"(?<=[.!?])\s+(?=[A-Z(\[]|\d+\.|\•|\-)" | |
| ) | |
| def split_sentences(text: str) -> List[str]: | |
| parts = [s.strip() for s in SENT_SPLIT_REGEX.split(text) if s and s.strip()] | |
| return [s for s in parts if len(s) > 1] | |
| def load_legalbert(name: str = LEGALBERT_NAME): | |
| tok = AutoTokenizer.from_pretrained(name) | |
| mdl = AutoModel.from_pretrained(name) | |
| mdl.to(DEVICE) | |
| mdl.eval() | |
| return tok, mdl | |
| def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| mask = attention_mask.unsqueeze(-1) | |
| summed = (last_hidden_state * mask).sum(dim=1) | |
| counts = mask.sum(dim=1).clamp(min=1) | |
| return summed / counts | |
| def embed_texts_legalbert(texts: List[str], batch_size: int = 16) -> np.ndarray: | |
| Tok, Mdl = load_legalbert() | |
| vecs = [] | |
| with torch.no_grad(): | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i : i + batch_size] | |
| enc = Tok( | |
| batch, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt", | |
| ).to(DEVICE) | |
| out = Mdl(**enc).last_hidden_state | |
| mean = ( | |
| mean_pool(out, enc["attention_mask"]) | |
| .detach() | |
| .cpu() | |
| .numpy() | |
| .astype("float32") | |
| ) | |
| vecs.append(mean) | |
| return np.vstack(vecs) if vecs else np.zeros((0, 768), dtype="float32") | |
| def build_faiss_index(sentences: List[str]): | |
| try: | |
| import faiss | |
| except Exception: | |
| st.error("FAISS is required. Install with `pip install faiss-cpu`.") | |
| st.stop() | |
| embs = embed_texts_legalbert(sentences) | |
| faiss.normalize_L2(embs) | |
| index = faiss.IndexFlatIP(embs.shape[1]) | |
| index.add(embs) | |
| return index, embs.shape[1] | |
| def map_generated_to_sources( | |
| gen_sents: List[str], | |
| index, | |
| source_sents: List[str], | |
| k: int = TOP_K_SOURCES, | |
| ): | |
| try: | |
| import faiss | |
| except Exception: | |
| st.error("FAISS is required. Install with `pip install faiss-cpu`.") | |
| st.stop() | |
| if not gen_sents: | |
| return [] | |
| embs = embed_texts_legalbert(gen_sents) | |
| faiss.normalize_L2(embs) | |
| D, I = index.search(embs, k) | |
| results = [] | |
| for i in range(len(gen_sents)): | |
| triples = [] | |
| for idx, sim in zip(I[i], D[i]): | |
| if 0 <= idx < len(source_sents): | |
| triples.append((int(idx), float(sim), source_sents[idx])) | |
| results.append(triples) | |
| return results | |
| # ----------------------------- | |
| # STREAMLIT UI | |
| # ----------------------------- | |
| st.set_page_config(page_title="LDS - Validation-style Summarizer", layout="wide") | |
| st.title("Legal Long Document Summarizer — with source mapping") | |
| st.sidebar.header("Model & Checkpoint") | |
| st.sidebar.write(f"**Base LED model**: `{DEFAULT_LED_MODEL}`") | |
| st.sidebar.write(f"**Checkpoint (HF Hub)**: `{HF_REPO_ID}/{HF_CHECKPOINT_FILE}`") | |
| st.sidebar.write("Device: " + DEVICE) | |
| max_input_len = st.sidebar.number_input( | |
| "LED max input tokens", | |
| value=DEFAULT_MAX_INPUT_LEN, | |
| step=512, | |
| ) | |
| beam = st.sidebar.number_input("num_beams (generate)", value=DEFAULT_BEAMS, step=1) | |
| max_target_len = st.sidebar.number_input( | |
| "max_target_len", | |
| value=DEFAULT_MAX_TARGET_LEN, | |
| step=16, | |
| ) | |
| st.sidebar.markdown("---") | |
| st.sidebar.header("Input options") | |
| use_naturalized = st.sidebar.checkbox( | |
| "Build naturalized condensed input", value=False | |
| ) | |
| show_condensed = st.sidebar.checkbox("Show condensed input", value=True) | |
| st.sidebar.markdown("---") | |
| st.sidebar.header("Citations / Mapping") | |
| sim_threshold = st.sidebar.slider( | |
| "Similarity threshold", 0.5, 0.99, SIM_DEFAULT, step=0.01 | |
| ) | |
| topk_sources = st.sidebar.slider( | |
| "Top-K sources per sentence", 1, 10, TOP_K_SOURCES | |
| ) | |
| # Main input | |
| st.subheader("Document input") | |
| raw_text = st.text_area( | |
| "Paste your long document text here (or small text for testing).", | |
| height=360, | |
| ) | |
| if not raw_text: | |
| st.info("Paste a document above to get started.") | |
| # Controls | |
| col1, col2 = st.columns([1, 3]) | |
| with col1: | |
| if st.button("Load LED model + checkpoint from HF Hub"): | |
| st.session_state["loaded"] = False | |
| st.session_state["loaded_led"] = False | |
| try: | |
| tokenizer, led_model = load_led_model_and_tokenizer( | |
| DEFAULT_LED_MODEL, device=DEVICE | |
| ) | |
| st.session_state["tokenizer"] = tokenizer | |
| st.session_state["led_model"] = led_model | |
| st.success("✅ Loaded HF LED base model and tokenizer.") | |
| ckpt_path = get_checkpoint_path_from_hub() | |
| if ckpt_path: | |
| loaded = load_checkpoint_weights_into_led(ckpt_path, led_model) | |
| if loaded: | |
| st.session_state["loaded_led"] = True | |
| st.success("✅ Loaded fine-tuned checkpoint from HF Hub.") | |
| st.session_state["loaded"] = True | |
| except Exception as e: | |
| st.error(f"Failed to load LED model/tokenizer or checkpoint: {e}") | |
| with col2: | |
| run_generate = st.button("Generate Summary") | |
| # Generation step | |
| if run_generate: | |
| if "led_model" not in st.session_state: | |
| st.error( | |
| "Model not loaded. Click 'Load LED model + checkpoint from HF Hub' first." | |
| ) | |
| elif not raw_text or raw_text.strip() == "": | |
| st.error("Please paste some input text in the document input area.") | |
| else: | |
| tokenizer = st.session_state["tokenizer"] | |
| led_model = st.session_state["led_model"] | |
| # Build condensed input | |
| if use_naturalized: | |
| condensed = build_condensed_natural_from_text(raw_text, facts=None) | |
| else: | |
| condensed = raw_text.strip() | |
| # Tokenize | |
| enc = tokenizer( | |
| [condensed], | |
| truncation=True, | |
| padding="longest", | |
| max_length=int(max_input_len), | |
| return_tensors="pt", | |
| ) | |
| input_ids = enc["input_ids"].to(DEVICE) | |
| attention_mask = enc["attention_mask"].to(DEVICE) | |
| # Global attention mask | |
| header_texts = ["Relevant facts:", "Important paragraphs:", "Please write"] | |
| global_attn = build_global_attention_mask_for_headers( | |
| tokenizer, | |
| input_ids, | |
| header_texts, | |
| device=DEVICE, | |
| ) | |
| # Generate | |
| try: | |
| led_model.eval() | |
| with torch.no_grad(): | |
| gen_ids = led_model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| global_attention_mask=global_attn, | |
| num_beams=int(beam), | |
| max_length=int(max_target_len), | |
| no_repeat_ngram_size=3, | |
| length_penalty=1.2, | |
| early_stopping=True, | |
| ) | |
| preds = [ | |
| tokenizer.decode( | |
| g, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| for g in gen_ids | |
| ] | |
| pred = preds[0] if preds else "" | |
| except Exception as e: | |
| st.error(f"Generation failed: {e}") | |
| pred = "" | |
| # Show outputs | |
| st.markdown("### Generated summary") | |
| st.write(pred) | |
| st.markdown("### Stats") | |
| st.write( | |
| { | |
| "input_token_count": int(input_ids.size(1)), | |
| "pred_token_count": len(tokenizer.encode(pred)), | |
| } | |
| ) | |
| if show_condensed: | |
| st.markdown("### Condensed input used (truncated to 2000 chars)") | |
| st.code( | |
| textwrap.shorten( | |
| condensed, width=2000, placeholder="... [TRUNCATED]" | |
| ), | |
| language="text", | |
| ) | |
| # ----------------------------- | |
| # Sentence-to-source mapping | |
| # ----------------------------- | |
| if raw_text and pred: | |
| st.markdown("---") | |
| st.markdown("## 🔗 Sentence-to-Source Mapping") | |
| # Split sentences | |
| source_sents = split_sentences(raw_text) | |
| gen_sents = split_sentences(pred) | |
| if not source_sents: | |
| st.info("Could not split the source into sentences.") | |
| elif not gen_sents: | |
| st.info("Could not split the generated summary into sentences.") | |
| else: | |
| # Build FAISS index over source sentences | |
| index, dim = build_faiss_index(source_sents) | |
| mappings = map_generated_to_sources( | |
| gen_sents, index, source_sents, k=int(topk_sources) | |
| ) | |
| # Render per-sentence with tags | |
| for i, sent in enumerate(gen_sents, start=1): | |
| hits = mappings[i - 1] if i - 1 < len(mappings) else [] | |
| strong = [ | |
| (idx, sim, s) | |
| for (idx, sim, s) in hits | |
| if sim >= float(sim_threshold) | |
| ] | |
| tag = ( | |
| "[EXTRACTIVE]" | |
| if strong and strong[0][1] >= 0.995 | |
| else ("[SUPPORTED]" if strong else "[UNSUPPORTED]") | |
| ) | |
| with st.expander(f"{i}. {sent} {tag}"): | |
| if strong: | |
| for rank, (idx, sim, src) in enumerate( | |
| strong, start=1 | |
| ): | |
| st.markdown( | |
| f"**Source #{rank}** · line **{idx+1}** · sim **{sim:.3f}**" | |
| ) | |
| st.write(src) | |
| st.markdown("---") | |
| else: | |
| # show top-1 anyway to help debugging | |
| if hits: | |
| idx, sim, src = hits[0] | |
| st.info( | |
| f"No hit above threshold {sim_threshold:.2f}. Closest:" | |
| ) | |
| st.markdown( | |
| f"**Closest** · line **{idx+1}** · sim **{sim:.3f}**" | |
| ) | |
| st.write(src) | |
| else: | |
| st.info("No close source sentence found.") | |