import streamlit as st import torch import torch.nn as nn from transformers import DebertaV2Model, DebertaV2TokenizerFast, DebertaV2Config, AutoTokenizer from pathlib import Path import numpy as np import json import logging from dataclasses import dataclass from typing import Optional, Dict, List, Tuple from tqdm import tqdm from skimage.filters import threshold_otsu # ---------------------------------- # Logging # ---------------------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ---------------------------------- # Config / Model # ---------------------------------- @dataclass class TrainingConfig: """Training configuration for link token classification""" model_name: str = "microsoft/deberta-v3-large" num_labels: int = 2 # 0: not link, 1: link token # Inference windowing max_length: int = 512 doc_stride: int = 128 # match _prep.py for consistent windowing # Train-only placeholders train_file: str = "" val_file: str = "" batch_size: int = 1 gradient_accumulation_steps: int = 1 num_epochs: int = 1 learning_rate: float = 1e-5 warmup_ratio: float = 0.1 weight_decay: float = 0.01 max_grad_norm: float = 1.0 label_smoothing: float = 0.0 device: str = "cuda" if torch.cuda.is_available() else "cpu" num_workers: int = 0 bf16: bool = False seed: int = 42 logging_steps: int = 1 eval_steps: int = 100 save_steps: int = 100 output_dir: str = "./deberta_link_output" # model is loaded from here wandb_project: str = "" wandb_name: str = "" patience: int = 2 min_delta: float = 0.0001 class DeBERTaForTokenClassification(nn.Module): """DeBERTa model for token classification""" def __init__(self, model_name: str, num_labels: int, dropout_rate: float = 0.1): super().__init__() self.config = DebertaV2Config.from_pretrained(model_name) self.deberta = DebertaV2Model.from_pretrained(model_name) self.dropout = nn.Dropout(dropout_rate) self.classifier = nn.Linear(self.config.hidden_size, num_labels) nn.init.xavier_uniform_(self.classifier.weight) nn.init.zeros_(self.classifier.bias) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None ) -> Dict[str, torch.Tensor]: outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask) sequence_output = self.dropout(outputs.last_hidden_state) logits = self.classifier(sequence_output) return {'loss': None, 'logits': logits} # ---------------------------------- # Load model/tokenizer (robust) # ---------------------------------- @st.cache_resource def load_model(): """Loads pre-trained model and tokenizer. Handles raw state_dict and wrapped checkpoints.""" config = TrainingConfig() final_dir = Path(config.output_dir) / "final_model" model_path = final_dir / "pytorch_model.bin" if not model_path.exists(): st.error(f"Model checkpoint not found at {model_path}.") st.stop() logger.info(f"Loading model from {model_path}...") model = DeBERTaForTokenClassification(config.model_name, config.num_labels) # Load checkpoint robustly try: checkpoint = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False) except TypeError: checkpoint = torch.load(model_path, map_location=torch.device('cpu')) # Determine state_dict state_dict = None if isinstance(checkpoint, dict): # Case A: raw state_dict (keys -> tensors) if checkpoint and all(isinstance(v, torch.Tensor) for v in checkpoint.values()): state_dict = checkpoint logger.info("Detected raw state_dict checkpoint.") # Case B: wrapped dicts elif 'model_state_dict' in checkpoint and isinstance(checkpoint['model_state_dict'], dict): state_dict = checkpoint['model_state_dict'] logger.info("Detected 'model_state_dict' in checkpoint.") elif 'state_dict' in checkpoint and isinstance(checkpoint['state_dict'], dict): state_dict = checkpoint['state_dict'] logger.info("Detected 'state_dict' in checkpoint.") else: raise KeyError(f"Unrecognized checkpoint format keys: {list(checkpoint.keys())}") else: raise TypeError(f"Unexpected checkpoint type: {type(checkpoint)}") missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: logger.warning(f"Missing keys: {missing}") if unexpected: logger.warning(f"Unexpected keys: {unexpected}") model.to(config.device) model.eval() logger.info(f"Loading tokenizer {config.model_name}...") tokenizer = DebertaV2TokenizerFast.from_pretrained(config.model_name) logger.info("Tokenizer loaded.") return model, tokenizer, config.device, config.max_length, config.doc_stride model, tokenizer, device, MAX_LENGTH, DOC_STRIDE = load_model() # ---------------------------------- # Inference helpers # ---------------------------------- def windowize_inference( plain_text: str, tokenizer: AutoTokenizer, max_length: int, doc_stride: int ) -> List[Dict]: """Slice long text into overlapping windows for inference.""" specials = tokenizer.num_special_tokens_to_add(pair=False) cap = max_length - specials if cap <= 0: raise ValueError(f"max_length too small; specials={specials}") full_encoding = tokenizer( plain_text, add_special_tokens=False, return_offsets_mapping=True, return_attention_mask=False, return_token_type_ids=False, truncation=False, ) input_ids_no_special = full_encoding["input_ids"] offsets_no_special = full_encoding["offset_mapping"] temp_encoding_for_word_ids = tokenizer( plain_text, return_offsets_mapping=True, truncation=False, padding=False ) full_word_ids = temp_encoding_for_word_ids.word_ids(batch_index=0) windows_data = [] step = max(cap - doc_stride, 1) start_token_idx = 0 total_tokens_no_special = len(input_ids_no_special) while start_token_idx < total_tokens_no_special: end_token_idx = min(start_token_idx + cap, total_tokens_no_special) ids_slice_no_special = input_ids_no_special[start_token_idx:end_token_idx] offsets_slice_no_special = offsets_no_special[start_token_idx:end_token_idx] word_ids_slice = full_word_ids[start_token_idx:end_token_idx] input_ids_with_special = tokenizer.build_inputs_with_special_tokens(ids_slice_no_special) attention_mask_with_special = [1] * len(input_ids_with_special) padding_length = max_length - len(input_ids_with_special) if padding_length > 0: input_ids_with_special.extend([tokenizer.pad_token_id] * padding_length) attention_mask_with_special.extend([0] * padding_length) window_offset_mapping = offsets_slice_no_special[:] window_word_ids = word_ids_slice[:] if tokenizer.cls_token_id is not None: window_offset_mapping.insert(0, (0, 0)) window_word_ids.insert(0, None) if tokenizer.sep_token_id is not None and len(window_offset_mapping) < max_length: window_offset_mapping.append((0, 0)) window_word_ids.append(None) while len(window_offset_mapping) < max_length: window_offset_mapping.append((0, 0)) window_word_ids.append(None) windows_data.append({ "input_ids": torch.tensor(input_ids_with_special, dtype=torch.long), "attention_mask": torch.tensor(attention_mask_with_special, dtype=torch.long), "word_ids": window_word_ids, "offset_mapping": window_offset_mapping, }) if end_token_idx == total_tokens_no_special: break start_token_idx += step return windows_data def classify_text( text: str, otsu_mode: str, prediction_threshold_override: Optional[float] = None ) -> Tuple[str, Optional[str], Optional[float]]: """Classify link tokens with windowing. Returns (html, warning, threshold%).""" if not text.strip(): return "", None, None windows = windowize_inference(text, tokenizer, MAX_LENGTH, DOC_STRIDE) if not windows: return "", "Could not generate any windows for processing.", None char_link_probabilities = np.zeros(len(text), dtype=np.float32) char_covered = np.zeros(len(text), dtype=bool) all_content_token_probs = [] with torch.no_grad(): for window in tqdm(windows, desc="Processing windows"): inputs = { 'input_ids': window['input_ids'].unsqueeze(0).to(device), 'attention_mask': window['attention_mask'].unsqueeze(0).to(device) } outputs = model(**inputs) logits = outputs['logits'].squeeze(0) probabilities = torch.softmax(logits, dim=-1) link_probs_for_window_tokens = probabilities[:, 1].cpu().numpy() for i, (offset_start, offset_end) in enumerate(window['offset_mapping']): if window['word_ids'][i] is not None and offset_start < offset_end: char_link_probabilities[offset_start:offset_end] = np.maximum( char_link_probabilities[offset_start:offset_end], link_probs_for_window_tokens[i] ) char_covered[offset_start:offset_end] = True all_content_token_probs.append(link_probs_for_window_tokens[i]) # Threshold selection (Otsu or manual) determined_threshold_float = None determined_threshold_for_display = None # 0-100% if prediction_threshold_override is not None: determined_threshold_float = prediction_threshold_override / 100.0 determined_threshold_for_display = prediction_threshold_override else: if len(all_content_token_probs) > 1: try: otsu_base_threshold = threshold_otsu(np.array(all_content_token_probs)) conservative_delta = 0.1 # stricter generous_delta = 0.1 # more lenient if otsu_mode == 'conservative': determined_threshold_float = otsu_base_threshold + conservative_delta elif otsu_mode == 'generous': determined_threshold_float = otsu_base_threshold - generous_delta else: determined_threshold_float = otsu_base_threshold determined_threshold_float = max(0.0, min(1.0, determined_threshold_float)) determined_threshold_for_display = determined_threshold_float * 100 except ValueError: logger.warning("Otsu failed; defaulting to 0.5.") determined_threshold_float = 0.5 determined_threshold_for_display = 50.0 else: logger.warning("Insufficient tokens for Otsu; defaulting to 0.5.") determined_threshold_float = 0.5 determined_threshold_for_display = 50.0 final_threshold = determined_threshold_float # Word-level aggregation full_text_encoding = tokenizer(text, return_offsets_mapping=True, truncation=False, padding=False) full_word_ids = full_text_encoding.word_ids(batch_index=0) full_offset_mapping = full_text_encoding['offset_mapping'] word_prob_map: Dict[int, List[float]] = {} word_char_spans: Dict[int, List[int]] = {} for i, word_id in enumerate(full_word_ids): if word_id is not None: start_char, end_char = full_offset_mapping[i] if start_char < end_char and np.any(char_covered[start_char:end_char]): if word_id not in word_prob_map: word_prob_map[word_id] = [] word_char_spans[word_id] = [start_char, end_char] else: word_char_spans[word_id][0] = min(word_char_spans[word_id][0], start_char) word_char_spans[word_id][1] = max(word_char_spans[word_id][1], end_char) token_span_probs = char_link_probabilities[start_char:end_char] word_prob_map[word_id].append(np.max(token_span_probs) if token_span_probs.size > 0 else 0.0) elif word_id not in word_prob_map: word_prob_map[word_id] = [0.0] word_char_spans[word_id] = list(full_offset_mapping[i]) words_to_highlight_status: Dict[int, bool] = {} for word_id, probs in word_prob_map.items(): max_word_prob = np.max(probs) if probs else 0.0 words_to_highlight_status[word_id] = (max_word_prob >= final_threshold) # Reconstruct HTML with highlights html_output_parts: List[str] = [] current_char_idx = 0 sorted_word_ids = sorted(word_char_spans.keys(), key=lambda k: word_char_spans[k][0]) for word_id in sorted_word_ids: start_char, end_char = word_char_spans[word_id] if start_char > current_char_idx: html_output_parts.append(text[current_char_idx:start_char]) word_text = text[start_char:end_char] if words_to_highlight_status.get(word_id, False): html_output_parts.append( "" + word_text + "" ) else: html_output_parts.append(word_text) current_char_idx = end_char if current_char_idx < len(text): html_output_parts.append(text[current_char_idx:]) return "".join(html_output_parts), None, determined_threshold_for_display # ---------------------------------- # Streamlit UI # ---------------------------------- st.set_page_config(layout="wide", page_title="LinkBERT by DEJAN AI") st.title("LinkBERT") user_input = st.text_area( "Paste your text here:", "DEJAN AI is the world's leading AI SEO agency.", height=200 ) with st.expander('Settings'): auto_threshold_enabled = st.checkbox( "Automagic", value=True, help="Uncheck to set manual threshold value for link prediction." ) otsu_mode_options = ['Conservative', 'Standard', 'Generous'] selected_otsu_mode = 'Standard' if auto_threshold_enabled: selected_otsu_mode = st.radio( "Generosity:", otsu_mode_options, index=1, help="Generous suggests more links; conservative suggests fewer." ) prediction_threshold_manual = 50.0 if not auto_threshold_enabled: prediction_threshold_manual = st.slider( "Manual Link Probability Threshold (%)", min_value=0, max_value=100, value=50, step=1, help="Minimum probability to classify a token as a link when Automagic is off." ) if st.button("Classify Text"): if not user_input.strip(): st.warning("Please enter some text to classify.") else: threshold_to_pass = None if auto_threshold_enabled else prediction_threshold_manual highlighted_html, warning_message, determined_threshold_for_display = classify_text( user_input, selected_otsu_mode.lower(), threshold_to_pass ) if warning_message: st.warning(warning_message) if determined_threshold_for_display is not None and auto_threshold_enabled: st.info(f"Auto threshold: {determined_threshold_for_display:.1f}% ({selected_otsu_mode})") st.markdown(highlighted_html, unsafe_allow_html=True)