mckell commited on
Commit
9a3535b
·
verified ·
1 Parent(s): cc18481

Upload app.py

Browse files

Moved zero decorators to root level app.py load.

Files changed (1) hide show
  1. app.py +62 -0
app.py CHANGED
@@ -17,6 +17,8 @@ Environment variables:
17
  import os
18
  from pathlib import Path
19
 
 
 
20
  # Data source configuration
21
  DATA_REPO_ID = "mckell/diffviews_demo_data"
22
  CHECKPOINT_URLS = {
@@ -253,6 +255,60 @@ def get_device() -> str:
253
  return "cpu"
254
 
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  def main():
257
  """Main entry point for HF Spaces."""
258
  # Configuration from environment
@@ -288,6 +344,12 @@ def main():
288
  PLOTLY_HANDLER_JS,
289
  )
290
 
 
 
 
 
 
 
291
  print("\nInitializing visualizer...")
292
  visualizer = GradioVisualizer(
293
  data_dir=data_dir,
 
17
  import os
18
  from pathlib import Path
19
 
20
+ import spaces
21
+
22
  # Data source configuration
23
  DATA_REPO_ID = "mckell/diffviews_demo_data"
24
  CHECKPOINT_URLS = {
 
255
  return "cpu"
256
 
257
 
258
+ @spaces.GPU(duration=120)
259
+ def generate_on_gpu(
260
+ visualizer, model_name, all_neighbors, class_label,
261
+ n_steps, m_steps, s_max, s_min, guidance, noise_mode,
262
+ extract_layers, can_project
263
+ ):
264
+ """Run masked generation on GPU. Must live in app_file for ZeroGPU detection."""
265
+ from diffviews.core.masking import ActivationMasker
266
+ from diffviews.core.generator import generate_with_mask_multistep
267
+
268
+ with visualizer._generation_lock:
269
+ adapter = visualizer.load_adapter(model_name)
270
+ if adapter is None:
271
+ return None
272
+
273
+ activation_dict = visualizer.prepare_activation_dict(model_name, all_neighbors)
274
+ if activation_dict is None:
275
+ return None
276
+
277
+ masker = ActivationMasker(adapter)
278
+ for layer_name, activation in activation_dict.items():
279
+ masker.set_mask(layer_name, activation)
280
+ masker.register_hooks(list(activation_dict.keys()))
281
+
282
+ try:
283
+ result = generate_with_mask_multistep(
284
+ adapter,
285
+ masker,
286
+ class_label=class_label,
287
+ num_steps=int(n_steps),
288
+ mask_steps=int(m_steps),
289
+ sigma_max=float(s_max),
290
+ sigma_min=float(s_min),
291
+ guidance_scale=float(guidance),
292
+ noise_mode=(noise_mode or "stochastic noise").replace(" noise", ""),
293
+ num_samples=1,
294
+ device=visualizer.device,
295
+ extract_layers=extract_layers if can_project else None,
296
+ return_trajectory=can_project,
297
+ return_intermediates=True,
298
+ return_noised_inputs=True,
299
+ )
300
+ finally:
301
+ masker.remove_hooks()
302
+
303
+ return result
304
+
305
+
306
+ @spaces.GPU(duration=180)
307
+ def extract_layer_on_gpu(visualizer, model_name, layer_name, batch_size=32):
308
+ """Extract layer activations on GPU. Must live in app_file for ZeroGPU detection."""
309
+ return visualizer.extract_layer_activations(model_name, layer_name, batch_size)
310
+
311
+
312
  def main():
313
  """Main entry point for HF Spaces."""
314
  # Configuration from environment
 
344
  PLOTLY_HANDLER_JS,
345
  )
346
 
347
+ # Inject ZeroGPU-decorated functions into visualization module
348
+ # so Gradio callbacks use the versions codefind can detect
349
+ import diffviews.visualization.app as viz_mod
350
+ viz_mod._generate_on_gpu = generate_on_gpu
351
+ viz_mod._extract_layer_on_gpu = extract_layer_on_gpu
352
+
353
  print("\nInitializing visualizer...")
354
  visualizer = GradioVisualizer(
355
  data_dir=data_dir,