def predict_with_threshold(text, model_dir="causal-classifier", threshold=None, threshold_type=None): """ Make a prediction using the saved model and custom threshold. Args: text: Input text to classify model_dir: Directory where the model and threshold are saved threshold: Custom threshold to use (if None, loads from saved config) threshold_type: Type of threshold to use ('f1' or 'balanced') if loading from all_thresholds.json Returns: Dictionary with prediction results """ # Load the model and tokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch import json tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(model_dir) threshold_source = "custom" # Load the threshold if not provided if threshold is None: # Try multiple methods to get the threshold if threshold_type is not None: # First try all_thresholds.json if threshold_type is specified try: with open(f"{model_dir}/all_thresholds.json", "r") as f: all_thresholds = json.load(f) if threshold_type in all_thresholds: threshold = all_thresholds[threshold_type]["threshold"] threshold_source = f"all_thresholds.json ({threshold_type})" else: print(f"Threshold type '{threshold_type}' not found. Available types: {list(all_thresholds.keys())}") except FileNotFoundError: pass # If still no threshold, try threshold_config.json if threshold is None: try: with open(f"{model_dir}/threshold_config.json", "r") as f: config = json.load(f) threshold = config["threshold"] threshold_source = "threshold_config.json" except FileNotFoundError: # Default to 0.5 if no threshold config is found threshold = 0.5 threshold_source = "default" print("No threshold configuration found. Using default threshold of 0.5.") # Tokenize the input text inputs = tokenizer(text, padding="max_length", truncation=True, max_length=512, return_tensors="pt") # Get model prediction model.eval() with torch.no_grad(): outputs = model(**inputs) # Convert logits to probabilities probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().tolist() # Apply threshold to get the final prediction if isinstance(probs, list): # Handle batch size of 1 prediction = 1 if probs[1] > threshold else 0 class_probs = { "Causal": probs[1], "Descriptive": probs[0] } else: # Handle single prediction case prediction = 1 if probs > threshold else 0 class_probs = { "Causal": probs, "Descriptive": 1 - probs } # Map prediction back to original label label_names = {1: "Causal", 0: "Descriptive"} return { "prediction": label_names[prediction], "probabilities": class_probs, "threshold_used": threshold, "threshold_source": threshold_source } sample_text = 'This is a causal study that aims to investigate the relationship between smoking and lung cancer.' result = predict_with_threshold(sample_text)