""" Inference script for Code Comment Quality Classifier """ import argparse import torch import logging from typing import List, Dict, Union, Optional from transformers import AutoTokenizer, AutoModelForSequenceClassification class CommentQualityClassifier: """Wrapper class for easy inference with optimizations.""" def __init__( self, model_path: str, device: Optional[str] = None, use_fp16: bool = False ): """ Initialize the classifier. Args: model_path: Path to the trained model or Hugging Face model ID device: Device to run inference on ('cuda', 'cpu', or None for auto-detect) use_fp16: Whether to use half precision for faster inference (GPU only) """ logging.info(f"Loading model from {model_path}...") # Auto-detect device if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(device) # Load tokenizer and model self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForSequenceClassification.from_pretrained(model_path) # Move model to device self.model.to(self.device) # Enable half precision if requested and on GPU if use_fp16 and self.device.type == 'cuda': self.model.half() logging.info("Using FP16 precision for inference") self.model.eval() # Get label mapping self.id2label = self.model.config.id2label self.label2id = self.model.config.label2id logging.info(f"Model loaded successfully on {self.device}") logging.info(f"Labels: {list(self.id2label.values())}") print(f"✓ Model loaded successfully on {self.device}") print(f"✓ Labels: {list(self.id2label.values())}") def predict( self, comment: str, return_probabilities: bool = False, confidence_threshold: Optional[float] = None ) -> Union[str, Dict]: """ Predict the quality of a code comment. Args: comment: The code comment text return_probabilities: If True, return probabilities for all classes confidence_threshold: Optional minimum confidence threshold (returns None if below) Returns: If return_probabilities is False: predicted label (str) or None if below threshold If return_probabilities is True: dict with label, confidence, and probabilities """ if not comment or not comment.strip(): logging.warning("Empty comment provided") if return_probabilities: return { 'label': None, 'confidence': 0.0, 'probabilities': {} } return None # Tokenize input inputs = self.tokenizer( comment, return_tensors="pt", truncation=True, max_length=512, padding=True ) # Move inputs to device inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get predictions with torch.no_grad(): outputs = self.model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) predicted_class = torch.argmax(probabilities, dim=-1).item() confidence = probabilities[0][predicted_class].item() predicted_label = self.id2label[predicted_class] # Check confidence threshold if confidence_threshold and confidence < confidence_threshold: if return_probabilities: return { 'label': None, 'confidence': confidence, 'probabilities': { self.id2label[i]: prob.item() for i, prob in enumerate(probabilities[0]) }, 'below_threshold': True } return None if return_probabilities: prob_dict = { self.id2label[i]: prob.item() for i, prob in enumerate(probabilities[0]) } return { 'label': predicted_label, 'confidence': confidence, 'probabilities': prob_dict } return predicted_label def predict_batch( self, comments: List[str], batch_size: int = 32, return_probabilities: bool = False ) -> Union[List[str], List[Dict]]: """ Predict quality for multiple comments with batching support. Args: comments: List of code comment texts batch_size: Batch size for processing return_probabilities: If True, return full probability dicts Returns: List of predicted labels or dicts with probabilities """ if not comments: return [] all_results = [] # Process in batches for i in range(0, len(comments), batch_size): batch = comments[i:i + batch_size] # Tokenize batch inputs = self.tokenizer( batch, return_tensors="pt", truncation=True, max_length=512, padding=True ) # Move inputs to device inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs) probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) predictions = torch.argmax(probabilities, dim=-1) if return_probabilities: for j, (pred, probs) in enumerate(zip(predictions, probabilities)): prob_dict = { self.id2label[k]: prob.item() for k, prob in enumerate(probs) } all_results.append({ 'label': self.id2label[pred.item()], 'confidence': probs[pred.item()].item(), 'probabilities': prob_dict }) else: all_results.extend([self.id2label[pred.item()] for pred in predictions]) return all_results def main(): """Main inference function with example usage.""" parser = argparse.ArgumentParser(description="Classify code comment quality") parser.add_argument( "--model-path", type=str, default="./results/final_model", help="Path to the trained model" ) parser.add_argument( "--comment", type=str, help="Code comment to classify" ) parser.add_argument( "--device", type=str, choices=['cuda', 'cpu', 'auto'], default='auto', help="Device to run inference on" ) parser.add_argument( "--fp16", action="store_true", help="Use FP16 precision for faster inference (GPU only)" ) parser.add_argument( "--confidence-threshold", type=float, default=None, help="Minimum confidence threshold (0.0-1.0)" ) args = parser.parse_args() # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) # Initialize classifier device = None if args.device == 'auto' else args.device classifier = CommentQualityClassifier( args.model_path, device=device, use_fp16=args.fp16 ) # Example comments if none provided if args.comment: comments = [args.comment] else: print("\nNo comment provided. Using example comments...\n") comments = [ "This function calculates the Fibonacci sequence using dynamic programming to avoid redundant calculations. Time complexity: O(n), Space complexity: O(n)", "does stuff with numbers", "TODO: fix this later", "Calculates sum of two numbers", "This function adds two integers and returns the result. Parameters: a (int), b (int). Returns: int sum", "loop through array", "DEPRECATED: Use calculate_new() instead. This method will be removed in v2.0", ] print("=" * 80) print("Code Comment Quality Classification") print("=" * 80) for i, comment in enumerate(comments, 1): print(f"\n[{i}] Comment: {comment}") print("-" * 80) result = classifier.predict( comment, return_probabilities=True, confidence_threshold=args.confidence_threshold ) if result['label'] is None: print(f"Predicted Quality: LOW CONFIDENCE (below threshold)") print(f"Confidence: {result['confidence']:.2%}") else: print(f"Predicted Quality: {result['label'].upper()}") print(f"Confidence: {result['confidence']:.2%}") print("\nAll Probabilities:") for label, prob in sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True): bar = "█" * int(prob * 50) print(f" {label:10s}: {bar:50s} {prob:.2%}") print("\n" + "=" * 80) if __name__ == "__main__": main()