|
|
def predict_with_threshold(text, model_dir="causal-classifier", threshold=None, threshold_type=None): |
|
|
""" |
|
|
Make a prediction using the saved model and custom threshold. |
|
|
|
|
|
Args: |
|
|
text: Input text to classify |
|
|
model_dir: Directory where the model and threshold are saved |
|
|
threshold: Custom threshold to use (if None, loads from saved config) |
|
|
threshold_type: Type of threshold to use ('f1' or 'balanced') if loading from all_thresholds.json |
|
|
|
|
|
Returns: |
|
|
Dictionary with prediction results |
|
|
""" |
|
|
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
import torch |
|
|
import json |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_dir) |
|
|
|
|
|
threshold_source = "custom" |
|
|
|
|
|
|
|
|
if threshold is None: |
|
|
|
|
|
if threshold_type is not None: |
|
|
|
|
|
try: |
|
|
with open(f"{model_dir}/all_thresholds.json", "r") as f: |
|
|
all_thresholds = json.load(f) |
|
|
if threshold_type in all_thresholds: |
|
|
threshold = all_thresholds[threshold_type]["threshold"] |
|
|
threshold_source = f"all_thresholds.json ({threshold_type})" |
|
|
else: |
|
|
print(f"Threshold type '{threshold_type}' not found. Available types: {list(all_thresholds.keys())}") |
|
|
except FileNotFoundError: |
|
|
pass |
|
|
|
|
|
|
|
|
if threshold is None: |
|
|
try: |
|
|
with open(f"{model_dir}/threshold_config.json", "r") as f: |
|
|
config = json.load(f) |
|
|
threshold = config["threshold"] |
|
|
threshold_source = "threshold_config.json" |
|
|
except FileNotFoundError: |
|
|
|
|
|
threshold = 0.5 |
|
|
threshold_source = "default" |
|
|
print("No threshold configuration found. Using default threshold of 0.5.") |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, padding="max_length", truncation=True, max_length=512, return_tensors="pt") |
|
|
|
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
|
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist() |
|
|
|
|
|
|
|
|
if isinstance(probs, list): |
|
|
prediction = 1 if probs[1] > threshold else 0 |
|
|
class_probs = { |
|
|
"Causal": probs[1], |
|
|
"Descriptive": probs[0] |
|
|
} |
|
|
else: |
|
|
prediction = 1 if probs > threshold else 0 |
|
|
class_probs = { |
|
|
"Causal": probs, |
|
|
"Descriptive": 1 - probs |
|
|
} |
|
|
|
|
|
|
|
|
label_names = {1: "Causal", 0: "Descriptive"} |
|
|
|
|
|
return { |
|
|
"prediction": label_names[prediction], |
|
|
"probabilities": class_probs, |
|
|
"threshold_used": threshold, |
|
|
"threshold_source": threshold_source |
|
|
} |
|
|
|
|
|
|
|
|
sample_text = 'This is a causal study that aims to investigate the relationship between smoking and lung cancer.' |
|
|
result = predict_with_threshold(sample_text) |