code-comment-classifier / inference.py
Snaseem2026's picture
Upload inference.py with huggingface_hub
aa71938 verified
"""
Inference script for Code Comment Quality Classifier
"""
import argparse
import torch
import logging
from typing import List, Dict, Union, Optional
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class CommentQualityClassifier:
"""Wrapper class for easy inference with optimizations."""
def __init__(
self,
model_path: str,
device: Optional[str] = None,
use_fp16: bool = False
):
"""
Initialize the classifier.
Args:
model_path: Path to the trained model or Hugging Face model ID
device: Device to run inference on ('cuda', 'cpu', or None for auto-detect)
use_fp16: Whether to use half precision for faster inference (GPU only)
"""
logging.info(f"Loading model from {model_path}...")
# Auto-detect device
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(device)
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
# Move model to device
self.model.to(self.device)
# Enable half precision if requested and on GPU
if use_fp16 and self.device.type == 'cuda':
self.model.half()
logging.info("Using FP16 precision for inference")
self.model.eval()
# Get label mapping
self.id2label = self.model.config.id2label
self.label2id = self.model.config.label2id
logging.info(f"Model loaded successfully on {self.device}")
logging.info(f"Labels: {list(self.id2label.values())}")
print(f"✓ Model loaded successfully on {self.device}")
print(f"✓ Labels: {list(self.id2label.values())}")
def predict(
self,
comment: str,
return_probabilities: bool = False,
confidence_threshold: Optional[float] = None
) -> Union[str, Dict]:
"""
Predict the quality of a code comment.
Args:
comment: The code comment text
return_probabilities: If True, return probabilities for all classes
confidence_threshold: Optional minimum confidence threshold (returns None if below)
Returns:
If return_probabilities is False: predicted label (str) or None if below threshold
If return_probabilities is True: dict with label, confidence, and probabilities
"""
if not comment or not comment.strip():
logging.warning("Empty comment provided")
if return_probabilities:
return {
'label': None,
'confidence': 0.0,
'probabilities': {}
}
return None
# Tokenize input
inputs = self.tokenizer(
comment,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Move inputs to device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get predictions
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
confidence = probabilities[0][predicted_class].item()
predicted_label = self.id2label[predicted_class]
# Check confidence threshold
if confidence_threshold and confidence < confidence_threshold:
if return_probabilities:
return {
'label': None,
'confidence': confidence,
'probabilities': {
self.id2label[i]: prob.item()
for i, prob in enumerate(probabilities[0])
},
'below_threshold': True
}
return None
if return_probabilities:
prob_dict = {
self.id2label[i]: prob.item()
for i, prob in enumerate(probabilities[0])
}
return {
'label': predicted_label,
'confidence': confidence,
'probabilities': prob_dict
}
return predicted_label
def predict_batch(
self,
comments: List[str],
batch_size: int = 32,
return_probabilities: bool = False
) -> Union[List[str], List[Dict]]:
"""
Predict quality for multiple comments with batching support.
Args:
comments: List of code comment texts
batch_size: Batch size for processing
return_probabilities: If True, return full probability dicts
Returns:
List of predicted labels or dicts with probabilities
"""
if not comments:
return []
all_results = []
# Process in batches
for i in range(0, len(comments), batch_size):
batch = comments[i:i + batch_size]
# Tokenize batch
inputs = self.tokenizer(
batch,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Move inputs to device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
predictions = torch.argmax(probabilities, dim=-1)
if return_probabilities:
for j, (pred, probs) in enumerate(zip(predictions, probabilities)):
prob_dict = {
self.id2label[k]: prob.item()
for k, prob in enumerate(probs)
}
all_results.append({
'label': self.id2label[pred.item()],
'confidence': probs[pred.item()].item(),
'probabilities': prob_dict
})
else:
all_results.extend([self.id2label[pred.item()] for pred in predictions])
return all_results
def main():
"""Main inference function with example usage."""
parser = argparse.ArgumentParser(description="Classify code comment quality")
parser.add_argument(
"--model-path",
type=str,
default="./results/final_model",
help="Path to the trained model"
)
parser.add_argument(
"--comment",
type=str,
help="Code comment to classify"
)
parser.add_argument(
"--device",
type=str,
choices=['cuda', 'cpu', 'auto'],
default='auto',
help="Device to run inference on"
)
parser.add_argument(
"--fp16",
action="store_true",
help="Use FP16 precision for faster inference (GPU only)"
)
parser.add_argument(
"--confidence-threshold",
type=float,
default=None,
help="Minimum confidence threshold (0.0-1.0)"
)
args = parser.parse_args()
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# Initialize classifier
device = None if args.device == 'auto' else args.device
classifier = CommentQualityClassifier(
args.model_path,
device=device,
use_fp16=args.fp16
)
# Example comments if none provided
if args.comment:
comments = [args.comment]
else:
print("\nNo comment provided. Using example comments...\n")
comments = [
"This function calculates the Fibonacci sequence using dynamic programming to avoid redundant calculations. Time complexity: O(n), Space complexity: O(n)",
"does stuff with numbers",
"TODO: fix this later",
"Calculates sum of two numbers",
"This function adds two integers and returns the result. Parameters: a (int), b (int). Returns: int sum",
"loop through array",
"DEPRECATED: Use calculate_new() instead. This method will be removed in v2.0",
]
print("=" * 80)
print("Code Comment Quality Classification")
print("=" * 80)
for i, comment in enumerate(comments, 1):
print(f"\n[{i}] Comment: {comment}")
print("-" * 80)
result = classifier.predict(
comment,
return_probabilities=True,
confidence_threshold=args.confidence_threshold
)
if result['label'] is None:
print(f"Predicted Quality: LOW CONFIDENCE (below threshold)")
print(f"Confidence: {result['confidence']:.2%}")
else:
print(f"Predicted Quality: {result['label'].upper()}")
print(f"Confidence: {result['confidence']:.2%}")
print("\nAll Probabilities:")
for label, prob in sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True):
bar = "█" * int(prob * 50)
print(f" {label:10s}: {bar:50s} {prob:.2%}")
print("\n" + "=" * 80)
if __name__ == "__main__":
main()