mckell commited on
Commit
207930a
·
verified ·
1 Parent(s): cfb4c42

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -25
app.py CHANGED
@@ -257,11 +257,12 @@ def get_device() -> str:
257
 
258
  @spaces.GPU(duration=120)
259
  def generate_on_gpu(
260
- visualizer, model_name, all_neighbors, class_label,
261
  n_steps, m_steps, s_max, s_min, guidance, noise_mode,
262
  extract_layers, can_project
263
  ):
264
  """Run masked generation on GPU. Must live in app_file for ZeroGPU detection."""
 
265
  from diffviews.core.masking import ActivationMasker
266
  from diffviews.core.generator import generate_with_mask_multistep
267
 
@@ -304,19 +305,18 @@ def generate_on_gpu(
304
 
305
 
306
  @spaces.GPU(duration=180)
307
- def extract_layer_on_gpu(visualizer, model_name, layer_name, batch_size=32):
308
  """Extract layer activations on GPU. Must live in app_file for ZeroGPU detection."""
 
309
  return visualizer.extract_layer_activations(model_name, layer_name, batch_size)
310
 
311
 
312
- def main():
313
- """Main entry point for HF Spaces."""
314
- # Configuration from environment
315
  data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data"))
316
- checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "all") # Download all by default
317
  device = get_device()
318
 
319
- # Parse checkpoint config
320
  if checkpoint_config == "all":
321
  checkpoints = list(CHECKPOINT_URLS.keys())
322
  elif checkpoint_config == "none":
@@ -332,44 +332,42 @@ def main():
332
  print(f"Checkpoints: {checkpoints}")
333
  print("=" * 50)
334
 
335
- # Ensure data is ready
336
  ensure_data_ready(data_dir, checkpoints)
337
 
338
- # Import and launch visualizer
339
- import gradio as gr
340
  from diffviews.visualization.app import (
341
  GradioVisualizer,
342
  create_gradio_app,
343
- CUSTOM_CSS,
344
- PLOTLY_HANDLER_JS,
345
  )
346
 
347
  # Inject ZeroGPU-decorated functions into visualization module
348
  # so Gradio callbacks use the versions codefind can detect
349
- import diffviews.visualization.app as viz_mod
350
  viz_mod._generate_on_gpu = generate_on_gpu
351
  viz_mod._extract_layer_on_gpu = extract_layer_on_gpu
352
 
353
  print("\nInitializing visualizer...")
354
- visualizer = GradioVisualizer(
355
- data_dir=data_dir,
356
- device=device,
357
- )
358
 
359
  print("Creating Gradio app...")
360
  app = create_gradio_app(visualizer)
 
 
 
 
 
 
 
 
361
 
362
- print("Launching...")
363
- # HF Spaces expects server on 0.0.0.0:7860
364
- app.queue(max_size=20).launch(
 
 
365
  server_name="0.0.0.0",
366
  server_port=7860,
367
- share=False, # Spaces handles public URL
368
  theme=gr.themes.Soft(),
369
  css=CUSTOM_CSS,
370
  js=PLOTLY_HANDLER_JS,
371
  )
372
-
373
-
374
- if __name__ == "__main__":
375
- main()
 
257
 
258
  @spaces.GPU(duration=120)
259
  def generate_on_gpu(
260
+ model_name, all_neighbors, class_label,
261
  n_steps, m_steps, s_max, s_min, guidance, noise_mode,
262
  extract_layers, can_project
263
  ):
264
  """Run masked generation on GPU. Must live in app_file for ZeroGPU detection."""
265
+ from diffviews.visualization.app import _app_visualizer as visualizer
266
  from diffviews.core.masking import ActivationMasker
267
  from diffviews.core.generator import generate_with_mask_multistep
268
 
 
305
 
306
 
307
  @spaces.GPU(duration=180)
308
+ def extract_layer_on_gpu(model_name, layer_name, batch_size=32):
309
  """Extract layer activations on GPU. Must live in app_file for ZeroGPU detection."""
310
+ from diffviews.visualization.app import _app_visualizer as visualizer
311
  return visualizer.extract_layer_activations(model_name, layer_name, batch_size)
312
 
313
 
314
+ def _setup():
315
+ """Initialize data, visualizer, and Gradio app."""
 
316
  data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data"))
317
+ checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "all")
318
  device = get_device()
319
 
 
320
  if checkpoint_config == "all":
321
  checkpoints = list(CHECKPOINT_URLS.keys())
322
  elif checkpoint_config == "none":
 
332
  print(f"Checkpoints: {checkpoints}")
333
  print("=" * 50)
334
 
 
335
  ensure_data_ready(data_dir, checkpoints)
336
 
337
+ import diffviews.visualization.app as viz_mod
 
338
  from diffviews.visualization.app import (
339
  GradioVisualizer,
340
  create_gradio_app,
 
 
341
  )
342
 
343
  # Inject ZeroGPU-decorated functions into visualization module
344
  # so Gradio callbacks use the versions codefind can detect
 
345
  viz_mod._generate_on_gpu = generate_on_gpu
346
  viz_mod._extract_layer_on_gpu = extract_layer_on_gpu
347
 
348
  print("\nInitializing visualizer...")
349
+ visualizer = GradioVisualizer(data_dir=data_dir, device=device)
 
 
 
350
 
351
  print("Creating Gradio app...")
352
  app = create_gradio_app(visualizer)
353
+ app.queue(max_size=20)
354
+ return app
355
+
356
+
357
+ # Module-level setup so Gradio hot-reload (which imports but doesn't call main)
358
+ # still initializes everything and finds the app as `demo`.
359
+ demo = _setup()
360
+
361
 
362
+ if __name__ == "__main__":
363
+ import gradio as gr
364
+ from diffviews.visualization.app import CUSTOM_CSS, PLOTLY_HANDLER_JS
365
+
366
+ demo.launch(
367
  server_name="0.0.0.0",
368
  server_port=7860,
369
+ share=False,
370
  theme=gr.themes.Soft(),
371
  css=CUSTOM_CSS,
372
  js=PLOTLY_HANDLER_JS,
373
  )