Spaces:
Sleeping
Sleeping
try again on gradio progress update streaming
Browse files- app.py +2 -2
- evaluation.py +12 -3
app.py
CHANGED
|
@@ -108,7 +108,7 @@ def run_full_evaluation_gradio():
|
|
| 108 |
for update in evaluate(model, test_dataloader_full, device):
|
| 109 |
if isinstance(update, dict):
|
| 110 |
# This is the final results dictionary
|
| 111 |
-
results_str = "--- Full Evaluation Results ---\n"
|
| 112 |
for key, value in update.items():
|
| 113 |
if isinstance(value, float):
|
| 114 |
results_str += f"{key.capitalize()}: {value:.4f}\n"
|
|
@@ -120,7 +120,7 @@ def run_full_evaluation_gradio():
|
|
| 120 |
break # Stop after getting the results dict
|
| 121 |
else:
|
| 122 |
# This is a progress string
|
| 123 |
-
yield update
|
| 124 |
|
| 125 |
# Ensure the final formatted results string is yielded if not already (e.g., if loop broke early)
|
| 126 |
# However, the logic above should yield it before breaking.
|
|
|
|
| 108 |
for update in evaluate(model, test_dataloader_full, device):
|
| 109 |
if isinstance(update, dict):
|
| 110 |
# This is the final results dictionary
|
| 111 |
+
results_str = "\n--- Full Evaluation Results ---\n" # Start with a newline
|
| 112 |
for key, value in update.items():
|
| 113 |
if isinstance(value, float):
|
| 114 |
results_str += f"{key.capitalize()}: {value:.4f}\n"
|
|
|
|
| 120 |
break # Stop after getting the results dict
|
| 121 |
else:
|
| 122 |
# This is a progress string
|
| 123 |
+
yield str(update) + "\n" # Append newline to each progress string
|
| 124 |
|
| 125 |
# Ensure the final formatted results string is yielded if not already (e.g., if loop broke early)
|
| 126 |
# However, the logic above should yield it before breaking.
|
evaluation.py
CHANGED
|
@@ -13,8 +13,6 @@ def evaluate(model, dataloader, device):
|
|
| 13 |
num_batches = len(dataloader)
|
| 14 |
processed_batches = 0
|
| 15 |
|
| 16 |
-
yield "Starting evaluation..."
|
| 17 |
-
|
| 18 |
with torch.no_grad():
|
| 19 |
for batch in dataloader: # dataloader here should not be pre-wrapped with tqdm by the caller if we yield progress
|
| 20 |
processed_batches += 1
|
|
@@ -54,8 +52,19 @@ def evaluate(model, dataloader, device):
|
|
| 54 |
all_preds.extend(preds.cpu().numpy())
|
| 55 |
|
| 56 |
all_labels.extend(labels.cpu().numpy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# Yield progress update
|
| 58 |
-
|
|
|
|
| 59 |
yield f"Processed {processed_batches}/{num_batches} batches ({processed_batches/num_batches*100:.2f}%)"
|
| 60 |
|
| 61 |
avg_loss = total_loss / num_batches
|
|
|
|
| 13 |
num_batches = len(dataloader)
|
| 14 |
processed_batches = 0
|
| 15 |
|
|
|
|
|
|
|
| 16 |
with torch.no_grad():
|
| 17 |
for batch in dataloader: # dataloader here should not be pre-wrapped with tqdm by the caller if we yield progress
|
| 18 |
processed_batches += 1
|
|
|
|
| 52 |
all_preds.extend(preds.cpu().numpy())
|
| 53 |
|
| 54 |
all_labels.extend(labels.cpu().numpy())
|
| 55 |
+
|
| 56 |
+
# Populate probabilities for AUC calculation
|
| 57 |
+
if logits.shape[1] > 1:
|
| 58 |
+
# Multi-class or multi-label, assuming positive class is at index 1 for binary-like AUC
|
| 59 |
+
probs_for_auc = torch.softmax(logits, dim=1)[:, 1]
|
| 60 |
+
else:
|
| 61 |
+
# Binary classification with a single logit output
|
| 62 |
+
probs_for_auc = torch.sigmoid(logits).squeeze()
|
| 63 |
+
all_probs_for_auc.extend(probs_for_auc.cpu().numpy())
|
| 64 |
+
|
| 65 |
# Yield progress update
|
| 66 |
+
progress_update_frequency = max(1, num_batches // 20) # Ensure at least 1 to avoid modulo zero
|
| 67 |
+
if processed_batches % progress_update_frequency == 0 or processed_batches == num_batches: # Update roughly 20 times + final
|
| 68 |
yield f"Processed {processed_batches}/{num_batches} batches ({processed_batches/num_batches*100:.2f}%)"
|
| 69 |
|
| 70 |
avg_loss = total_loss / num_batches
|