Snaseem2026 commited on
Commit
aa71938
·
verified ·
1 Parent(s): e4495a9

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +287 -0
inference.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for Code Comment Quality Classifier
3
+ """
4
+ import argparse
5
+ import torch
6
+ import logging
7
+ from typing import List, Dict, Union, Optional
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
+
10
+
11
+ class CommentQualityClassifier:
12
+ """Wrapper class for easy inference with optimizations."""
13
+
14
+ def __init__(
15
+ self,
16
+ model_path: str,
17
+ device: Optional[str] = None,
18
+ use_fp16: bool = False
19
+ ):
20
+ """
21
+ Initialize the classifier.
22
+
23
+ Args:
24
+ model_path: Path to the trained model or Hugging Face model ID
25
+ device: Device to run inference on ('cuda', 'cpu', or None for auto-detect)
26
+ use_fp16: Whether to use half precision for faster inference (GPU only)
27
+ """
28
+ logging.info(f"Loading model from {model_path}...")
29
+
30
+ # Auto-detect device
31
+ if device is None:
32
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
33
+ self.device = torch.device(device)
34
+
35
+ # Load tokenizer and model
36
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
37
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
38
+
39
+ # Move model to device
40
+ self.model.to(self.device)
41
+
42
+ # Enable half precision if requested and on GPU
43
+ if use_fp16 and self.device.type == 'cuda':
44
+ self.model.half()
45
+ logging.info("Using FP16 precision for inference")
46
+
47
+ self.model.eval()
48
+
49
+ # Get label mapping
50
+ self.id2label = self.model.config.id2label
51
+ self.label2id = self.model.config.label2id
52
+
53
+ logging.info(f"Model loaded successfully on {self.device}")
54
+ logging.info(f"Labels: {list(self.id2label.values())}")
55
+ print(f"✓ Model loaded successfully on {self.device}")
56
+ print(f"✓ Labels: {list(self.id2label.values())}")
57
+
58
+ def predict(
59
+ self,
60
+ comment: str,
61
+ return_probabilities: bool = False,
62
+ confidence_threshold: Optional[float] = None
63
+ ) -> Union[str, Dict]:
64
+ """
65
+ Predict the quality of a code comment.
66
+
67
+ Args:
68
+ comment: The code comment text
69
+ return_probabilities: If True, return probabilities for all classes
70
+ confidence_threshold: Optional minimum confidence threshold (returns None if below)
71
+
72
+ Returns:
73
+ If return_probabilities is False: predicted label (str) or None if below threshold
74
+ If return_probabilities is True: dict with label, confidence, and probabilities
75
+ """
76
+ if not comment or not comment.strip():
77
+ logging.warning("Empty comment provided")
78
+ if return_probabilities:
79
+ return {
80
+ 'label': None,
81
+ 'confidence': 0.0,
82
+ 'probabilities': {}
83
+ }
84
+ return None
85
+
86
+ # Tokenize input
87
+ inputs = self.tokenizer(
88
+ comment,
89
+ return_tensors="pt",
90
+ truncation=True,
91
+ max_length=512,
92
+ padding=True
93
+ )
94
+
95
+ # Move inputs to device
96
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
97
+
98
+ # Get predictions
99
+ with torch.no_grad():
100
+ outputs = self.model(**inputs)
101
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
102
+ predicted_class = torch.argmax(probabilities, dim=-1).item()
103
+
104
+ confidence = probabilities[0][predicted_class].item()
105
+ predicted_label = self.id2label[predicted_class]
106
+
107
+ # Check confidence threshold
108
+ if confidence_threshold and confidence < confidence_threshold:
109
+ if return_probabilities:
110
+ return {
111
+ 'label': None,
112
+ 'confidence': confidence,
113
+ 'probabilities': {
114
+ self.id2label[i]: prob.item()
115
+ for i, prob in enumerate(probabilities[0])
116
+ },
117
+ 'below_threshold': True
118
+ }
119
+ return None
120
+
121
+ if return_probabilities:
122
+ prob_dict = {
123
+ self.id2label[i]: prob.item()
124
+ for i, prob in enumerate(probabilities[0])
125
+ }
126
+ return {
127
+ 'label': predicted_label,
128
+ 'confidence': confidence,
129
+ 'probabilities': prob_dict
130
+ }
131
+
132
+ return predicted_label
133
+
134
+ def predict_batch(
135
+ self,
136
+ comments: List[str],
137
+ batch_size: int = 32,
138
+ return_probabilities: bool = False
139
+ ) -> Union[List[str], List[Dict]]:
140
+ """
141
+ Predict quality for multiple comments with batching support.
142
+
143
+ Args:
144
+ comments: List of code comment texts
145
+ batch_size: Batch size for processing
146
+ return_probabilities: If True, return full probability dicts
147
+
148
+ Returns:
149
+ List of predicted labels or dicts with probabilities
150
+ """
151
+ if not comments:
152
+ return []
153
+
154
+ all_results = []
155
+
156
+ # Process in batches
157
+ for i in range(0, len(comments), batch_size):
158
+ batch = comments[i:i + batch_size]
159
+
160
+ # Tokenize batch
161
+ inputs = self.tokenizer(
162
+ batch,
163
+ return_tensors="pt",
164
+ truncation=True,
165
+ max_length=512,
166
+ padding=True
167
+ )
168
+
169
+ # Move inputs to device
170
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
171
+
172
+ with torch.no_grad():
173
+ outputs = self.model(**inputs)
174
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
175
+ predictions = torch.argmax(probabilities, dim=-1)
176
+
177
+ if return_probabilities:
178
+ for j, (pred, probs) in enumerate(zip(predictions, probabilities)):
179
+ prob_dict = {
180
+ self.id2label[k]: prob.item()
181
+ for k, prob in enumerate(probs)
182
+ }
183
+ all_results.append({
184
+ 'label': self.id2label[pred.item()],
185
+ 'confidence': probs[pred.item()].item(),
186
+ 'probabilities': prob_dict
187
+ })
188
+ else:
189
+ all_results.extend([self.id2label[pred.item()] for pred in predictions])
190
+
191
+ return all_results
192
+
193
+
194
+ def main():
195
+ """Main inference function with example usage."""
196
+ parser = argparse.ArgumentParser(description="Classify code comment quality")
197
+ parser.add_argument(
198
+ "--model-path",
199
+ type=str,
200
+ default="./results/final_model",
201
+ help="Path to the trained model"
202
+ )
203
+ parser.add_argument(
204
+ "--comment",
205
+ type=str,
206
+ help="Code comment to classify"
207
+ )
208
+ parser.add_argument(
209
+ "--device",
210
+ type=str,
211
+ choices=['cuda', 'cpu', 'auto'],
212
+ default='auto',
213
+ help="Device to run inference on"
214
+ )
215
+ parser.add_argument(
216
+ "--fp16",
217
+ action="store_true",
218
+ help="Use FP16 precision for faster inference (GPU only)"
219
+ )
220
+ parser.add_argument(
221
+ "--confidence-threshold",
222
+ type=float,
223
+ default=None,
224
+ help="Minimum confidence threshold (0.0-1.0)"
225
+ )
226
+ args = parser.parse_args()
227
+
228
+ # Setup logging
229
+ logging.basicConfig(
230
+ level=logging.INFO,
231
+ format='%(asctime)s - %(levelname)s - %(message)s'
232
+ )
233
+
234
+ # Initialize classifier
235
+ device = None if args.device == 'auto' else args.device
236
+ classifier = CommentQualityClassifier(
237
+ args.model_path,
238
+ device=device,
239
+ use_fp16=args.fp16
240
+ )
241
+
242
+ # Example comments if none provided
243
+ if args.comment:
244
+ comments = [args.comment]
245
+ else:
246
+ print("\nNo comment provided. Using example comments...\n")
247
+ comments = [
248
+ "This function calculates the Fibonacci sequence using dynamic programming to avoid redundant calculations. Time complexity: O(n), Space complexity: O(n)",
249
+ "does stuff with numbers",
250
+ "TODO: fix this later",
251
+ "Calculates sum of two numbers",
252
+ "This function adds two integers and returns the result. Parameters: a (int), b (int). Returns: int sum",
253
+ "loop through array",
254
+ "DEPRECATED: Use calculate_new() instead. This method will be removed in v2.0",
255
+ ]
256
+
257
+ print("=" * 80)
258
+ print("Code Comment Quality Classification")
259
+ print("=" * 80)
260
+
261
+ for i, comment in enumerate(comments, 1):
262
+ print(f"\n[{i}] Comment: {comment}")
263
+ print("-" * 80)
264
+
265
+ result = classifier.predict(
266
+ comment,
267
+ return_probabilities=True,
268
+ confidence_threshold=args.confidence_threshold
269
+ )
270
+
271
+ if result['label'] is None:
272
+ print(f"Predicted Quality: LOW CONFIDENCE (below threshold)")
273
+ print(f"Confidence: {result['confidence']:.2%}")
274
+ else:
275
+ print(f"Predicted Quality: {result['label'].upper()}")
276
+ print(f"Confidence: {result['confidence']:.2%}")
277
+
278
+ print("\nAll Probabilities:")
279
+ for label, prob in sorted(result['probabilities'].items(), key=lambda x: x[1], reverse=True):
280
+ bar = "█" * int(prob * 50)
281
+ print(f" {label:10s}: {bar:50s} {prob:.2%}")
282
+
283
+ print("\n" + "=" * 80)
284
+
285
+
286
+ if __name__ == "__main__":
287
+ main()