mgbam commited on
Commit
dba4dcb
·
verified ·
1 Parent(s): aad72ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -52,14 +52,16 @@ class ASTDashboard:
52
  else:
53
  model = timm.create_model(model_name, pretrained=False, num_classes=10)
54
 
55
- # AST Config
56
  config = ASTConfig(
57
  target_activation_rate=activation_rate,
58
- use_amp=True,
 
59
  )
60
 
61
  # Start training
62
  progress(0.2, desc="Starting training...")
 
63
  trainer = AdaptiveSparseTrainer(model, train_loader, val_loader, config)
64
 
65
  self.training_history = []
@@ -109,8 +111,8 @@ class ASTDashboard:
109
  root='./data', train=False, download=True, transform=transform
110
  )
111
 
112
- train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
113
- val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
114
 
115
  return train_loader, val_loader
116
 
@@ -322,10 +324,10 @@ def create_demo():
322
  gr.Markdown(f"**Energy Savings:** ~{(1-0.35)*100:.0f}%")
323
 
324
  epochs = gr.Slider(
325
- minimum=10,
326
- maximum=100,
327
- value=30,
328
- step=10,
329
  label="Training Epochs"
330
  )
331
 
 
52
  else:
53
  model = timm.create_model(model_name, pretrained=False, num_classes=10)
54
 
55
+ # AST Config (CPU mode for HuggingFace free tier)
56
  config = ASTConfig(
57
  target_activation_rate=activation_rate,
58
+ use_amp=False, # Disable AMP on CPU
59
+ device='cpu'
60
  )
61
 
62
  # Start training
63
  progress(0.2, desc="Starting training...")
64
+ model = model.to('cpu')
65
  trainer = AdaptiveSparseTrainer(model, train_loader, val_loader, config)
66
 
67
  self.training_history = []
 
111
  root='./data', train=False, download=True, transform=transform
112
  )
113
 
114
+ train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
115
+ val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=0)
116
 
117
  return train_loader, val_loader
118
 
 
324
  gr.Markdown(f"**Energy Savings:** ~{(1-0.35)*100:.0f}%")
325
 
326
  epochs = gr.Slider(
327
+ minimum=5,
328
+ maximum=50,
329
+ value=10,
330
+ step=5,
331
  label="Training Epochs"
332
  )
333