GitHub Actions
Track large files with LFS
447d423
# filepath: src/model/inference.py
from transformers import AutoTokenizer, AutoModelForTokenClassification
from peft import PeftModel, PeftConfig
import torch
import torch.nn.functional as F
from Idiom_lexicon import KNOWN_IDIOMS
import fitz # PyMuPDF
import tempfile
from PIL import Image
import pytesseract
import nltk
import spacy
import json
from pathlib import Path
from fastapi import HTTPException
nltk.download('punkt_tab', quiet=True)
from langdetect import detect
from nltk.tokenize import sent_tokenize
import re
LANG_MAP = {
'en': 'english',
'es': 'spanish',
# add more if needed
}
def split_text_by_language(text, language: str):
# Map input language (e.g., 'en', 'es') to NLTK language codes
nltk_lang = LANG_MAP.get(language.lower(), 'english')
sentences = sent_tokenize(text, language=nltk_lang)
return sentences
def load_model(checkpoint_path):
config = PeftConfig.from_pretrained(checkpoint_path)
base_model = AutoModelForTokenClassification.from_pretrained(
config.base_model_name_or_path,
num_labels=3 # O, B-IDIOM, I-IDIOM
)
model = PeftModel.from_pretrained(base_model, checkpoint_path)
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
return model, tokenizer
def normalize_text(text):
# Join hyphenated words split across lines
text = re.sub(r'-\s*\n\s*', '', text)
# Replace newlines with spaces
text = re.sub(r'\n+', ' ', text)
# Collapse multiple spaces into one
text = re.sub(r'\s+', ' ', text)
return text.strip()
def filter_idioms(candidate_idioms, known_idioms, min_len=2):
filtered = []
for idiom in candidate_idioms:
norm = idiom.lower().strip()
if norm in known_idioms or len(norm.split()) >= min_len:
filtered.append(idiom)
return filtered
import spacy
import json
from pathlib import Path
class IdiomMatcher:
def __init__(self, idiom_files: dict[str, str]):
self.models = {
"en": spacy.load("en_core_web_sm"),
"es": spacy.load("es_core_news_sm"),
}
self.idioms_by_lang = {lang: [] for lang in idiom_files}
self._load_idioms(idiom_files)
def _lemmatize(self, text: str, lang: str) -> str:
doc = self.models[lang](text)
return " ".join(token.lemma_ for token in doc)
def _load_idioms(self, idiom_files: dict[str, str]):
for lang, file_path in idiom_files.items():
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"Idiom file not found for {lang}: {file_path}")
with open(path, "r", encoding="utf-8") as f:
for line in f:
entry = json.loads(line)
idiom_text = entry.get("idiom", "").strip()
if not idiom_text:
continue
entry["lemmatized"] = self._lemmatize(idiom_text, lang)
self.idioms_by_lang[lang].append(entry)
def match(self, sentence: str, lang: str):
if lang not in self.models:
raise ValueError(f"Unsupported language: {lang}")
sent_lemma = self._lemmatize(sentence, lang)
return [
idiom for idiom in self.idioms_by_lang[lang]
if idiom["lemmatized"] in sent_lemma
]
def predict_idiom(text, model, tokenizer, device, conf_threshold=0.9):
words = text.split()
if not words:
print("[⚠️] Empty input text")
return []
inputs = tokenizer(
words,
is_split_into_words=True,
truncation=True,
padding=True,
max_length=128,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = F.softmax(logits, dim=-1)
max_probs, predictions = torch.max(probs, dim=-1)
max_probs = max_probs.cpu().numpy()[0]
predictions = predictions.cpu().numpy()[0]
word_ids = inputs.word_ids(batch_index=0)
idioms = []
current_idiom_start = -1
current_idiom_end = -1
for i, (pred_label, conf, word_idx) in enumerate(zip(predictions, max_probs, word_ids)):
if word_idx is None:
if current_idiom_start != -1:
idioms.append(' '.join(words[current_idiom_start:current_idiom_end + 1]))
current_idiom_start = -1
current_idiom_end = -1
continue
if conf < conf_threshold:
pred_label = 0
if pred_label == 1: # B-IDIOM
if current_idiom_start != -1:
idioms.append(' '.join(words[current_idiom_start:current_idiom_end + 1]))
current_idiom_start = word_idx
current_idiom_end = word_idx
elif pred_label == 2: # I-IDIOM
if current_idiom_start != -1 and (word_idx == current_idiom_end or word_idx == current_idiom_end + 1):
current_idiom_end = word_idx
else:
if current_idiom_start != -1:
idioms.append(' '.join(words[current_idiom_start:current_idiom_end + 1]))
current_idiom_start = -1
current_idiom_end = -1
else: # O
if current_idiom_start != -1:
idioms.append(' '.join(words[current_idiom_start:current_idiom_end + 1]))
current_idiom_start = -1
current_idiom_end = -1
if current_idiom_start != -1:
idioms.append(' '.join(words[current_idiom_start:current_idiom_end + 1]))
idioms = filter_idioms(idioms, known_idioms=KNOWN_IDIOMS)
return idioms
import pdfplumber
def extract_text_from_pdf(pdf_bytes: bytes) -> str:
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
tmp.write(pdf_bytes)
tmp_path = tmp.name
doc = fitz.open(tmp_path)
text = ""
for i, page in enumerate(doc):
page_text = page.get_text()
print(f"[DEBUG] Page {i+1} extracted text (first 100 chars): {repr(page_text[:100])}")
text += page_text
doc.close()
text = normalize_text(text)
print("[DEBUG] Cleaned extracted text from PDF (first 500 chars):", repr(text[:500]))
if not text:
print("[⚠️] No text extracted from PDF. It may be blank or not readable.")
return text
def reconstruct_words(tokens, labels):
"""
Reconstruct words from BERT tokens and their corresponding labels.
This function is used to map the BERT token predictions back to the original words.
"""
words = []
current_word = []
current_label = None
for token, label in zip(tokens, labels):
if label == 'O':
if current_word:
words.append(''.join(current_word))
current_word = []
continue
if label.startswith('B-'):
if current_word:
words.append(''.join(current_word))
current_word = []
current_label = label[2:] # Get the idiom type
current_word.append(token)
elif label.startswith('I-') and current_label == label[2:]:
current_word.append(token)
if current_word:
words.append(''.join(current_word))
return words