feat: add TrainerCallback to stream live training logs to UI
Browse files- Implemented GradioLoggerCallback to forward Hugging Face Trainer logs to Gradio
- Replaced pre-loop simulated logging with true per-step feedback
- UI now shows step-by-step progress without freezing or blocking
- train_abuse_model.py +18 -1
train_abuse_model.py
CHANGED
|
@@ -7,6 +7,7 @@ import os
|
|
| 7 |
import time
|
| 8 |
import gradio as gr # β
required for progress bar
|
| 9 |
from pathlib import Path
|
|
|
|
| 10 |
|
| 11 |
# Python standard + ML packages
|
| 12 |
import pandas as pd
|
|
@@ -23,6 +24,7 @@ from huggingface_hub import hf_hub_download
|
|
| 23 |
# Hugging Face transformers
|
| 24 |
import transformers
|
| 25 |
from transformers import (
|
|
|
|
| 26 |
AutoTokenizer,
|
| 27 |
DebertaV2Tokenizer,
|
| 28 |
BertTokenizer,
|
|
@@ -66,6 +68,15 @@ logger.info(f"Transformers version: {transformers.__version__}")
|
|
| 66 |
logger.info("torch.cuda.is_available(): %s", torch.cuda.is_available())
|
| 67 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
def evaluate_model_with_thresholds(trainer, test_dataset):
|
| 71 |
"""Run full evaluation with automatic threshold tuning."""
|
|
@@ -191,6 +202,7 @@ train_texts, val_texts, train_labels, val_labels = train_test_split(
|
|
| 191 |
model_name = "microsoft/deberta-v3-base"
|
| 192 |
|
| 193 |
def run_training(progress=gr.Progress(track_tqdm=True)):
|
|
|
|
| 194 |
if os.path.exists("saved_model/"):
|
| 195 |
yield "β
Trained model found! Skipping training...\n"
|
| 196 |
for line in evaluate_saved_model():
|
|
@@ -239,7 +251,8 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
|
|
| 239 |
model=model,
|
| 240 |
args=training_args,
|
| 241 |
train_dataset=train_dataset,
|
| 242 |
-
eval_dataset=val_dataset
|
|
|
|
| 243 |
)
|
| 244 |
|
| 245 |
logger.info("Training started with %d samples", len(train_dataset))
|
|
@@ -262,6 +275,10 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
|
|
| 262 |
# Start training!
|
| 263 |
trainer.train()
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
progress(1.0)
|
| 266 |
yield "β
Progress: 100%\n"
|
| 267 |
|
|
|
|
| 7 |
import time
|
| 8 |
import gradio as gr # β
required for progress bar
|
| 9 |
from pathlib import Path
|
| 10 |
+
import queue
|
| 11 |
|
| 12 |
# Python standard + ML packages
|
| 13 |
import pandas as pd
|
|
|
|
| 24 |
# Hugging Face transformers
|
| 25 |
import transformers
|
| 26 |
from transformers import (
|
| 27 |
+
TrainerCallback,
|
| 28 |
AutoTokenizer,
|
| 29 |
DebertaV2Tokenizer,
|
| 30 |
BertTokenizer,
|
|
|
|
| 68 |
logger.info("torch.cuda.is_available(): %s", torch.cuda.is_available())
|
| 69 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 70 |
|
| 71 |
+
class GradioLoggerCallback(TrainerCallback):
|
| 72 |
+
def __init__(self, gr_queue):
|
| 73 |
+
self.gr_queue = gr_queue
|
| 74 |
+
|
| 75 |
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 76 |
+
if logs:
|
| 77 |
+
msg = f"π Step {state.global_step}: {logs}"
|
| 78 |
+
logger.info(msg)
|
| 79 |
+
self.gr_queue.put(msg)
|
| 80 |
|
| 81 |
def evaluate_model_with_thresholds(trainer, test_dataset):
|
| 82 |
"""Run full evaluation with automatic threshold tuning."""
|
|
|
|
| 202 |
model_name = "microsoft/deberta-v3-base"
|
| 203 |
|
| 204 |
def run_training(progress=gr.Progress(track_tqdm=True)):
|
| 205 |
+
log_queue = queue.Queue()
|
| 206 |
if os.path.exists("saved_model/"):
|
| 207 |
yield "β
Trained model found! Skipping training...\n"
|
| 208 |
for line in evaluate_saved_model():
|
|
|
|
| 251 |
model=model,
|
| 252 |
args=training_args,
|
| 253 |
train_dataset=train_dataset,
|
| 254 |
+
eval_dataset=val_dataset,
|
| 255 |
+
callbacks=[GradioLoggerCallback(log_queue)]
|
| 256 |
)
|
| 257 |
|
| 258 |
logger.info("Training started with %d samples", len(train_dataset))
|
|
|
|
| 275 |
# Start training!
|
| 276 |
trainer.train()
|
| 277 |
|
| 278 |
+
# Drain queue to UI
|
| 279 |
+
while not log_queue.empty():
|
| 280 |
+
yield log_queue.get()
|
| 281 |
+
|
| 282 |
progress(1.0)
|
| 283 |
yield "β
Progress: 100%\n"
|
| 284 |
|