|
|
""" |
|
|
Inference script for Code Comment Quality Classifier |
|
|
""" |
|
|
import argparse |
|
|
import torch |
|
|
import logging |
|
|
from typing import List, Dict, Union, Optional |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
|
|
|
class CommentQualityClassifier: |
|
|
"""Wrapper class for easy inference with optimizations.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: str, |
|
|
device: Optional[str] = None, |
|
|
use_fp16: bool = False |
|
|
): |
|
|
""" |
|
|
Initialize the classifier. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the trained model or Hugging Face model ID |
|
|
device: Device to run inference on ('cuda', 'cpu', or None for auto-detect) |
|
|
use_fp16: Whether to use half precision for faster inference (GPU only) |
|
|
""" |
|
|
logging.info(f"Loading model from {model_path}...") |
|
|
|
|
|
|
|
|
if device is None: |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
self.device = torch.device(device) |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
if use_fp16 and self.device.type == 'cuda': |
|
|
self.model.half() |
|
|
logging.info("Using FP16 precision for inference") |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.id2label = self.model.config.id2label |
|
|
self.label2id = self.model.config.label2id |
|
|
|
|
|
logging.info(f"Model loaded successfully on {self.device}") |
|
|
logging.info(f"Labels: {list(self.id2label.values())}") |
|
|
print(f"✓ Model loaded successfully on {self.device}") |
|
|
print(f"✓ Labels: {list(self.id2label.values())}") |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
comment: str, |
|
|
return_probabilities: bool = False, |
|
|
confidence_threshold: Optional[float] = None |
|
|
) -> Union[str, Dict]: |
|
|
""" |
|
|
Predict the quality of a code comment. |
|
|
|
|
|
Args: |
|
|
comment: The code comment text |
|
|
return_probabilities: If True, return probabilities for all classes |
|
|
confidence_threshold: Optional minimum confidence threshold (returns None if below) |
|
|
|
|
|
Returns: |
|
|
If return_probabilities is False: predicted label (str) or None if below threshold |
|
|
If return_probabilities is True: dict with label, confidence, and probabilities |
|
|
""" |
|
|
if not comment or not comment.strip(): |
|
|
logging.warning("Empty comment provided") |
|
|
if return_probabilities: |
|
|
return { |
|
|
'label': None, |
|
|
'confidence': 0.0, |
|
|
'probabilities': {} |
|
|
} |
|
|
return None |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
comment, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
padding=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
predicted_class = torch.argmax(probabilities, dim=-1).item() |
|
|
|
|
|
confidence = probabilities[0][predicted_class].item() |
|
|
predicted_label = self.id2label[predicted_class] |
|
|
|
|
|
|
|
|
if confidence_threshold and confidence < confidence_threshold: |
|
|
if return_probabilities: |
|
|
return { |
|
|
'label': None, |
|
|
'confidence': confidence, |
|
|
'probabilities': { |
|
|
self.id2label[i]: prob.item() |
|
|
for i, prob in enumerate(probabilities[0]) |
|
|
}, |
|
|
'below_threshold': True |
|
|
} |
|
|
return None |
|
|
|
|
|
if return_probabilities: |
|
|
prob_dict = { |
|
|
self.id2label[i]: prob.item() |
|
|
for i, prob in enumerate(probabilities[0]) |
|
|
} |
|
|
return { |
|
|
'label': predicted_label, |
|
|
'confidence': confidence, |
|
|
'probabilities': prob_dict |
|
|
} |
|
|
|
|
|
return predicted_label |
|
|
|
|
|
def predict_batch( |
|
|
self, |
|
|
comments: List[str], |
|
|
batch_size: int = 32, |
|
|
return_probabilities: bool = False |
|
|
) -> Union[List[str], List[Dict]]: |
|
|
""" |
|
|
Predict quality for multiple comments with batching support. |
|
|
|
|
|
Args: |
|
|
comments: List of code comment texts |
|
|
batch_size: Batch size for processing |
|
|
return_probabilities: If True, return full probability dicts |
|
|
|
|
|
Returns: |
|
|
List of predicted labels or dicts with probabilities |
|
|
""" |
|
|
if not comments: |
|
|
return [] |
|
|
|
|
|
all_results = [] |
|
|
|
|
|
|
|
|
for i in range(0, len(comments), batch_size): |
|
|
batch = comments[i:i + batch_size] |
|
|
|
|
|
|
|
|
inputs = self.tokenizer( |
|
|
batch, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
padding=True |
|
|
) |
|
|
|
|
|
|
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
predictions = torch.argmax(probabilities, dim=-1) |
|
|
|
|
|
if return_probabilities: |
|
|
for j, (pred, probs) in enumerate(zip(predictions, probabilities)): |
|
|
prob_dict = { |
|
|
self.id2label[k]: prob.item() |
|
|
for k, prob in enumerate(probs) |
|
|
} |
|
|
all_results.append({ |
|
|
'label': self.id2label[pred.item()], |
|
|
'confidence': probs[pred.item()].item(), |
|
|
'probabilities': prob_dict |
|
|
}) |
|
|
else: |
|
|
all_results.extend([self.id2label[pred.item()] for pred in predictions]) |
|
|
|
|
|
return all_results |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main inference function with example usage.""" |
|
|
parser = argparse.ArgumentParser(description="Classify code comment quality") |
|
|
parser.add_argument( |
|
|
"--model-path", |
|
|
type=str, |
|
|
default="./results/final_model", |
|
|
help="Path to the trained model" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--comment", |
|
|
type=str, |
|
|
help="Code comment to classify" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
choices=['cuda', 'cpu', 'auto'], |
|
|
default='auto', |
|
|
help="Device to run inference on" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fp16", |
|
|
action="store_true", |
|
|
help="Use FP16 precision for faster inference (GPU only)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--confidence-threshold", |
|
|
type=float, |
|
|
default=None, |
|
|
help="Minimum confidence threshold (0.0-1.0)" |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
|
|
|
|
|
|
device = None if args.device == 'auto' else args.device |
|
|
classifier = CommentQualityClassifier( |
|
|
args.model_path, |
|
|
device=device, |
|
|
use_fp16=args.fp16 |
|
|
) |
|
|
|
|
|
|
|
|
if args.comment: |
|
|
comments = [args.comment] |
|
|
else: |
|
|
print("\nNo comment provided. Using example comments...\n") |
|
|
comments = [ |
|
|
"This function calculates the Fibonacci sequence using dynamic programming to avoid redundant calculations. Time complexity: O(n), Space complexity: O(n)", |
|
|
"does stuff with numbers", |
|
|
"TODO: fix this later", |
|
|
"Calculates sum of two numbers", |
|
|
"This function adds two integers and returns the result. Parameters: a (int), b (int). Returns: int sum", |
|
|
"loop through array", |
|
|
"DEPRECATED: Use calculate_new() instead. This method will be removed in v2.0", |
|
|
] |
|
|
|
|
|
print("=" * 80) |
|
|
print("Code Comment Quality Classification") |
|
|
print("=" * 80) |
|
|
|
|
|
for i, comment in enumerate(comments, 1): |
|
|
print(f"\n[{i}] Comment: {comment}") |
|
|
print("-" * 80) |
|
|
|
|
|
result = classifier.predict( |
|
|
comment, |
|
|
return_probabilities=True, |
|
|
confidence_threshold=args.confidence_threshold |
|
|
) |
|
|
|
|
|
if result['label'] is None: |
|
|
print(f"Predicted Quality: LOW CONFIDENCE (below threshold)") |
|
|
print(f"Confidence: {result['confidence']:.2%}") |
|
|
else: |
|
|
print(f"Predicted Quality: {result['label'].upper()}") |
|
|
print(f"Confidence: {result['confidence']:.2%}") |
|
|
|
|
|
print("\nAll Probabilities:") |
|
|
for label, prob in sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True): |
|
|
bar = "█" * int(prob * 50) |
|
|
print(f" {label:10s}: {bar:50s} {prob:.2%}") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|