causal-language-ss-abstracts / predict_ft_model_causal.py
calvini's picture
Upload 10 files
2f4c052 verified
def predict_with_threshold(text, model_dir="causal-classifier", threshold=None, threshold_type=None):
"""
Make a prediction using the saved model and custom threshold.
Args:
text: Input text to classify
model_dir: Directory where the model and threshold are saved
threshold: Custom threshold to use (if None, loads from saved config)
threshold_type: Type of threshold to use ('f1' or 'balanced') if loading from all_thresholds.json
Returns:
Dictionary with prediction results
"""
# Load the model and tokenizer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import json
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
threshold_source = "custom"
# Load the threshold if not provided
if threshold is None:
# Try multiple methods to get the threshold
if threshold_type is not None:
# First try all_thresholds.json if threshold_type is specified
try:
with open(f"{model_dir}/all_thresholds.json", "r") as f:
all_thresholds = json.load(f)
if threshold_type in all_thresholds:
threshold = all_thresholds[threshold_type]["threshold"]
threshold_source = f"all_thresholds.json ({threshold_type})"
else:
print(f"Threshold type '{threshold_type}' not found. Available types: {list(all_thresholds.keys())}")
except FileNotFoundError:
pass
# If still no threshold, try threshold_config.json
if threshold is None:
try:
with open(f"{model_dir}/threshold_config.json", "r") as f:
config = json.load(f)
threshold = config["threshold"]
threshold_source = "threshold_config.json"
except FileNotFoundError:
# Default to 0.5 if no threshold config is found
threshold = 0.5
threshold_source = "default"
print("No threshold configuration found. Using default threshold of 0.5.")
# Tokenize the input text
inputs = tokenizer(text, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
# Get model prediction
model.eval()
with torch.no_grad():
outputs = model(**inputs)
# Convert logits to probabilities
probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist()
# Apply threshold to get the final prediction
if isinstance(probs, list): # Handle batch size of 1
prediction = 1 if probs[1] > threshold else 0
class_probs = {
"Causal": probs[1],
"Descriptive": probs[0]
}
else: # Handle single prediction case
prediction = 1 if probs > threshold else 0
class_probs = {
"Causal": probs,
"Descriptive": 1 - probs
}
# Map prediction back to original label
label_names = {1: "Causal", 0: "Descriptive"}
return {
"prediction": label_names[prediction],
"probabilities": class_probs,
"threshold_used": threshold,
"threshold_source": threshold_source
}
sample_text = 'This is a causal study that aims to investigate the relationship between smoking and lung cancer.'
result = predict_with_threshold(sample_text)