mgbam commited on
Commit
8371eb2
Β·
verified Β·
1 Parent(s): 7569002

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +383 -0
app.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Space: AST Training Dashboard
3
+ Live monitoring and model card generation
4
+ """
5
+
6
+ import gradio as gr
7
+ import json
8
+ import time
9
+ from pathlib import Path
10
+ import plotly.graph_objects as go
11
+ from plotly.subplots import make_subplots
12
+
13
+ try:
14
+ from adaptive_sparse_training import AdaptiveSparseTrainer, ASTConfig
15
+ import torch
16
+ import torchvision
17
+ import timm
18
+ HAS_DEPS = True
19
+ except ImportError:
20
+ HAS_DEPS = False
21
+
22
+
23
+ class ASTDashboard:
24
+ """Real-time AST training dashboard"""
25
+
26
+ def __init__(self):
27
+ self.active_training = None
28
+ self.training_history = []
29
+
30
+ def start_training(
31
+ self,
32
+ model_name: str,
33
+ dataset: str,
34
+ activation_rate: float,
35
+ epochs: int,
36
+ progress=gr.Progress()
37
+ ):
38
+ """Start AST training with live updates"""
39
+
40
+ if not HAS_DEPS:
41
+ return "❌ Dependencies not installed", None, None
42
+
43
+ progress(0, desc="Initializing...")
44
+
45
+ # Load dataset (CIFAR-10 for demo)
46
+ train_loader, val_loader = self._get_dataloaders(dataset)
47
+
48
+ # Create model
49
+ progress(0.1, desc="Creating model...")
50
+ if model_name == "resnet18":
51
+ model = torchvision.models.resnet18(num_classes=10)
52
+ else:
53
+ model = timm.create_model(model_name, pretrained=False, num_classes=10)
54
+
55
+ # AST Config
56
+ config = ASTConfig(
57
+ target_activation_rate=activation_rate,
58
+ entropy_weight=1.0,
59
+ use_mixed_precision=True,
60
+ )
61
+
62
+ # Start training
63
+ progress(0.2, desc="Starting training...")
64
+ trainer = AdaptiveSparseTrainer(model, train_loader, val_loader, config)
65
+
66
+ self.training_history = []
67
+
68
+ for epoch in range(epochs):
69
+ progress((epoch + 1) / epochs, desc=f"Epoch {epoch+1}/{epochs}")
70
+
71
+ # Train one epoch
72
+ epoch_stats = trainer.train_epoch(epoch)
73
+ val_acc = trainer.evaluate()
74
+
75
+ # Store history
76
+ self.training_history.append({
77
+ "epoch": epoch + 1,
78
+ "val_acc": val_acc,
79
+ "activation_rate": epoch_stats.get("activation_rate", activation_rate),
80
+ "threshold": epoch_stats.get("threshold", 1.0),
81
+ })
82
+
83
+ # Update dashboard
84
+ if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
85
+ status = self._format_status(epoch + 1, epochs, val_acc, activation_rate)
86
+ plot = self._create_plot()
87
+ yield status, plot, None
88
+
89
+ # Generate model card
90
+ model_card = self._generate_model_card(model_name, activation_rate)
91
+
92
+ final_status = f"βœ… Training complete! Best accuracy: {max([h['val_acc'] for h in self.training_history]):.2%}"
93
+
94
+ yield final_status, self._create_plot(), model_card
95
+
96
+ def _get_dataloaders(self, dataset: str):
97
+ """Get data loaders (CIFAR-10 demo)"""
98
+ import torchvision.transforms as transforms
99
+ from torch.utils.data import DataLoader
100
+
101
+ transform = transforms.Compose([
102
+ transforms.ToTensor(),
103
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
104
+ ])
105
+
106
+ train_dataset = torchvision.datasets.CIFAR10(
107
+ root='./data', train=True, download=True, transform=transform
108
+ )
109
+ val_dataset = torchvision.datasets.CIFAR10(
110
+ root='./data', train=False, download=True, transform=transform
111
+ )
112
+
113
+ train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
114
+ val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
115
+
116
+ return train_loader, val_loader
117
+
118
+ def _format_status(self, epoch: int, total_epochs: int, accuracy: float, activation_rate: float):
119
+ """Format training status"""
120
+ return f"""
121
+ ### πŸš€ Training in Progress
122
+
123
+ **Epoch:** {epoch}/{total_epochs}
124
+ **Accuracy:** {accuracy:.2%}
125
+ **Activation Rate:** {activation_rate:.1%}
126
+ **Energy Savings:** ~{(1-activation_rate)*100:.0f}%
127
+
128
+ *Updating every 5 epochs...*
129
+ """
130
+
131
+ def _create_plot(self):
132
+ """Create live training plot"""
133
+ if not self.training_history:
134
+ return None
135
+
136
+ fig = make_subplots(
137
+ rows=2, cols=2,
138
+ subplot_titles=("Validation Accuracy", "Activation Rate", "Threshold Evolution", "Energy Savings"),
139
+ )
140
+
141
+ epochs = [h["epoch"] for h in self.training_history]
142
+ accuracies = [h["val_acc"] * 100 for h in self.training_history]
143
+ activation_rates = [h["activation_rate"] * 100 for h in self.training_history]
144
+ thresholds = [h["threshold"] for h in self.training_history]
145
+ savings = [(1 - h["activation_rate"]) * 100 for h in self.training_history]
146
+
147
+ # Accuracy plot
148
+ fig.add_trace(
149
+ go.Scatter(x=epochs, y=accuracies, mode='lines+markers', name='Val Accuracy',
150
+ line=dict(color='#3498db', width=3)),
151
+ row=1, col=1
152
+ )
153
+
154
+ # Activation rate plot
155
+ fig.add_trace(
156
+ go.Scatter(x=epochs, y=activation_rates, mode='lines+markers', name='Activation Rate',
157
+ line=dict(color='#e74c3c', width=3)),
158
+ row=1, col=2
159
+ )
160
+
161
+ # Threshold plot
162
+ fig.add_trace(
163
+ go.Scatter(x=epochs, y=thresholds, mode='lines+markers', name='Threshold',
164
+ line=dict(color='#f39c12', width=3)),
165
+ row=2, col=1
166
+ )
167
+
168
+ # Energy savings plot
169
+ fig.add_trace(
170
+ go.Scatter(x=epochs, y=savings, mode='lines+markers', name='Energy Savings',
171
+ line=dict(color='#27ae60', width=3), fill='tozeroy'),
172
+ row=2, col=2
173
+ )
174
+
175
+ fig.update_xaxes(title_text="Epoch")
176
+ fig.update_yaxes(title_text="Accuracy (%)", row=1, col=1)
177
+ fig.update_yaxes(title_text="Activation (%)", row=1, col=2)
178
+ fig.update_yaxes(title_text="Threshold", row=2, col=1)
179
+ fig.update_yaxes(title_text="Savings (%)", row=2, col=2)
180
+
181
+ fig.update_layout(height=600, showlegend=False)
182
+
183
+ return fig
184
+
185
+ def _generate_model_card(self, model_name: str, activation_rate: float):
186
+ """Generate HuggingFace model card"""
187
+
188
+ best_acc = max([h["val_acc"] for h in self.training_history])
189
+ energy_savings = (1 - activation_rate) * 100
190
+
191
+ return f"""---
192
+ tags:
193
+ - adaptive-sparse-training
194
+ - energy-efficient
195
+ - sustainability
196
+ metrics:
197
+ - accuracy
198
+ - energy_savings
199
+ ---
200
+
201
+ # {model_name} (AST-Trained)
202
+
203
+ **Trained with {energy_savings:.0f}% less energy than standard training** ⚑
204
+
205
+ ## Model Details
206
+ - **Architecture:** {model_name}
207
+ - **Dataset:** CIFAR-10
208
+ - **Training Method:** Adaptive Sparse Training (AST)
209
+ - **Target Activation Rate:** {activation_rate:.0%}
210
+
211
+ ## Performance
212
+ - **Accuracy:** {best_acc:.2%}
213
+ - **Energy Savings:** {energy_savings:.0f}%
214
+ - **Training Epochs:** {len(self.training_history)}
215
+
216
+ ## Sustainability Report
217
+ This model was trained using Adaptive Sparse Training, which dynamically selects
218
+ the most important training samples. This resulted in:
219
+
220
+ - ⚑ **{energy_savings:.0f}% energy savings** compared to standard training
221
+ - 🌍 **Lower carbon footprint**
222
+ - ⏱️ **Faster training time**
223
+ - 🎯 **Maintained accuracy** (minimal degradation)
224
+
225
+ ## How to Use
226
+
227
+ ```python
228
+ import torch
229
+ from torchvision import models
230
+
231
+ # Load model
232
+ model = models.{model_name}(num_classes=10)
233
+ model.load_state_dict(torch.load("pytorch_model.bin"))
234
+ model.eval()
235
+
236
+ # Inference
237
+ # ... (your inference code)
238
+ ```
239
+
240
+ ## Training Details
241
+
242
+ **AST Configuration:**
243
+ - Target Activation Rate: {activation_rate:.0%}
244
+ - Entropy Weight: 1.0
245
+ - PI Controller: Enabled
246
+ - Mixed Precision: Enabled
247
+
248
+ ## Reproducing This Model
249
+
250
+ ```bash
251
+ pip install adaptive-sparse-training
252
+
253
+ python -c "
254
+ from adaptive_sparse_training import AdaptiveSparseTrainer, ASTConfig
255
+ config = ASTConfig(target_activation_rate={activation_rate})
256
+ # ... (full training code)
257
+ "
258
+ ```
259
+
260
+ ## Citation
261
+
262
+ If you use this model or AST, please cite:
263
+
264
+ ```bibtex
265
+ @software{{adaptive_sparse_training,
266
+ title={{Adaptive Sparse Training}},
267
+ author={{Idiakhoa, Oluwafemi}},
268
+ year={{2024}},
269
+ url={{https://github.com/oluwafemidiakhoa/adaptive-sparse-training}}
270
+ }}
271
+ ```
272
+
273
+ ## Acknowledgments
274
+
275
+ Trained using the `adaptive-sparse-training` package. Special thanks to the PyTorch and HuggingFace communities.
276
+
277
+ ---
278
+
279
+ *This model card was auto-generated by the AST Training Dashboard.*
280
+ """
281
+
282
+
283
+ # Initialize dashboard
284
+ dashboard = ASTDashboard()
285
+
286
+
287
+ # Gradio Interface
288
+ def create_demo():
289
+ """Create Gradio demo interface"""
290
+
291
+ with gr.Blocks(title="AST Training Dashboard", theme=gr.themes.Soft()) as demo:
292
+ gr.Markdown("""
293
+ # ⚑ Adaptive Sparse Training Dashboard
294
+
295
+ Train models with **60-70% less energy** while maintaining accuracy!
296
+
297
+ This demo trains a model on CIFAR-10 using AST and generates a HuggingFace model card.
298
+ """)
299
+
300
+ with gr.Row():
301
+ with gr.Column(scale=1):
302
+ gr.Markdown("### βš™οΈ Configuration")
303
+
304
+ model_name = gr.Dropdown(
305
+ choices=["resnet18", "efficientnet_b0", "mobilenetv3_small_100"],
306
+ value="resnet18",
307
+ label="Model Architecture"
308
+ )
309
+
310
+ dataset = gr.Dropdown(
311
+ choices=["cifar10"],
312
+ value="cifar10",
313
+ label="Dataset"
314
+ )
315
+
316
+ activation_rate = gr.Slider(
317
+ minimum=0.2,
318
+ maximum=0.8,
319
+ value=0.35,
320
+ step=0.05,
321
+ label="Target Activation Rate (lower = more savings)"
322
+ )
323
+
324
+ gr.Markdown(f"**Energy Savings:** ~{(1-0.35)*100:.0f}%")
325
+
326
+ epochs = gr.Slider(
327
+ minimum=10,
328
+ maximum=100,
329
+ value=30,
330
+ step=10,
331
+ label="Training Epochs"
332
+ )
333
+
334
+ train_btn = gr.Button("πŸš€ Start Training", variant="primary", size="lg")
335
+
336
+ with gr.Column(scale=2):
337
+ gr.Markdown("### πŸ“Š Live Training Metrics")
338
+
339
+ status = gr.Markdown("*Ready to train...*")
340
+ plot = gr.Plot()
341
+
342
+ with gr.Row():
343
+ with gr.Column():
344
+ gr.Markdown("### πŸ“ Generated Model Card")
345
+ model_card = gr.Textbox(
346
+ label="HuggingFace Model Card (Markdown)",
347
+ lines=20,
348
+ max_lines=30,
349
+ )
350
+
351
+ gr.Markdown("""
352
+ **Next Steps:**
353
+ 1. Copy the model card above
354
+ 2. Create a new model on [HuggingFace Hub](https://huggingface.co/new)
355
+ 3. Paste the model card into `README.md`
356
+ 4. Upload your trained model weights
357
+ """)
358
+
359
+ # Training logic
360
+ train_btn.click(
361
+ fn=dashboard.start_training,
362
+ inputs=[model_name, dataset, activation_rate, epochs],
363
+ outputs=[status, plot, model_card],
364
+ )
365
+
366
+ gr.Markdown("""
367
+ ---
368
+
369
+ ## πŸ“š Learn More
370
+
371
+ - πŸ“¦ [PyPI Package](https://pypi.org/project/adaptive-sparse-training/)
372
+ - πŸ™ [GitHub Repo](https://github.com/oluwafemidiakhoa/adaptive-sparse-training)
373
+ - πŸ“– [Documentation](https://github.com/oluwafemidiakhoa/adaptive-sparse-training#readme)
374
+
375
+ **Made with ❀️ using Adaptive Sparse Training**
376
+ """)
377
+
378
+ return demo
379
+
380
+
381
+ if __name__ == "__main__":
382
+ demo = create_demo()
383
+ demo.launch()