# filepath: src/model/inference.py from transformers import AutoTokenizer, AutoModelForTokenClassification from peft import PeftModel, PeftConfig import torch import torch.nn.functional as F from Idiom_lexicon import KNOWN_IDIOMS import fitz # PyMuPDF import tempfile from PIL import Image import pytesseract import nltk import spacy import json from pathlib import Path from fastapi import HTTPException nltk.download('punkt_tab', quiet=True) from langdetect import detect from nltk.tokenize import sent_tokenize import re LANG_MAP = { 'en': 'english', 'es': 'spanish', # add more if needed } def split_text_by_language(text, language: str): # Map input language (e.g., 'en', 'es') to NLTK language codes nltk_lang = LANG_MAP.get(language.lower(), 'english') sentences = sent_tokenize(text, language=nltk_lang) return sentences def load_model(checkpoint_path): config = PeftConfig.from_pretrained(checkpoint_path) base_model = AutoModelForTokenClassification.from_pretrained( config.base_model_name_or_path, num_labels=3 # O, B-IDIOM, I-IDIOM ) model = PeftModel.from_pretrained(base_model, checkpoint_path) tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased") return model, tokenizer def normalize_text(text): # Join hyphenated words split across lines text = re.sub(r'-\s*\n\s*', '', text) # Replace newlines with spaces text = re.sub(r'\n+', ' ', text) # Collapse multiple spaces into one text = re.sub(r'\s+', ' ', text) return text.strip() def filter_idioms(candidate_idioms, known_idioms, min_len=2): filtered = [] for idiom in candidate_idioms: norm = idiom.lower().strip() if norm in known_idioms or len(norm.split()) >= min_len: filtered.append(idiom) return filtered import spacy import json from pathlib import Path class IdiomMatcher: def __init__(self, idiom_files: dict[str, str]): self.models = { "en": spacy.load("en_core_web_sm"), "es": spacy.load("es_core_news_sm"), } self.idioms_by_lang = {lang: [] for lang in idiom_files} self._load_idioms(idiom_files) def _lemmatize(self, text: str, lang: str) -> str: doc = self.models[lang](text) return " ".join(token.lemma_ for token in doc) def _load_idioms(self, idiom_files: dict[str, str]): for lang, file_path in idiom_files.items(): path = Path(file_path) if not path.exists(): raise FileNotFoundError(f"Idiom file not found for {lang}: {file_path}") with open(path, "r", encoding="utf-8") as f: for line in f: entry = json.loads(line) idiom_text = entry.get("idiom", "").strip() if not idiom_text: continue entry["lemmatized"] = self._lemmatize(idiom_text, lang) self.idioms_by_lang[lang].append(entry) def match(self, sentence: str, lang: str): if lang not in self.models: raise ValueError(f"Unsupported language: {lang}") sent_lemma = self._lemmatize(sentence, lang) return [ idiom for idiom in self.idioms_by_lang[lang] if idiom["lemmatized"] in sent_lemma ] def predict_idiom(text, model, tokenizer, device, conf_threshold=0.9): words = text.split() if not words: print("[⚠️] Empty input text") return [] inputs = tokenizer( words, is_split_into_words=True, truncation=True, padding=True, max_length=128, return_tensors="pt" ).to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = F.softmax(logits, dim=-1) max_probs, predictions = torch.max(probs, dim=-1) max_probs = max_probs.cpu().numpy()[0] predictions = predictions.cpu().numpy()[0] word_ids = inputs.word_ids(batch_index=0) idioms = [] current_idiom_start = -1 current_idiom_end = -1 for i, (pred_label, conf, word_idx) in enumerate(zip(predictions, max_probs, word_ids)): if word_idx is None: if current_idiom_start != -1: idioms.append(' '.join(words[current_idiom_start:current_idiom_end + 1])) current_idiom_start = -1 current_idiom_end = -1 continue if conf < conf_threshold: pred_label = 0 if pred_label == 1: # B-IDIOM if current_idiom_start != -1: idioms.append(' '.join(words[current_idiom_start:current_idiom_end + 1])) current_idiom_start = word_idx current_idiom_end = word_idx elif pred_label == 2: # I-IDIOM if current_idiom_start != -1 and (word_idx == current_idiom_end or word_idx == current_idiom_end + 1): current_idiom_end = word_idx else: if current_idiom_start != -1: idioms.append(' '.join(words[current_idiom_start:current_idiom_end + 1])) current_idiom_start = -1 current_idiom_end = -1 else: # O if current_idiom_start != -1: idioms.append(' '.join(words[current_idiom_start:current_idiom_end + 1])) current_idiom_start = -1 current_idiom_end = -1 if current_idiom_start != -1: idioms.append(' '.join(words[current_idiom_start:current_idiom_end + 1])) idioms = filter_idioms(idioms, known_idioms=KNOWN_IDIOMS) return idioms import pdfplumber def extract_text_from_pdf(pdf_bytes: bytes) -> str: with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: tmp.write(pdf_bytes) tmp_path = tmp.name doc = fitz.open(tmp_path) text = "" for i, page in enumerate(doc): page_text = page.get_text() print(f"[DEBUG] Page {i+1} extracted text (first 100 chars): {repr(page_text[:100])}") text += page_text doc.close() text = normalize_text(text) print("[DEBUG] Cleaned extracted text from PDF (first 500 chars):", repr(text[:500])) if not text: print("[⚠️] No text extracted from PDF. It may be blank or not readable.") return text def reconstruct_words(tokens, labels): """ Reconstruct words from BERT tokens and their corresponding labels. This function is used to map the BERT token predictions back to the original words. """ words = [] current_word = [] current_label = None for token, label in zip(tokens, labels): if label == 'O': if current_word: words.append(''.join(current_word)) current_word = [] continue if label.startswith('B-'): if current_word: words.append(''.join(current_word)) current_word = [] current_label = label[2:] # Get the idiom type current_word.append(token) elif label.startswith('I-') and current_label == label[2:]: current_word.append(token) if current_word: words.append(''.join(current_word)) return words