Upload app.py
Browse files
app.py
CHANGED
|
@@ -257,11 +257,12 @@ def get_device() -> str:
|
|
| 257 |
|
| 258 |
@spaces.GPU(duration=120)
|
| 259 |
def generate_on_gpu(
|
| 260 |
-
|
| 261 |
n_steps, m_steps, s_max, s_min, guidance, noise_mode,
|
| 262 |
extract_layers, can_project
|
| 263 |
):
|
| 264 |
"""Run masked generation on GPU. Must live in app_file for ZeroGPU detection."""
|
|
|
|
| 265 |
from diffviews.core.masking import ActivationMasker
|
| 266 |
from diffviews.core.generator import generate_with_mask_multistep
|
| 267 |
|
|
@@ -304,19 +305,18 @@ def generate_on_gpu(
|
|
| 304 |
|
| 305 |
|
| 306 |
@spaces.GPU(duration=180)
|
| 307 |
-
def extract_layer_on_gpu(
|
| 308 |
"""Extract layer activations on GPU. Must live in app_file for ZeroGPU detection."""
|
|
|
|
| 309 |
return visualizer.extract_layer_activations(model_name, layer_name, batch_size)
|
| 310 |
|
| 311 |
|
| 312 |
-
def
|
| 313 |
-
"""
|
| 314 |
-
# Configuration from environment
|
| 315 |
data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data"))
|
| 316 |
-
checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "all")
|
| 317 |
device = get_device()
|
| 318 |
|
| 319 |
-
# Parse checkpoint config
|
| 320 |
if checkpoint_config == "all":
|
| 321 |
checkpoints = list(CHECKPOINT_URLS.keys())
|
| 322 |
elif checkpoint_config == "none":
|
|
@@ -332,44 +332,42 @@ def main():
|
|
| 332 |
print(f"Checkpoints: {checkpoints}")
|
| 333 |
print("=" * 50)
|
| 334 |
|
| 335 |
-
# Ensure data is ready
|
| 336 |
ensure_data_ready(data_dir, checkpoints)
|
| 337 |
|
| 338 |
-
|
| 339 |
-
import gradio as gr
|
| 340 |
from diffviews.visualization.app import (
|
| 341 |
GradioVisualizer,
|
| 342 |
create_gradio_app,
|
| 343 |
-
CUSTOM_CSS,
|
| 344 |
-
PLOTLY_HANDLER_JS,
|
| 345 |
)
|
| 346 |
|
| 347 |
# Inject ZeroGPU-decorated functions into visualization module
|
| 348 |
# so Gradio callbacks use the versions codefind can detect
|
| 349 |
-
import diffviews.visualization.app as viz_mod
|
| 350 |
viz_mod._generate_on_gpu = generate_on_gpu
|
| 351 |
viz_mod._extract_layer_on_gpu = extract_layer_on_gpu
|
| 352 |
|
| 353 |
print("\nInitializing visualizer...")
|
| 354 |
-
visualizer = GradioVisualizer(
|
| 355 |
-
data_dir=data_dir,
|
| 356 |
-
device=device,
|
| 357 |
-
)
|
| 358 |
|
| 359 |
print("Creating Gradio app...")
|
| 360 |
app = create_gradio_app(visualizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
| 365 |
server_name="0.0.0.0",
|
| 366 |
server_port=7860,
|
| 367 |
-
share=False,
|
| 368 |
theme=gr.themes.Soft(),
|
| 369 |
css=CUSTOM_CSS,
|
| 370 |
js=PLOTLY_HANDLER_JS,
|
| 371 |
)
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
if __name__ == "__main__":
|
| 375 |
-
main()
|
|
|
|
| 257 |
|
| 258 |
@spaces.GPU(duration=120)
|
| 259 |
def generate_on_gpu(
|
| 260 |
+
model_name, all_neighbors, class_label,
|
| 261 |
n_steps, m_steps, s_max, s_min, guidance, noise_mode,
|
| 262 |
extract_layers, can_project
|
| 263 |
):
|
| 264 |
"""Run masked generation on GPU. Must live in app_file for ZeroGPU detection."""
|
| 265 |
+
from diffviews.visualization.app import _app_visualizer as visualizer
|
| 266 |
from diffviews.core.masking import ActivationMasker
|
| 267 |
from diffviews.core.generator import generate_with_mask_multistep
|
| 268 |
|
|
|
|
| 305 |
|
| 306 |
|
| 307 |
@spaces.GPU(duration=180)
|
| 308 |
+
def extract_layer_on_gpu(model_name, layer_name, batch_size=32):
|
| 309 |
"""Extract layer activations on GPU. Must live in app_file for ZeroGPU detection."""
|
| 310 |
+
from diffviews.visualization.app import _app_visualizer as visualizer
|
| 311 |
return visualizer.extract_layer_activations(model_name, layer_name, batch_size)
|
| 312 |
|
| 313 |
|
| 314 |
+
def _setup():
|
| 315 |
+
"""Initialize data, visualizer, and Gradio app."""
|
|
|
|
| 316 |
data_dir = Path(os.environ.get("DIFFVIEWS_DATA_DIR", "data"))
|
| 317 |
+
checkpoint_config = os.environ.get("DIFFVIEWS_CHECKPOINT", "all")
|
| 318 |
device = get_device()
|
| 319 |
|
|
|
|
| 320 |
if checkpoint_config == "all":
|
| 321 |
checkpoints = list(CHECKPOINT_URLS.keys())
|
| 322 |
elif checkpoint_config == "none":
|
|
|
|
| 332 |
print(f"Checkpoints: {checkpoints}")
|
| 333 |
print("=" * 50)
|
| 334 |
|
|
|
|
| 335 |
ensure_data_ready(data_dir, checkpoints)
|
| 336 |
|
| 337 |
+
import diffviews.visualization.app as viz_mod
|
|
|
|
| 338 |
from diffviews.visualization.app import (
|
| 339 |
GradioVisualizer,
|
| 340 |
create_gradio_app,
|
|
|
|
|
|
|
| 341 |
)
|
| 342 |
|
| 343 |
# Inject ZeroGPU-decorated functions into visualization module
|
| 344 |
# so Gradio callbacks use the versions codefind can detect
|
|
|
|
| 345 |
viz_mod._generate_on_gpu = generate_on_gpu
|
| 346 |
viz_mod._extract_layer_on_gpu = extract_layer_on_gpu
|
| 347 |
|
| 348 |
print("\nInitializing visualizer...")
|
| 349 |
+
visualizer = GradioVisualizer(data_dir=data_dir, device=device)
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
print("Creating Gradio app...")
|
| 352 |
app = create_gradio_app(visualizer)
|
| 353 |
+
app.queue(max_size=20)
|
| 354 |
+
return app
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
# Module-level setup so Gradio hot-reload (which imports but doesn't call main)
|
| 358 |
+
# still initializes everything and finds the app as `demo`.
|
| 359 |
+
demo = _setup()
|
| 360 |
+
|
| 361 |
|
| 362 |
+
if __name__ == "__main__":
|
| 363 |
+
import gradio as gr
|
| 364 |
+
from diffviews.visualization.app import CUSTOM_CSS, PLOTLY_HANDLER_JS
|
| 365 |
+
|
| 366 |
+
demo.launch(
|
| 367 |
server_name="0.0.0.0",
|
| 368 |
server_port=7860,
|
| 369 |
+
share=False,
|
| 370 |
theme=gr.themes.Soft(),
|
| 371 |
css=CUSTOM_CSS,
|
| 372 |
js=PLOTLY_HANDLER_JS,
|
| 373 |
)
|
|
|
|
|
|
|
|
|
|
|
|