Upload app.py
Browse filesMoved zero decorators to root level app.py load.
app.py
CHANGED
|
@@ -17,6 +17,8 @@ Environment variables:
|
|
| 17 |
import os
|
| 18 |
from pathlib import Path
|
| 19 |
|
|
|
|
|
|
|
| 20 |
# Data source configuration
|
| 21 |
DATA_REPO_ID = "mckell/diffviews_demo_data"
|
| 22 |
CHECKPOINT_URLS = {
|
|
@@ -253,6 +255,60 @@ def get_device() -> str:
|
|
| 253 |
return "cpu"
|
| 254 |
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
def main():
|
| 257 |
"""Main entry point for HF Spaces."""
|
| 258 |
# Configuration from environment
|
|
@@ -288,6 +344,12 @@ def main():
|
|
| 288 |
PLOTLY_HANDLER_JS,
|
| 289 |
)
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
print("\nInitializing visualizer...")
|
| 292 |
visualizer = GradioVisualizer(
|
| 293 |
data_dir=data_dir,
|
|
|
|
| 17 |
import os
|
| 18 |
from pathlib import Path
|
| 19 |
|
| 20 |
+
import spaces
|
| 21 |
+
|
| 22 |
# Data source configuration
|
| 23 |
DATA_REPO_ID = "mckell/diffviews_demo_data"
|
| 24 |
CHECKPOINT_URLS = {
|
|
|
|
| 255 |
return "cpu"
|
| 256 |
|
| 257 |
|
| 258 |
+
@spaces.GPU(duration=120)
|
| 259 |
+
def generate_on_gpu(
|
| 260 |
+
visualizer, model_name, all_neighbors, class_label,
|
| 261 |
+
n_steps, m_steps, s_max, s_min, guidance, noise_mode,
|
| 262 |
+
extract_layers, can_project
|
| 263 |
+
):
|
| 264 |
+
"""Run masked generation on GPU. Must live in app_file for ZeroGPU detection."""
|
| 265 |
+
from diffviews.core.masking import ActivationMasker
|
| 266 |
+
from diffviews.core.generator import generate_with_mask_multistep
|
| 267 |
+
|
| 268 |
+
with visualizer._generation_lock:
|
| 269 |
+
adapter = visualizer.load_adapter(model_name)
|
| 270 |
+
if adapter is None:
|
| 271 |
+
return None
|
| 272 |
+
|
| 273 |
+
activation_dict = visualizer.prepare_activation_dict(model_name, all_neighbors)
|
| 274 |
+
if activation_dict is None:
|
| 275 |
+
return None
|
| 276 |
+
|
| 277 |
+
masker = ActivationMasker(adapter)
|
| 278 |
+
for layer_name, activation in activation_dict.items():
|
| 279 |
+
masker.set_mask(layer_name, activation)
|
| 280 |
+
masker.register_hooks(list(activation_dict.keys()))
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
result = generate_with_mask_multistep(
|
| 284 |
+
adapter,
|
| 285 |
+
masker,
|
| 286 |
+
class_label=class_label,
|
| 287 |
+
num_steps=int(n_steps),
|
| 288 |
+
mask_steps=int(m_steps),
|
| 289 |
+
sigma_max=float(s_max),
|
| 290 |
+
sigma_min=float(s_min),
|
| 291 |
+
guidance_scale=float(guidance),
|
| 292 |
+
noise_mode=(noise_mode or "stochastic noise").replace(" noise", ""),
|
| 293 |
+
num_samples=1,
|
| 294 |
+
device=visualizer.device,
|
| 295 |
+
extract_layers=extract_layers if can_project else None,
|
| 296 |
+
return_trajectory=can_project,
|
| 297 |
+
return_intermediates=True,
|
| 298 |
+
return_noised_inputs=True,
|
| 299 |
+
)
|
| 300 |
+
finally:
|
| 301 |
+
masker.remove_hooks()
|
| 302 |
+
|
| 303 |
+
return result
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@spaces.GPU(duration=180)
|
| 307 |
+
def extract_layer_on_gpu(visualizer, model_name, layer_name, batch_size=32):
|
| 308 |
+
"""Extract layer activations on GPU. Must live in app_file for ZeroGPU detection."""
|
| 309 |
+
return visualizer.extract_layer_activations(model_name, layer_name, batch_size)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
def main():
|
| 313 |
"""Main entry point for HF Spaces."""
|
| 314 |
# Configuration from environment
|
|
|
|
| 344 |
PLOTLY_HANDLER_JS,
|
| 345 |
)
|
| 346 |
|
| 347 |
+
# Inject ZeroGPU-decorated functions into visualization module
|
| 348 |
+
# so Gradio callbacks use the versions codefind can detect
|
| 349 |
+
import diffviews.visualization.app as viz_mod
|
| 350 |
+
viz_mod._generate_on_gpu = generate_on_gpu
|
| 351 |
+
viz_mod._extract_layer_on_gpu = extract_layer_on_gpu
|
| 352 |
+
|
| 353 |
print("\nInitializing visualizer...")
|
| 354 |
visualizer = GradioVisualizer(
|
| 355 |
data_dir=data_dir,
|