File size: 9,752 Bytes
aa71938
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""
Inference script for Code Comment Quality Classifier
"""
import argparse
import torch
import logging
from typing import List, Dict, Union, Optional
from transformers import AutoTokenizer, AutoModelForSequenceClassification


class CommentQualityClassifier:
    """Wrapper class for easy inference with optimizations."""
    
    def __init__(
        self, 
        model_path: str,
        device: Optional[str] = None,
        use_fp16: bool = False
    ):
        """
        Initialize the classifier.
        
        Args:
            model_path: Path to the trained model or Hugging Face model ID
            device: Device to run inference on ('cuda', 'cpu', or None for auto-detect)
            use_fp16: Whether to use half precision for faster inference (GPU only)
        """
        logging.info(f"Loading model from {model_path}...")
        
        # Auto-detect device
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = torch.device(device)
        
        # Load tokenizer and model
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        
        # Move model to device
        self.model.to(self.device)
        
        # Enable half precision if requested and on GPU
        if use_fp16 and self.device.type == 'cuda':
            self.model.half()
            logging.info("Using FP16 precision for inference")
        
        self.model.eval()
        
        # Get label mapping
        self.id2label = self.model.config.id2label
        self.label2id = self.model.config.label2id
        
        logging.info(f"Model loaded successfully on {self.device}")
        logging.info(f"Labels: {list(self.id2label.values())}")
        print(f"✓ Model loaded successfully on {self.device}")
        print(f"✓ Labels: {list(self.id2label.values())}")
    
    def predict(
        self, 
        comment: str, 
        return_probabilities: bool = False,
        confidence_threshold: Optional[float] = None
    ) -> Union[str, Dict]:
        """
        Predict the quality of a code comment.
        
        Args:
            comment: The code comment text
            return_probabilities: If True, return probabilities for all classes
            confidence_threshold: Optional minimum confidence threshold (returns None if below)
            
        Returns:
            If return_probabilities is False: predicted label (str) or None if below threshold
            If return_probabilities is True: dict with label, confidence, and probabilities
        """
        if not comment or not comment.strip():
            logging.warning("Empty comment provided")
            if return_probabilities:
                return {
                    'label': None,
                    'confidence': 0.0,
                    'probabilities': {}
                }
            return None
        
        # Tokenize input
        inputs = self.tokenizer(
            comment,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True
        )
        
        # Move inputs to device
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Get predictions
        with torch.no_grad():
            outputs = self.model(**inputs)
            probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=-1).item()
        
        confidence = probabilities[0][predicted_class].item()
        predicted_label = self.id2label[predicted_class]
        
        # Check confidence threshold
        if confidence_threshold and confidence < confidence_threshold:
            if return_probabilities:
                return {
                    'label': None,
                    'confidence': confidence,
                    'probabilities': {
                        self.id2label[i]: prob.item()
                        for i, prob in enumerate(probabilities[0])
                    },
                    'below_threshold': True
                }
            return None
        
        if return_probabilities:
            prob_dict = {
                self.id2label[i]: prob.item()
                for i, prob in enumerate(probabilities[0])
            }
            return {
                'label': predicted_label,
                'confidence': confidence,
                'probabilities': prob_dict
            }
        
        return predicted_label
    
    def predict_batch(
        self, 
        comments: List[str],
        batch_size: int = 32,
        return_probabilities: bool = False
    ) -> Union[List[str], List[Dict]]:
        """
        Predict quality for multiple comments with batching support.
        
        Args:
            comments: List of code comment texts
            batch_size: Batch size for processing
            return_probabilities: If True, return full probability dicts
            
        Returns:
            List of predicted labels or dicts with probabilities
        """
        if not comments:
            return []
        
        all_results = []
        
        # Process in batches
        for i in range(0, len(comments), batch_size):
            batch = comments[i:i + batch_size]
            
            # Tokenize batch
            inputs = self.tokenizer(
                batch,
                return_tensors="pt",
                truncation=True,
                max_length=512,
                padding=True
            )
            
            # Move inputs to device
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.model(**inputs)
                probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
                predictions = torch.argmax(probabilities, dim=-1)
            
            if return_probabilities:
                for j, (pred, probs) in enumerate(zip(predictions, probabilities)):
                    prob_dict = {
                        self.id2label[k]: prob.item()
                        for k, prob in enumerate(probs)
                    }
                    all_results.append({
                        'label': self.id2label[pred.item()],
                        'confidence': probs[pred.item()].item(),
                        'probabilities': prob_dict
                    })
            else:
                all_results.extend([self.id2label[pred.item()] for pred in predictions])
        
        return all_results


def main():
    """Main inference function with example usage."""
    parser = argparse.ArgumentParser(description="Classify code comment quality")
    parser.add_argument(
        "--model-path",
        type=str,
        default="./results/final_model",
        help="Path to the trained model"
    )
    parser.add_argument(
        "--comment",
        type=str,
        help="Code comment to classify"
    )
    parser.add_argument(
        "--device",
        type=str,
        choices=['cuda', 'cpu', 'auto'],
        default='auto',
        help="Device to run inference on"
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Use FP16 precision for faster inference (GPU only)"
    )
    parser.add_argument(
        "--confidence-threshold",
        type=float,
        default=None,
        help="Minimum confidence threshold (0.0-1.0)"
    )
    args = parser.parse_args()
    
    # Setup logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    
    # Initialize classifier
    device = None if args.device == 'auto' else args.device
    classifier = CommentQualityClassifier(
        args.model_path,
        device=device,
        use_fp16=args.fp16
    )
    
    # Example comments if none provided
    if args.comment:
        comments = [args.comment]
    else:
        print("\nNo comment provided. Using example comments...\n")
        comments = [
            "This function calculates the Fibonacci sequence using dynamic programming to avoid redundant calculations. Time complexity: O(n), Space complexity: O(n)",
            "does stuff with numbers",
            "TODO: fix this later",
            "Calculates sum of two numbers",
            "This function adds two integers and returns the result. Parameters: a (int), b (int). Returns: int sum",
            "loop through array",
            "DEPRECATED: Use calculate_new() instead. This method will be removed in v2.0",
        ]
    
    print("=" * 80)
    print("Code Comment Quality Classification")
    print("=" * 80)
    
    for i, comment in enumerate(comments, 1):
        print(f"\n[{i}] Comment: {comment}")
        print("-" * 80)
        
        result = classifier.predict(
            comment, 
            return_probabilities=True,
            confidence_threshold=args.confidence_threshold
        )
        
        if result['label'] is None:
            print(f"Predicted Quality: LOW CONFIDENCE (below threshold)")
            print(f"Confidence: {result['confidence']:.2%}")
        else:
            print(f"Predicted Quality: {result['label'].upper()}")
            print(f"Confidence: {result['confidence']:.2%}")
        
        print("\nAll Probabilities:")
        for label, prob in sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True):
            bar = "█" * int(prob * 50)
            print(f"  {label:10s}: {bar:50s} {prob:.2%}")
    
    print("\n" + "=" * 80)


if __name__ == "__main__":
    main()