Spaces:
Runtime error
Runtime error
Mateo Fidabel
commited on
Commit
·
42d64c8
1
Parent(s):
ddb9f2a
Added Segmentation Map Notebook
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ from PIL import Image
|
|
| 4 |
from flax.jax_utils import replicate
|
| 5 |
from flax.training.common_utils import shard
|
| 6 |
from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
|
| 7 |
-
|
| 8 |
import jax.numpy as jnp
|
| 9 |
import numpy as np
|
| 10 |
import gc
|
|
@@ -22,6 +22,8 @@ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
|
| 22 |
params["controlnet"] = controlnet_params
|
| 23 |
p_params = replicate(params)
|
| 24 |
|
|
|
|
|
|
|
| 25 |
# Description
|
| 26 |
title = "# 🧨 ControlNet on Segment Anything 🤗"
|
| 27 |
description = """This is a demo on 🧨 ControlNet based on Meta's [Segment Anything Model](https://segment-anything.com/).
|
|
@@ -30,6 +32,9 @@ description = """This is a demo on 🧨 ControlNet based on Meta's [Segment Anyt
|
|
| 30 |
|
| 31 |
⌛️ It takes about 30~ seconds to generate 4 samples, to get faster results, don't forget to reduce the Nº Samples to 1.
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
A huge thanks goes out to @Google Cloud, for providing us with powerful TPUs that enabled us to train this model; and to the @HuggingFace Team for organizing the sprint.
|
| 34 |
"""
|
| 35 |
|
|
|
|
| 4 |
from flax.jax_utils import replicate
|
| 5 |
from flax.training.common_utils import shard
|
| 6 |
from diffusers import FlaxControlNetModel, FlaxStableDiffusionControlNetPipeline
|
| 7 |
+
import jax.profiler
|
| 8 |
import jax.numpy as jnp
|
| 9 |
import numpy as np
|
| 10 |
import gc
|
|
|
|
| 22 |
params["controlnet"] = controlnet_params
|
| 23 |
p_params = replicate(params)
|
| 24 |
|
| 25 |
+
jax.profiler.save_device_memory_profile("memory.prof")
|
| 26 |
+
|
| 27 |
# Description
|
| 28 |
title = "# 🧨 ControlNet on Segment Anything 🤗"
|
| 29 |
description = """This is a demo on 🧨 ControlNet based on Meta's [Segment Anything Model](https://segment-anything.com/).
|
|
|
|
| 32 |
|
| 33 |
⌛️ It takes about 30~ seconds to generate 4 samples, to get faster results, don't forget to reduce the Nº Samples to 1.
|
| 34 |
|
| 35 |
+
You can obtain the Segmentation Map of any Image through this Colab: [](https://colab.research.google.com/github/mfidabel/JAX_SPRINT_2023/blob/main/Segment_Anything_JAX_SPRINT.ipynb)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
A huge thanks goes out to @Google Cloud, for providing us with powerful TPUs that enabled us to train this model; and to the @HuggingFace Team for organizing the sprint.
|
| 39 |
"""
|
| 40 |
|