Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -52,14 +52,16 @@ class ASTDashboard:
|
|
| 52 |
else:
|
| 53 |
model = timm.create_model(model_name, pretrained=False, num_classes=10)
|
| 54 |
|
| 55 |
-
# AST Config
|
| 56 |
config = ASTConfig(
|
| 57 |
target_activation_rate=activation_rate,
|
| 58 |
-
use_amp=
|
|
|
|
| 59 |
)
|
| 60 |
|
| 61 |
# Start training
|
| 62 |
progress(0.2, desc="Starting training...")
|
|
|
|
| 63 |
trainer = AdaptiveSparseTrainer(model, train_loader, val_loader, config)
|
| 64 |
|
| 65 |
self.training_history = []
|
|
@@ -109,8 +111,8 @@ class ASTDashboard:
|
|
| 109 |
root='./data', train=False, download=True, transform=transform
|
| 110 |
)
|
| 111 |
|
| 112 |
-
train_loader = DataLoader(train_dataset, batch_size=
|
| 113 |
-
val_loader = DataLoader(val_dataset, batch_size=
|
| 114 |
|
| 115 |
return train_loader, val_loader
|
| 116 |
|
|
@@ -322,10 +324,10 @@ def create_demo():
|
|
| 322 |
gr.Markdown(f"**Energy Savings:** ~{(1-0.35)*100:.0f}%")
|
| 323 |
|
| 324 |
epochs = gr.Slider(
|
| 325 |
-
minimum=
|
| 326 |
-
maximum=
|
| 327 |
-
value=
|
| 328 |
-
step=
|
| 329 |
label="Training Epochs"
|
| 330 |
)
|
| 331 |
|
|
|
|
| 52 |
else:
|
| 53 |
model = timm.create_model(model_name, pretrained=False, num_classes=10)
|
| 54 |
|
| 55 |
+
# AST Config (CPU mode for HuggingFace free tier)
|
| 56 |
config = ASTConfig(
|
| 57 |
target_activation_rate=activation_rate,
|
| 58 |
+
use_amp=False, # Disable AMP on CPU
|
| 59 |
+
device='cpu'
|
| 60 |
)
|
| 61 |
|
| 62 |
# Start training
|
| 63 |
progress(0.2, desc="Starting training...")
|
| 64 |
+
model = model.to('cpu')
|
| 65 |
trainer = AdaptiveSparseTrainer(model, train_loader, val_loader, config)
|
| 66 |
|
| 67 |
self.training_history = []
|
|
|
|
| 111 |
root='./data', train=False, download=True, transform=transform
|
| 112 |
)
|
| 113 |
|
| 114 |
+
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
|
| 115 |
+
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0)
|
| 116 |
|
| 117 |
return train_loader, val_loader
|
| 118 |
|
|
|
|
| 324 |
gr.Markdown(f"**Energy Savings:** ~{(1-0.35)*100:.0f}%")
|
| 325 |
|
| 326 |
epochs = gr.Slider(
|
| 327 |
+
minimum=5,
|
| 328 |
+
maximum=50,
|
| 329 |
+
value=10,
|
| 330 |
+
step=5,
|
| 331 |
label="Training Epochs"
|
| 332 |
)
|
| 333 |
|