Snaseem2026 commited on
Commit
e4495a9
Β·
verified Β·
1 Parent(s): 3ab633a

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +211 -0
train.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main training script for Code Comment Quality Classifier
3
+ """
4
+ import os
5
+ import argparse
6
+ import logging
7
+ from pathlib import Path
8
+ from transformers import (
9
+ Trainer,
10
+ TrainingArguments,
11
+ EarlyStoppingCallback
12
+ )
13
+ from src import (
14
+ load_config,
15
+ prepare_datasets_for_training,
16
+ create_model,
17
+ get_model_size,
18
+ get_trainable_params,
19
+ compute_metrics_factory
20
+ )
21
+
22
+
23
+ def setup_logging(config: dict) -> None:
24
+ """Setup logging configuration."""
25
+ log_config = config.get('logging', {})
26
+ log_level = getattr(logging, log_config.get('level', 'INFO'))
27
+ log_file = log_config.get('log_file', './results/training.log')
28
+
29
+ # Create log directory if needed
30
+ log_dir = os.path.dirname(log_file)
31
+ if log_dir:
32
+ os.makedirs(log_dir, exist_ok=True)
33
+
34
+ # Configure logging
35
+ logging.basicConfig(
36
+ level=log_level,
37
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
38
+ handlers=[
39
+ logging.FileHandler(log_file),
40
+ logging.StreamHandler()
41
+ ]
42
+ )
43
+
44
+
45
+ def main(config_path: str = "config.yaml"):
46
+ """
47
+ Main training function.
48
+
49
+ Args:
50
+ config_path: Path to configuration file
51
+ """
52
+ print("=" * 60)
53
+ print("Code Comment Quality Classifier - Training")
54
+ print("=" * 60)
55
+
56
+ # Load configuration
57
+ print("\n[1/7] Loading configuration...")
58
+ config = load_config(config_path)
59
+ print(f"βœ“ Configuration loaded from {config_path}")
60
+
61
+ # Validate configuration
62
+ from src.validation import validate_config
63
+ config_errors = validate_config(config)
64
+ if config_errors:
65
+ print("\nβœ— Configuration validation errors:")
66
+ for error in config_errors:
67
+ print(f" - {error}")
68
+ raise ValueError("Invalid configuration. Please fix the errors above.")
69
+
70
+ # Setup logging
71
+ setup_logging(config)
72
+ logging.info("Starting training process")
73
+
74
+ # Prepare datasets
75
+ print("\n[2/7] Preparing datasets...")
76
+ tokenized_datasets, label2id, id2label, tokenizer = prepare_datasets_for_training(config_path)
77
+ print(f"βœ“ Train samples: {len(tokenized_datasets['train'])}")
78
+ print(f"βœ“ Validation samples: {len(tokenized_datasets['validation'])}")
79
+ print(f"βœ“ Test samples: {len(tokenized_datasets['test'])}")
80
+ logging.info(f"Dataset sizes - Train: {len(tokenized_datasets['train'])}, "
81
+ f"Val: {len(tokenized_datasets['validation'])}, "
82
+ f"Test: {len(tokenized_datasets['test'])}")
83
+
84
+ # Create model
85
+ print("\n[3/7] Loading model...")
86
+ dropout = config['model'].get('dropout')
87
+ model = create_model(
88
+ model_name=config['model']['name'],
89
+ num_labels=config['model']['num_labels'],
90
+ label2id=label2id,
91
+ id2label=id2label,
92
+ dropout=dropout
93
+ )
94
+ model_size = get_model_size(model)
95
+ params_info = get_trainable_params(model)
96
+ print(f"βœ“ Model: {config['model']['name']}")
97
+ print(f"βœ“ Total Parameters: {model_size:.2f}M")
98
+ print(f"βœ“ Trainable Parameters: {params_info['trainable'] / 1e6:.2f}M")
99
+ logging.info(f"Model: {config['model']['name']}, Size: {model_size:.2f}M parameters")
100
+
101
+ # Setup training arguments
102
+ print("\n[4/7] Setting up training...")
103
+ output_dir = config['training']['output_dir']
104
+ os.makedirs(output_dir, exist_ok=True)
105
+
106
+ training_args = TrainingArguments(
107
+ output_dir=output_dir,
108
+ num_train_epochs=config['training']['num_train_epochs'],
109
+ per_device_train_batch_size=config['training']['per_device_train_batch_size'],
110
+ per_device_eval_batch_size=config['training']['per_device_eval_batch_size'],
111
+ gradient_accumulation_steps=config['training'].get('gradient_accumulation_steps', 1),
112
+ learning_rate=config['training']['learning_rate'],
113
+ lr_scheduler_type=config['training'].get('lr_scheduler_type', 'linear'),
114
+ weight_decay=config['training']['weight_decay'],
115
+ warmup_steps=config['training'].get('warmup_steps'),
116
+ warmup_ratio=config['training'].get('warmup_ratio'),
117
+ logging_dir=os.path.join(output_dir, 'logs'),
118
+ logging_steps=config['training']['logging_steps'],
119
+ eval_steps=config['training']['eval_steps'],
120
+ save_steps=config['training']['save_steps'],
121
+ save_total_limit=config['training'].get('save_total_limit', 3),
122
+ eval_strategy=config['training']['evaluation_strategy'],
123
+ save_strategy=config['training']['save_strategy'],
124
+ load_best_model_at_end=config['training']['load_best_model_at_end'],
125
+ metric_for_best_model=config['training']['metric_for_best_model'],
126
+ greater_is_better=config['training'].get('greater_is_better', True),
127
+ seed=config['training']['seed'],
128
+ fp16=config['training'].get('fp16', False),
129
+ dataloader_num_workers=config['training'].get('dataloader_num_workers', 4),
130
+ dataloader_pin_memory=config['training'].get('dataloader_pin_memory', True),
131
+ remove_unused_columns=config['training'].get('remove_unused_columns', True),
132
+ report_to=config['training'].get('report_to', ['none']),
133
+ push_to_hub=False,
134
+ )
135
+
136
+ # Create compute_metrics function with label mapping
137
+ compute_metrics_fn = compute_metrics_factory(id2label)
138
+
139
+ # Setup callbacks
140
+ callbacks = []
141
+ if config['training'].get('early_stopping_patience'):
142
+ early_stopping = EarlyStoppingCallback(
143
+ early_stopping_patience=config['training']['early_stopping_patience'],
144
+ early_stopping_threshold=config['training'].get('early_stopping_threshold', 0.0)
145
+ )
146
+ callbacks.append(early_stopping)
147
+ logging.info(f"Early stopping enabled with patience={config['training']['early_stopping_patience']}")
148
+
149
+ # Create trainer
150
+ trainer = Trainer(
151
+ model=model,
152
+ args=training_args,
153
+ train_dataset=tokenized_datasets['train'],
154
+ eval_dataset=tokenized_datasets['validation'],
155
+ tokenizer=tokenizer,
156
+ compute_metrics=compute_metrics_fn,
157
+ callbacks=callbacks
158
+ )
159
+
160
+ print("βœ“ Trainer initialized")
161
+ logging.info("Trainer initialized with all configurations")
162
+
163
+ # Train model
164
+ print("\n[5/7] Training model...")
165
+ print("-" * 60)
166
+ logging.info("Starting training")
167
+ train_result = trainer.train()
168
+ logging.info(f"Training completed. Train loss: {train_result.training_loss:.4f}")
169
+
170
+ # Save final model
171
+ print("\n[6/7] Saving model...")
172
+ final_model_path = os.path.join(output_dir, 'final_model')
173
+ trainer.save_model(final_model_path)
174
+ tokenizer.save_pretrained(final_model_path)
175
+ print(f"βœ“ Model saved to {final_model_path}")
176
+ logging.info(f"Model saved to {final_model_path}")
177
+
178
+ # Evaluate on test set
179
+ print("\n[7/7] Evaluating on test set...")
180
+ print("=" * 60)
181
+ print("Final Evaluation on Test Set")
182
+ print("=" * 60)
183
+ test_results = trainer.evaluate(tokenized_datasets['test'], metric_key_prefix='test')
184
+
185
+ print("\nTest Results:")
186
+ for key, value in sorted(test_results.items()):
187
+ if isinstance(value, float):
188
+ print(f" {key}: {value:.4f}")
189
+ logging.info("Test evaluation completed")
190
+
191
+ print("\n" + "=" * 60)
192
+ print("Training Complete! πŸŽ‰")
193
+ print("=" * 60)
194
+ print(f"\nModel location: {final_model_path}")
195
+ print("\nNext steps:")
196
+ print("1. Run evaluation: python scripts/evaluate.py")
197
+ print("2. Test inference: python inference.py")
198
+ print("3. Upload to Hub: python scripts/upload_to_hub.py")
199
+
200
+
201
+ if __name__ == "__main__":
202
+ parser = argparse.ArgumentParser(description="Train Code Comment Quality Classifier")
203
+ parser.add_argument(
204
+ "--config",
205
+ type=str,
206
+ default="config.yaml",
207
+ help="Path to configuration file"
208
+ )
209
+ args = parser.parse_args()
210
+
211
+ main(args.config)