Spaces:
Sleeping
Sleeping
Commit
·
227a9e0
1
Parent(s):
d54b5ce
attempting RT speedup for L4
Browse files
app.py
CHANGED
|
@@ -24,6 +24,28 @@ from typing import Optional
|
|
| 24 |
|
| 25 |
import json, asyncio, base64
|
| 26 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
from starlette.websockets import WebSocketState
|
| 28 |
try:
|
| 29 |
from uvicorn.protocols.utils import ClientDisconnected # uvicorn >= 0.20
|
|
|
|
| 24 |
|
| 25 |
import json, asyncio, base64
|
| 26 |
import time
|
| 27 |
+
|
| 28 |
+
# ---- Perf knobs (add at top of app.py) ----
|
| 29 |
+
os.environ.setdefault("JAX_PLATFORMS", "cuda") # prefer GPU
|
| 30 |
+
os.environ.setdefault("XLA_FLAGS",
|
| 31 |
+
"--xla_gpu_enable_triton_gemm=true "
|
| 32 |
+
"--xla_gpu_enable_latency_hiding_scheduler=true "
|
| 33 |
+
"--xla_gpu_autotune_level=2")
|
| 34 |
+
# TF32 is enabled by default on Ampere/Ada for matmul; ensure not disabled:
|
| 35 |
+
os.environ.setdefault("NVIDIA_TF32_OVERRIDE", "0")
|
| 36 |
+
|
| 37 |
+
import jax
|
| 38 |
+
jax.config.update("jax_default_matmul_precision", "fastest") # allow TF32
|
| 39 |
+
# Optional: persist XLA compile artifacts across restarts (saves warmup time)
|
| 40 |
+
try:
|
| 41 |
+
from jax.experimental.compilation_cache import compilation_cache as cc
|
| 42 |
+
cc.initialize_cache(os.environ.get("JAX_CACHE_DIR", "/home/appuser/.cache/jax"))
|
| 43 |
+
except Exception:
|
| 44 |
+
pass
|
| 45 |
+
# --------------------------------------------
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
from starlette.websockets import WebSocketState
|
| 50 |
try:
|
| 51 |
from uvicorn.protocols.utils import ClientDisconnected # uvicorn >= 0.20
|