argument-backend / relations /predict_bert.py
p2002814
now using the best fine tuned bert model, uploaded to hugging face hub
bb34072
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def load_bert_model(model_path="../models/bert-argument", device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.to(device)
model.eval()
return model, tokenizer, device
def predict_relation(parent_text, child_text, model, tokenizer, device, max_length=256):
"""
Predicts whether the relation between parent and child is Support or Attack.
"""
model.eval()
# Tokenization
encoding = tokenizer(
parent_text,
child_text,
add_special_tokens=True,
max_length=max_length,
padding='max_length',
truncation='only_second',
return_attention_mask=True,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred].item()
relation = "Support" if pred == 1 else "Attack"
return {
"predicted_label": relation,
"probability": confidence,
"confidence": confidence
}