Spaces:
Paused
Paused
Upload 37 files
Browse files- .gitattributes +2 -0
- html/circular.html +32 -0
- html/denoising.html +16 -0
- html/embeddings.html +75 -0
- html/guidance.html +17 -0
- html/inpainting.html +14 -0
- html/interpolate.html +24 -0
- html/negative.html +15 -0
- html/perturbations.html +35 -0
- html/poke.html +21 -0
- html/seeds.html +25 -0
- images/circular.gif +3 -0
- images/circular.png +0 -0
- images/denoising.png +0 -0
- images/guidance.png +0 -0
- images/inpainting.png +0 -0
- images/interpolate.gif +3 -0
- images/interpolate.png +0 -0
- images/negative.png +0 -0
- images/perturbations.png +0 -0
- images/poke.png +0 -0
- images/seeds.png +0 -0
- run.py +1029 -0
- src/__init__.py +2 -0
- src/pipelines/__init__.py +9 -0
- src/pipelines/circular.py +52 -0
- src/pipelines/embeddings.py +196 -0
- src/pipelines/guidance.py +39 -0
- src/pipelines/inpainting.py +41 -0
- src/pipelines/interpolate.py +51 -0
- src/pipelines/negative.py +37 -0
- src/pipelines/perturbations.py +62 -0
- src/pipelines/poke.py +83 -0
- src/pipelines/seed.py +32 -0
- src/util/__init__.py +3 -0
- src/util/base.py +304 -0
- src/util/clip_config.py +114 -0
- src/util/params.py +96 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
images/circular.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
images/interpolate.gif filter=lfs diff=lfs merge=lfs -text
|
html/circular.html
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<details open>
|
| 2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
| 3 |
+
About
|
| 4 |
+
</summary>
|
| 5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
| 6 |
+
<div style="flex: 1;">
|
| 7 |
+
<p style="margin-top: 10px">
|
| 8 |
+
This tab generates a circular trajectory through latent space that begins and ends with the same image.
|
| 9 |
+
If we specify a large number of steps around the circle, the successive images will be closely related, resulting in a gradual deformation that produces a nice animation.
|
| 10 |
+
</p>
|
| 11 |
+
<p style="font-weight: bold;">
|
| 12 |
+
Additional Controls:
|
| 13 |
+
</p>
|
| 14 |
+
<p style="font-weight: bold;">
|
| 15 |
+
Number of Steps around the Circle:
|
| 16 |
+
</p>
|
| 17 |
+
<p>
|
| 18 |
+
Specify the number of images to produce along the circular path.
|
| 19 |
+
</p>
|
| 20 |
+
<p style="font-weight: bold;">
|
| 21 |
+
Proportion of Circle:
|
| 22 |
+
</p>
|
| 23 |
+
<p>
|
| 24 |
+
Sets the proportion of the circle to cover during image generation.
|
| 25 |
+
Ranges from 0 to 360 degrees.
|
| 26 |
+
Using a high step count with a small number of degrees allows you to explore very subtle image transformations.
|
| 27 |
+
</p>
|
| 28 |
+
</div>
|
| 29 |
+
<div style="flex: 1; align-content: center;">
|
| 30 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/circular.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
| 31 |
+
</div>
|
| 32 |
+
</div>
|
html/denoising.html
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<details open>
|
| 2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
| 3 |
+
About
|
| 4 |
+
</summary>
|
| 5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
| 6 |
+
<div style="flex: 1;">
|
| 7 |
+
<p style="margin-top: 10px">
|
| 8 |
+
This tab displays the intermediate images generated during the denoising process.
|
| 9 |
+
Seeing these intermediate images provides insight into how the diffusion model progressively adds detail at each step.
|
| 10 |
+
</p>
|
| 11 |
+
</div>
|
| 12 |
+
<div style="flex: 1; align-content: center;">
|
| 13 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/denoising.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
| 14 |
+
</div>
|
| 15 |
+
</div>
|
| 16 |
+
</details>
|
html/embeddings.html
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<head>
|
| 2 |
+
<link rel="stylesheet" type="text/css" href="styles.css">
|
| 3 |
+
</head>
|
| 4 |
+
|
| 5 |
+
<details open>
|
| 6 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
| 7 |
+
About
|
| 8 |
+
</summary>
|
| 9 |
+
<div style="background-color: #D87F2B; padding-left: 10px;">
|
| 10 |
+
<p style="font-weight: bold;">
|
| 11 |
+
Basic Exploration
|
| 12 |
+
</p>
|
| 13 |
+
The top part of the embeddings tab is the 3D plot of semantic feature space.
|
| 14 |
+
At the bottom of the tab there are expandable panels that can be opened to reveal more advanced features
|
| 15 |
+
|
| 16 |
+
<ul>
|
| 17 |
+
<li>
|
| 18 |
+
<strong>
|
| 19 |
+
Explore the 3D semantic feature space:
|
| 20 |
+
</strong>
|
| 21 |
+
Click and drag in the 3D semantic feature space to rotate the view.
|
| 22 |
+
Use the scroll wheel to zoom in and out.
|
| 23 |
+
Hold down the control key and click and drag to pan the view.
|
| 24 |
+
</li>
|
| 25 |
+
<li>
|
| 26 |
+
<strong>
|
| 27 |
+
Find the generated image:
|
| 28 |
+
</strong>
|
| 29 |
+
Hover over a point in the semantic feature space, and a window will pop up showing a generated image from this one-word prompt.
|
| 30 |
+
On left click, the image will be downloaded.
|
| 31 |
+
</li>
|
| 32 |
+
<li>
|
| 33 |
+
<strong>
|
| 34 |
+
Find the embedding vector display:
|
| 35 |
+
</strong>
|
| 36 |
+
Hover over a word in the 3D semantic feature space, and an embedding vector display at the bottom of the tab shows the corresponding embedding vector.
|
| 37 |
+
</li>
|
| 38 |
+
<li>
|
| 39 |
+
<strong>
|
| 40 |
+
Add/remove words from the 3D plot:
|
| 41 |
+
</strong>
|
| 42 |
+
Type a word in the Add/Remove word text box below the 3D plot to add a word to the plot, or if the word is already present, remove it from the plot.
|
| 43 |
+
You can also type multiple words separated by spaces or commas.
|
| 44 |
+
</li>
|
| 45 |
+
<li>
|
| 46 |
+
<strong>
|
| 47 |
+
Change image for word in the 3D plot:
|
| 48 |
+
</strong>
|
| 49 |
+
Type a word in the Change image for word text box below the 3D plot to generate a new image for the corresponding word in the plot.
|
| 50 |
+
</li>
|
| 51 |
+
</ul>
|
| 52 |
+
|
| 53 |
+
<p style="font-weight: bold; margin-top: 10px;">
|
| 54 |
+
Semantic Dimensions
|
| 55 |
+
</p>
|
| 56 |
+
<ul>
|
| 57 |
+
<li>
|
| 58 |
+
<strong>Select a different semantic dimension.</strong><br>
|
| 59 |
+
Open the Custom Semantic Dimensions panel and choose another dimension for the X or Y or Z axis.
|
| 60 |
+
See how the display changes.
|
| 61 |
+
</li>
|
| 62 |
+
<li>
|
| 63 |
+
<strong>Alter a semantic dimension.</strong><br>
|
| 64 |
+
Examine the positive and negative word pairs used to define the semantic dimension.
|
| 65 |
+
You can change these pairs to alter the semantic dimension.
|
| 66 |
+
</li>
|
| 67 |
+
<li>
|
| 68 |
+
<strong>Define a new semantic dimension.</strong><br>
|
| 69 |
+
Pick a new semantic dimension that you can define using pairs of opposed words.
|
| 70 |
+
For example, you could define a "tense" dimension with pairs such as eat/ate, go/went, see/saw, and is/was to contrast present and past tense forms of verbs.
|
| 71 |
+
</li>
|
| 72 |
+
</ul>
|
| 73 |
+
</div>
|
| 74 |
+
</details>
|
| 75 |
+
|
html/guidance.html
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<details open>
|
| 2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
| 3 |
+
About
|
| 4 |
+
</summary>
|
| 5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
| 6 |
+
<div style="flex: 1;">
|
| 7 |
+
<p style="margin-top: 10px">
|
| 8 |
+
Guidance is responsible for making the target image adhere to the prompt.
|
| 9 |
+
A higher value enforces this relation, whereas a lower value does not.
|
| 10 |
+
For example, a guidance scale of 1 produces a distorted grayscale image, whereas 50 produces a distorted, oversaturated image.
|
| 11 |
+
The default value of 8 produces normal-looking images that reasonably adhere to the prompt.
|
| 12 |
+
</p>
|
| 13 |
+
</div>
|
| 14 |
+
<div style="flex: 1; align-content: center;">
|
| 15 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/guidance.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
| 16 |
+
</div>
|
| 17 |
+
</div>
|
html/inpainting.html
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<details open>
|
| 2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
| 3 |
+
About
|
| 4 |
+
</summary>
|
| 5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
| 6 |
+
<div style="flex: 1;">
|
| 7 |
+
<p style="margin-top: 10px">
|
| 8 |
+
Unlike poke, which globally alters the target image via a perturbation in the initial latent noise, inpainting alters just the region of the perturbation and allows us to specify the change we want to make.
|
| 9 |
+
</p>
|
| 10 |
+
</div>
|
| 11 |
+
<div style="flex: 1; align-content: center;">
|
| 12 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/inpainting.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
| 13 |
+
</div>
|
| 14 |
+
</div>
|
html/interpolate.html
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<details open>
|
| 2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
| 3 |
+
About
|
| 4 |
+
</summary>
|
| 5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
| 6 |
+
<div style="flex: 1;">
|
| 7 |
+
<p style="margin-top: 10px">
|
| 8 |
+
This tab generates noise patterns for two text prompts and then interpolates between them, gradually transforming from the first to the second.
|
| 9 |
+
With a large number of perturbation steps the transformation is very gradual and makes a nice animation.
|
| 10 |
+
</p>
|
| 11 |
+
<p style="font-weight: bold;">
|
| 12 |
+
Additional Controls:
|
| 13 |
+
</p>
|
| 14 |
+
<p style="font-weight: bold;">
|
| 15 |
+
Number of Interpolation Steps:
|
| 16 |
+
</p>
|
| 17 |
+
<p>
|
| 18 |
+
Defines the number of intermediate images to generate between the two prompts.
|
| 19 |
+
</p>
|
| 20 |
+
</div>
|
| 21 |
+
<div style="flex: 1; align-content: center;">
|
| 22 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/interpolate.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
| 23 |
+
</div>
|
| 24 |
+
</div>
|
html/negative.html
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<details open>
|
| 2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
| 3 |
+
About
|
| 4 |
+
</summary>
|
| 5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
| 6 |
+
<div style="flex: 1;">
|
| 7 |
+
<p style="margin-top: 10px">
|
| 8 |
+
Negative prompts steer images away from unwanted features.
|
| 9 |
+
For example, “red” as a negative prompt makes the generated image unlikely to have reddish hues.
|
| 10 |
+
</p>
|
| 11 |
+
</div>
|
| 12 |
+
<div style="flex: 1; align-content: center;">
|
| 13 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/negative.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
| 14 |
+
</div>
|
| 15 |
+
</div>
|
html/perturbations.html
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<details open>
|
| 2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
| 3 |
+
About
|
| 4 |
+
</summary>
|
| 5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
| 6 |
+
<div style="flex: 1;">
|
| 7 |
+
<p style="margin-top: 10px">
|
| 8 |
+
Perturbations enables the exploration of the latent space around a seed.
|
| 9 |
+
Perturbing the noise from an initial seed towards the noise from a different seed illustrates the variations in images obtainable from a local region of latent space.
|
| 10 |
+
Using a small perturbation size produces target images that closely resemble the one from the initial seed.
|
| 11 |
+
Larger perturbations traverse more distance in latent space towards the second seed, resulting in greater variation in the generated images.
|
| 12 |
+
</p>
|
| 13 |
+
<p style="font-weight: bold;">
|
| 14 |
+
Additional Controls:
|
| 15 |
+
</p>
|
| 16 |
+
<p style="font-weight: bold;">
|
| 17 |
+
Number of Perturbations:
|
| 18 |
+
</p>
|
| 19 |
+
<p>
|
| 20 |
+
Specify the number of perturbations to create, i.e., the number of seeds to use. More perturbations produce more images.
|
| 21 |
+
</p>
|
| 22 |
+
<p style="font-weight: bold;">
|
| 23 |
+
Perturbation Size:
|
| 24 |
+
</p>
|
| 25 |
+
<p>
|
| 26 |
+
Controls the perturbation magnitude, ranging from 0 to 1.
|
| 27 |
+
With a value of 0, all images will match the one from the initial seed.
|
| 28 |
+
With a value of 1, images will have no connection to the initial seed.
|
| 29 |
+
A value such as 0.1 is recommended.
|
| 30 |
+
</p>
|
| 31 |
+
</div>
|
| 32 |
+
<div style="flex: 1; align-content: center;">
|
| 33 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/perturbations.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
| 34 |
+
</div>
|
| 35 |
+
</div>
|
html/poke.html
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<details open>
|
| 2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
| 3 |
+
About
|
| 4 |
+
</summary>
|
| 5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
| 6 |
+
<div style="flex: 1;">
|
| 7 |
+
<p style="margin-top: 10px">
|
| 8 |
+
Poke explores how perturbations in a local region of the initial latent noise impact the target image.
|
| 9 |
+
A small perturbation to the initial latent noise gets carried through the denoising process, demonstrating the global effect it can produce.
|
| 10 |
+
</p>
|
| 11 |
+
<p style="font-weight: bold;">
|
| 12 |
+
Additional Controls:
|
| 13 |
+
</p>
|
| 14 |
+
<p>
|
| 15 |
+
You can adjust the perturbation through the X, Y, height, and width controls.
|
| 16 |
+
</p>
|
| 17 |
+
</div>
|
| 18 |
+
<div style="flex: 1; align-content: center;">
|
| 19 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/poke.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
| 20 |
+
</div>
|
| 21 |
+
</div>
|
html/seeds.html
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<details open>
|
| 2 |
+
<summary style="background-color: #CE6400; padding-left: 10px;">
|
| 3 |
+
About
|
| 4 |
+
</summary>
|
| 5 |
+
<div style="display: flex; flex-direction: row; background-color: #D87F2B; padding-left: 10px;">
|
| 6 |
+
<div style="flex: 1;">
|
| 7 |
+
<p style="margin-top: 10px">
|
| 8 |
+
Seeds create the initial noise that gets refined into the target image.
|
| 9 |
+
Different seeds produce different noise patterns, hence the target image will differ even when prompted by the same text.
|
| 10 |
+
This tab produces multiple target images from the same text prompt to showcase how changing the seed changes the target image.
|
| 11 |
+
</p>
|
| 12 |
+
<p style="font-weight: bold;">
|
| 13 |
+
Additional Controls:
|
| 14 |
+
</p>
|
| 15 |
+
<p style="font-weight: bold;">
|
| 16 |
+
Number of Seeds:
|
| 17 |
+
</p>
|
| 18 |
+
<p>
|
| 19 |
+
Specify how many seed values to use.
|
| 20 |
+
</p>
|
| 21 |
+
</div>
|
| 22 |
+
<div style="flex: 1; align-content: center;">
|
| 23 |
+
<img src="https://raw.githubusercontent.com/touretzkyds/DiffusionDemo/master/images/seeds.png" style="max-width: 100%; height: auto; margin-top: 10px; margin-bottom: 10px; padding-left: 10px;">
|
| 24 |
+
</div>
|
| 25 |
+
</div>
|
images/circular.gif
ADDED
|
Git LFS Details
|
images/circular.png
ADDED
|
images/denoising.png
ADDED
|
images/guidance.png
ADDED
|
images/inpainting.png
ADDED
|
images/interpolate.gif
ADDED
|
Git LFS Details
|
images/interpolate.png
ADDED
|
images/negative.png
ADDED
|
images/perturbations.png
ADDED
|
images/poke.png
ADDED
|
images/seeds.png
ADDED
|
run.py
ADDED
|
@@ -0,0 +1,1029 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from src.util import *
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from src.pipelines import *
|
| 7 |
+
from threading import Thread
|
| 8 |
+
from dash import Dash, dcc, html, Input, Output, no_update, callback
|
| 9 |
+
|
| 10 |
+
app = Dash(__name__)
|
| 11 |
+
|
| 12 |
+
app.layout = html.Div(
|
| 13 |
+
className="container",
|
| 14 |
+
children=[
|
| 15 |
+
dcc.Graph(
|
| 16 |
+
id="graph", figure=fig, clear_on_unhover=True, style={"height": "90vh"}
|
| 17 |
+
),
|
| 18 |
+
dcc.Tooltip(id="tooltip"),
|
| 19 |
+
html.Div(id="word-emb-txt", style={"background-color": "white"}),
|
| 20 |
+
html.Div(id="word-emb-vis"),
|
| 21 |
+
html.Div(
|
| 22 |
+
[
|
| 23 |
+
html.Button(id="btn-download-image", hidden=True),
|
| 24 |
+
dcc.Download(id="download-image"),
|
| 25 |
+
]
|
| 26 |
+
),
|
| 27 |
+
],
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@callback(
|
| 32 |
+
Output("tooltip", "show"),
|
| 33 |
+
Output("tooltip", "bbox"),
|
| 34 |
+
Output("tooltip", "children"),
|
| 35 |
+
Output("tooltip", "direction"),
|
| 36 |
+
Output("word-emb-txt", "children"),
|
| 37 |
+
Output("word-emb-vis", "children"),
|
| 38 |
+
Input("graph", "hoverData"),
|
| 39 |
+
)
|
| 40 |
+
def display_hover(hoverData):
|
| 41 |
+
if hoverData is None:
|
| 42 |
+
return False, no_update, no_update, no_update, no_update, no_update
|
| 43 |
+
|
| 44 |
+
hover_data = hoverData["points"][0]
|
| 45 |
+
bbox = hover_data["bbox"]
|
| 46 |
+
direction = "left"
|
| 47 |
+
index = hover_data["pointNumber"]
|
| 48 |
+
|
| 49 |
+
children = [
|
| 50 |
+
html.Img(
|
| 51 |
+
src=images[index],
|
| 52 |
+
style={"width": "250px"},
|
| 53 |
+
),
|
| 54 |
+
html.P(
|
| 55 |
+
hover_data["text"],
|
| 56 |
+
style={
|
| 57 |
+
"color": "black",
|
| 58 |
+
"font-size": "20px",
|
| 59 |
+
"text-align": "center",
|
| 60 |
+
"background-color": "white",
|
| 61 |
+
"margin": "5px",
|
| 62 |
+
},
|
| 63 |
+
),
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
emb_children = [
|
| 67 |
+
html.Img(
|
| 68 |
+
src=generate_word_emb_vis(hover_data["text"]),
|
| 69 |
+
style={"width": "100%", "height": "25px"},
|
| 70 |
+
),
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
return True, bbox, children, direction, hover_data["text"], emb_children
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@callback(
|
| 77 |
+
Output("download-image", "data"),
|
| 78 |
+
Input("graph", "clickData"),
|
| 79 |
+
)
|
| 80 |
+
def download_image(clickData):
|
| 81 |
+
|
| 82 |
+
if clickData is None:
|
| 83 |
+
return no_update
|
| 84 |
+
|
| 85 |
+
click_data = clickData["points"][0]
|
| 86 |
+
index = click_data["pointNumber"]
|
| 87 |
+
txt = click_data["text"]
|
| 88 |
+
|
| 89 |
+
img_encoded = images[index]
|
| 90 |
+
img_decoded = base64.b64decode(img_encoded.split(",")[1])
|
| 91 |
+
img = Image.open(BytesIO(img_decoded))
|
| 92 |
+
img.save(f"{txt}.png")
|
| 93 |
+
return dcc.send_file(f"{txt}.png")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
with gr.Blocks() as demo:
|
| 97 |
+
gr.Markdown("## Stable Diffusion Demo")
|
| 98 |
+
|
| 99 |
+
with gr.Tab("Latent Space"):
|
| 100 |
+
|
| 101 |
+
with gr.TabItem("Denoising"):
|
| 102 |
+
gr.Markdown("Observe the intermediate images during denoising.")
|
| 103 |
+
gr.HTML(read_html("DiffusionDemo/html/denoising.html"))
|
| 104 |
+
|
| 105 |
+
with gr.Row():
|
| 106 |
+
with gr.Column():
|
| 107 |
+
prompt_denoise = gr.Textbox(
|
| 108 |
+
lines=1,
|
| 109 |
+
label="Prompt",
|
| 110 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
| 111 |
+
)
|
| 112 |
+
num_inference_steps_denoise = gr.Slider(
|
| 113 |
+
minimum=2,
|
| 114 |
+
maximum=100,
|
| 115 |
+
step=1,
|
| 116 |
+
value=8,
|
| 117 |
+
label="Number of Inference Steps",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
with gr.Row():
|
| 121 |
+
seed_denoise = gr.Slider(
|
| 122 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
| 123 |
+
)
|
| 124 |
+
seed_vis_denoise = gr.Plot(
|
| 125 |
+
value=generate_seed_vis(14), label="Seed"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
generate_images_button_denoise = gr.Button("Generate Images")
|
| 129 |
+
|
| 130 |
+
with gr.Column():
|
| 131 |
+
images_output_denoise = gr.Gallery(label="Images", selected_index=0)
|
| 132 |
+
gif_denoise = gr.Image(label="GIF")
|
| 133 |
+
zip_output_denoise = gr.File(label="Download ZIP")
|
| 134 |
+
|
| 135 |
+
@generate_images_button_denoise.click(
|
| 136 |
+
inputs=[prompt_denoise, seed_denoise, num_inference_steps_denoise],
|
| 137 |
+
outputs=[images_output_denoise, gif_denoise, zip_output_denoise],
|
| 138 |
+
)
|
| 139 |
+
def generate_images_wrapper(
|
| 140 |
+
prompt, seed, num_inference_steps, progress=gr.Progress()
|
| 141 |
+
):
|
| 142 |
+
images, _ = display_poke_images(
|
| 143 |
+
prompt, seed, num_inference_steps, poke=False, intermediate=True
|
| 144 |
+
)
|
| 145 |
+
fname = "denoising"
|
| 146 |
+
tab_config = {
|
| 147 |
+
"Tab": "Denoising",
|
| 148 |
+
"Prompt": prompt,
|
| 149 |
+
"Number of Inference Steps": num_inference_steps,
|
| 150 |
+
"Seed": seed,
|
| 151 |
+
}
|
| 152 |
+
export_as_zip(images, fname, tab_config)
|
| 153 |
+
progress(1, desc="Exporting as gif")
|
| 154 |
+
export_as_gif(images, filename="denoising.gif")
|
| 155 |
+
return images, "outputs/denoising.gif", f"outputs/{fname}.zip"
|
| 156 |
+
|
| 157 |
+
seed_denoise.change(
|
| 158 |
+
fn=generate_seed_vis, inputs=[seed_denoise], outputs=[seed_vis_denoise]
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
with gr.TabItem("Seeds"):
|
| 162 |
+
gr.Markdown(
|
| 163 |
+
"Understand how different starting points in latent space can lead to different images."
|
| 164 |
+
)
|
| 165 |
+
gr.HTML(read_html("DiffusionDemo/html/seeds.html"))
|
| 166 |
+
|
| 167 |
+
with gr.Row():
|
| 168 |
+
with gr.Column():
|
| 169 |
+
prompt_seed = gr.Textbox(
|
| 170 |
+
lines=1,
|
| 171 |
+
label="Prompt",
|
| 172 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
| 173 |
+
)
|
| 174 |
+
num_images_seed = gr.Slider(
|
| 175 |
+
minimum=1, maximum=100, step=1, value=5, label="Number of Seeds"
|
| 176 |
+
)
|
| 177 |
+
num_inference_steps_seed = gr.Slider(
|
| 178 |
+
minimum=2,
|
| 179 |
+
maximum=100,
|
| 180 |
+
step=1,
|
| 181 |
+
value=8,
|
| 182 |
+
label="Number of Inference Steps per Image",
|
| 183 |
+
)
|
| 184 |
+
generate_images_button_seed = gr.Button("Generate Images")
|
| 185 |
+
|
| 186 |
+
with gr.Column():
|
| 187 |
+
images_output_seed = gr.Gallery(label="Images", selected_index=0)
|
| 188 |
+
zip_output_seed = gr.File(label="Download ZIP")
|
| 189 |
+
|
| 190 |
+
generate_images_button_seed.click(
|
| 191 |
+
fn=display_seed_images,
|
| 192 |
+
inputs=[prompt_seed, num_inference_steps_seed, num_images_seed],
|
| 193 |
+
outputs=[images_output_seed, zip_output_seed],
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
with gr.TabItem("Perturbations"):
|
| 197 |
+
gr.Markdown("Explore different perturbations from a point in latent space.")
|
| 198 |
+
gr.HTML(read_html("DiffusionDemo/html/perturbations.html"))
|
| 199 |
+
|
| 200 |
+
with gr.Row():
|
| 201 |
+
with gr.Column():
|
| 202 |
+
prompt_perturb = gr.Textbox(
|
| 203 |
+
lines=1,
|
| 204 |
+
label="Prompt",
|
| 205 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
| 206 |
+
)
|
| 207 |
+
num_images_perturb = gr.Slider(
|
| 208 |
+
minimum=0,
|
| 209 |
+
maximum=100,
|
| 210 |
+
step=1,
|
| 211 |
+
value=5,
|
| 212 |
+
label="Number of Perturbations",
|
| 213 |
+
)
|
| 214 |
+
perturbation_size_perturb = gr.Slider(
|
| 215 |
+
minimum=0,
|
| 216 |
+
maximum=1,
|
| 217 |
+
step=0.1,
|
| 218 |
+
value=0.1,
|
| 219 |
+
label="Perturbation Size",
|
| 220 |
+
)
|
| 221 |
+
num_inference_steps_perturb = gr.Slider(
|
| 222 |
+
minimum=2,
|
| 223 |
+
maximum=100,
|
| 224 |
+
step=1,
|
| 225 |
+
value=8,
|
| 226 |
+
label="Number of Inference Steps per Image",
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
with gr.Row():
|
| 230 |
+
seed_perturb = gr.Slider(
|
| 231 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
| 232 |
+
)
|
| 233 |
+
seed_vis_perturb = gr.Plot(
|
| 234 |
+
value=generate_seed_vis(14), label="Seed"
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
generate_images_button_perturb = gr.Button("Generate Images")
|
| 238 |
+
|
| 239 |
+
with gr.Column():
|
| 240 |
+
images_output_perturb = gr.Gallery(label="Image", selected_index=0)
|
| 241 |
+
zip_output_perturb = gr.File(label="Download ZIP")
|
| 242 |
+
|
| 243 |
+
generate_images_button_perturb.click(
|
| 244 |
+
fn=display_perturb_images,
|
| 245 |
+
inputs=[
|
| 246 |
+
prompt_perturb,
|
| 247 |
+
seed_perturb,
|
| 248 |
+
num_inference_steps_perturb,
|
| 249 |
+
num_images_perturb,
|
| 250 |
+
perturbation_size_perturb,
|
| 251 |
+
],
|
| 252 |
+
outputs=[images_output_perturb, zip_output_perturb],
|
| 253 |
+
)
|
| 254 |
+
seed_perturb.change(
|
| 255 |
+
fn=generate_seed_vis, inputs=[seed_perturb], outputs=[seed_vis_perturb]
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
with gr.TabItem("Circular"):
|
| 259 |
+
gr.Markdown(
|
| 260 |
+
"Generate a circular path in latent space and observe how the images vary along the path."
|
| 261 |
+
)
|
| 262 |
+
gr.HTML(read_html("DiffusionDemo/html/circular.html"))
|
| 263 |
+
|
| 264 |
+
with gr.Row():
|
| 265 |
+
with gr.Column():
|
| 266 |
+
prompt_circular = gr.Textbox(
|
| 267 |
+
lines=1,
|
| 268 |
+
label="Prompt",
|
| 269 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
| 270 |
+
)
|
| 271 |
+
num_images_circular = gr.Slider(
|
| 272 |
+
minimum=2,
|
| 273 |
+
maximum=100,
|
| 274 |
+
step=1,
|
| 275 |
+
value=5,
|
| 276 |
+
label="Number of Steps around the Circle",
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
with gr.Row():
|
| 280 |
+
degree_circular = gr.Slider(
|
| 281 |
+
minimum=0,
|
| 282 |
+
maximum=360,
|
| 283 |
+
step=1,
|
| 284 |
+
value=360,
|
| 285 |
+
label="Proportion of Circle",
|
| 286 |
+
info="Enter the value in degrees",
|
| 287 |
+
)
|
| 288 |
+
step_size_circular = gr.Textbox(
|
| 289 |
+
label="Step Size", value=360 / 5
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
num_inference_steps_circular = gr.Slider(
|
| 293 |
+
minimum=2,
|
| 294 |
+
maximum=100,
|
| 295 |
+
step=1,
|
| 296 |
+
value=8,
|
| 297 |
+
label="Number of Inference Steps per Image",
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
with gr.Row():
|
| 301 |
+
seed_circular = gr.Slider(
|
| 302 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
| 303 |
+
)
|
| 304 |
+
seed_vis_circular = gr.Plot(
|
| 305 |
+
value=generate_seed_vis(14), label="Seed"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
generate_images_button_circular = gr.Button("Generate Images")
|
| 309 |
+
|
| 310 |
+
with gr.Column():
|
| 311 |
+
images_output_circular = gr.Gallery(label="Image", selected_index=0)
|
| 312 |
+
gif_circular = gr.Image(label="GIF")
|
| 313 |
+
zip_output_circular = gr.File(label="Download ZIP")
|
| 314 |
+
|
| 315 |
+
num_images_circular.change(
|
| 316 |
+
fn=calculate_step_size,
|
| 317 |
+
inputs=[num_images_circular, degree_circular],
|
| 318 |
+
outputs=[step_size_circular],
|
| 319 |
+
)
|
| 320 |
+
degree_circular.change(
|
| 321 |
+
fn=calculate_step_size,
|
| 322 |
+
inputs=[num_images_circular, degree_circular],
|
| 323 |
+
outputs=[step_size_circular],
|
| 324 |
+
)
|
| 325 |
+
generate_images_button_circular.click(
|
| 326 |
+
fn=display_circular_images,
|
| 327 |
+
inputs=[
|
| 328 |
+
prompt_circular,
|
| 329 |
+
seed_circular,
|
| 330 |
+
num_inference_steps_circular,
|
| 331 |
+
num_images_circular,
|
| 332 |
+
degree_circular,
|
| 333 |
+
],
|
| 334 |
+
outputs=[images_output_circular, gif_circular, zip_output_circular],
|
| 335 |
+
)
|
| 336 |
+
seed_circular.change(
|
| 337 |
+
fn=generate_seed_vis, inputs=[seed_circular], outputs=[seed_vis_circular]
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
with gr.TabItem("Poke"):
|
| 341 |
+
gr.Markdown("Perturb a region in the image and observe the effect.")
|
| 342 |
+
gr.HTML(read_html("DiffusionDemo/html/poke.html"))
|
| 343 |
+
|
| 344 |
+
with gr.Row():
|
| 345 |
+
with gr.Column():
|
| 346 |
+
prompt_poke = gr.Textbox(
|
| 347 |
+
lines=1,
|
| 348 |
+
label="Prompt",
|
| 349 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
| 350 |
+
)
|
| 351 |
+
num_inference_steps_poke = gr.Slider(
|
| 352 |
+
minimum=2,
|
| 353 |
+
maximum=100,
|
| 354 |
+
step=1,
|
| 355 |
+
value=8,
|
| 356 |
+
label="Number of Inference Steps per Image",
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
with gr.Row():
|
| 360 |
+
seed_poke = gr.Slider(
|
| 361 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
| 362 |
+
)
|
| 363 |
+
seed_vis_poke = gr.Plot(
|
| 364 |
+
value=generate_seed_vis(14), label="Seed"
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
pokeX = gr.Slider(
|
| 368 |
+
label="pokeX",
|
| 369 |
+
minimum=0,
|
| 370 |
+
maximum=64,
|
| 371 |
+
step=1,
|
| 372 |
+
value=32,
|
| 373 |
+
info="X coordinate of poke center",
|
| 374 |
+
)
|
| 375 |
+
pokeY = gr.Slider(
|
| 376 |
+
label="pokeY",
|
| 377 |
+
minimum=0,
|
| 378 |
+
maximum=64,
|
| 379 |
+
step=1,
|
| 380 |
+
value=32,
|
| 381 |
+
info="Y coordinate of poke center",
|
| 382 |
+
)
|
| 383 |
+
pokeHeight = gr.Slider(
|
| 384 |
+
label="pokeHeight",
|
| 385 |
+
minimum=0,
|
| 386 |
+
maximum=64,
|
| 387 |
+
step=1,
|
| 388 |
+
value=8,
|
| 389 |
+
info="Height of the poke",
|
| 390 |
+
)
|
| 391 |
+
pokeWidth = gr.Slider(
|
| 392 |
+
label="pokeWidth",
|
| 393 |
+
minimum=0,
|
| 394 |
+
maximum=64,
|
| 395 |
+
step=1,
|
| 396 |
+
value=8,
|
| 397 |
+
info="Width of the poke",
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
generate_images_button_poke = gr.Button("Generate Images")
|
| 401 |
+
|
| 402 |
+
with gr.Column():
|
| 403 |
+
original_images_output_poke = gr.Image(
|
| 404 |
+
value=visualize_poke(32, 32, 8, 8)[0], label="Original Image"
|
| 405 |
+
)
|
| 406 |
+
poked_images_output_poke = gr.Image(
|
| 407 |
+
value=visualize_poke(32, 32, 8, 8)[1], label="Poked Image"
|
| 408 |
+
)
|
| 409 |
+
zip_output_poke = gr.File(label="Download ZIP")
|
| 410 |
+
|
| 411 |
+
pokeX.change(
|
| 412 |
+
visualize_poke,
|
| 413 |
+
inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
|
| 414 |
+
outputs=[original_images_output_poke, poked_images_output_poke],
|
| 415 |
+
)
|
| 416 |
+
pokeY.change(
|
| 417 |
+
visualize_poke,
|
| 418 |
+
inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
|
| 419 |
+
outputs=[original_images_output_poke, poked_images_output_poke],
|
| 420 |
+
)
|
| 421 |
+
pokeHeight.change(
|
| 422 |
+
visualize_poke,
|
| 423 |
+
inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
|
| 424 |
+
outputs=[original_images_output_poke, poked_images_output_poke],
|
| 425 |
+
)
|
| 426 |
+
pokeWidth.change(
|
| 427 |
+
visualize_poke,
|
| 428 |
+
inputs=[pokeX, pokeY, pokeHeight, pokeWidth],
|
| 429 |
+
outputs=[original_images_output_poke, poked_images_output_poke],
|
| 430 |
+
)
|
| 431 |
+
seed_poke.change(
|
| 432 |
+
fn=generate_seed_vis, inputs=[seed_poke], outputs=[seed_vis_poke]
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
@generate_images_button_poke.click(
|
| 436 |
+
inputs=[
|
| 437 |
+
prompt_poke,
|
| 438 |
+
seed_poke,
|
| 439 |
+
num_inference_steps_poke,
|
| 440 |
+
pokeX,
|
| 441 |
+
pokeY,
|
| 442 |
+
pokeHeight,
|
| 443 |
+
pokeWidth,
|
| 444 |
+
],
|
| 445 |
+
outputs=[
|
| 446 |
+
original_images_output_poke,
|
| 447 |
+
poked_images_output_poke,
|
| 448 |
+
zip_output_poke,
|
| 449 |
+
],
|
| 450 |
+
)
|
| 451 |
+
def generate_images_wrapper(
|
| 452 |
+
prompt,
|
| 453 |
+
seed,
|
| 454 |
+
num_inference_steps,
|
| 455 |
+
pokeX=pokeX,
|
| 456 |
+
pokeY=pokeY,
|
| 457 |
+
pokeHeight=pokeHeight,
|
| 458 |
+
pokeWidth=pokeWidth,
|
| 459 |
+
):
|
| 460 |
+
_, _ = display_poke_images(
|
| 461 |
+
prompt,
|
| 462 |
+
seed,
|
| 463 |
+
num_inference_steps,
|
| 464 |
+
poke=True,
|
| 465 |
+
pokeX=pokeX,
|
| 466 |
+
pokeY=pokeY,
|
| 467 |
+
pokeHeight=pokeHeight,
|
| 468 |
+
pokeWidth=pokeWidth,
|
| 469 |
+
intermediate=False,
|
| 470 |
+
)
|
| 471 |
+
images, modImages = visualize_poke(pokeX, pokeY, pokeHeight, pokeWidth)
|
| 472 |
+
fname = "poke"
|
| 473 |
+
tab_config = {
|
| 474 |
+
"Tab": "Poke",
|
| 475 |
+
"Prompt": prompt,
|
| 476 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
| 477 |
+
"Seed": seed,
|
| 478 |
+
"PokeX": pokeX,
|
| 479 |
+
"PokeY": pokeY,
|
| 480 |
+
"PokeHeight": pokeHeight,
|
| 481 |
+
"PokeWidth": pokeWidth,
|
| 482 |
+
}
|
| 483 |
+
imgs_list = []
|
| 484 |
+
imgs_list.append((images, "Original Image"))
|
| 485 |
+
imgs_list.append((modImages, "Poked Image"))
|
| 486 |
+
|
| 487 |
+
export_as_zip(imgs_list, fname, tab_config)
|
| 488 |
+
return images, modImages, f"outputs/{fname}.zip"
|
| 489 |
+
|
| 490 |
+
with gr.TabItem("Guidance"):
|
| 491 |
+
gr.Markdown("Observe the effect of different guidance scales.")
|
| 492 |
+
gr.HTML(read_html("DiffusionDemo/html/guidance.html"))
|
| 493 |
+
|
| 494 |
+
with gr.Row():
|
| 495 |
+
with gr.Column():
|
| 496 |
+
prompt_guidance = gr.Textbox(
|
| 497 |
+
lines=1,
|
| 498 |
+
label="Prompt",
|
| 499 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
| 500 |
+
)
|
| 501 |
+
num_inference_steps_guidance = gr.Slider(
|
| 502 |
+
minimum=2,
|
| 503 |
+
maximum=100,
|
| 504 |
+
step=1,
|
| 505 |
+
value=8,
|
| 506 |
+
label="Number of Inference Steps per Image",
|
| 507 |
+
)
|
| 508 |
+
guidance_scale_values = gr.Textbox(
|
| 509 |
+
lines=1, value="1, 8, 20, 30", label="Guidance Scale Values"
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
with gr.Row():
|
| 513 |
+
seed_guidance = gr.Slider(
|
| 514 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
| 515 |
+
)
|
| 516 |
+
seed_vis_guidance = gr.Plot(
|
| 517 |
+
value=generate_seed_vis(14), label="Seed"
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
generate_images_button_guidance = gr.Button("Generate Images")
|
| 521 |
+
|
| 522 |
+
with gr.Column():
|
| 523 |
+
images_output_guidance = gr.Gallery(
|
| 524 |
+
label="Images", selected_index=0
|
| 525 |
+
)
|
| 526 |
+
zip_output_guidance = gr.File(label="Download ZIP")
|
| 527 |
+
|
| 528 |
+
generate_images_button_guidance.click(
|
| 529 |
+
fn=display_guidance_images,
|
| 530 |
+
inputs=[
|
| 531 |
+
prompt_guidance,
|
| 532 |
+
seed_guidance,
|
| 533 |
+
num_inference_steps_guidance,
|
| 534 |
+
guidance_scale_values,
|
| 535 |
+
],
|
| 536 |
+
outputs=[images_output_guidance, zip_output_guidance],
|
| 537 |
+
)
|
| 538 |
+
seed_guidance.change(
|
| 539 |
+
fn=generate_seed_vis, inputs=[seed_guidance], outputs=[seed_vis_guidance]
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
with gr.TabItem("Inpainting"):
|
| 543 |
+
gr.Markdown("Inpaint the image based on the prompt.")
|
| 544 |
+
gr.HTML(read_html("DiffusionDemo/html/inpainting.html"))
|
| 545 |
+
|
| 546 |
+
with gr.Row():
|
| 547 |
+
with gr.Column():
|
| 548 |
+
uploaded_img_inpaint = gr.Image(
|
| 549 |
+
source="upload", tool="sketch", type="pil", label="Upload"
|
| 550 |
+
)
|
| 551 |
+
prompt_inpaint = gr.Textbox(
|
| 552 |
+
lines=1, label="Prompt", value="sunglasses"
|
| 553 |
+
)
|
| 554 |
+
num_inference_steps_inpaint = gr.Slider(
|
| 555 |
+
minimum=2,
|
| 556 |
+
maximum=100,
|
| 557 |
+
step=1,
|
| 558 |
+
value=8,
|
| 559 |
+
label="Number of Inference Steps per Image",
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
with gr.Row():
|
| 563 |
+
seed_inpaint = gr.Slider(
|
| 564 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
| 565 |
+
)
|
| 566 |
+
seed_vis_inpaint = gr.Plot(
|
| 567 |
+
value=generate_seed_vis(14), label="Seed"
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
inpaint_button = gr.Button("Inpaint")
|
| 571 |
+
|
| 572 |
+
with gr.Column():
|
| 573 |
+
images_output_inpaint = gr.Image(label="Output")
|
| 574 |
+
zip_output_inpaint = gr.File(label="Download ZIP")
|
| 575 |
+
|
| 576 |
+
inpaint_button.click(
|
| 577 |
+
fn=inpaint,
|
| 578 |
+
inputs=[
|
| 579 |
+
uploaded_img_inpaint,
|
| 580 |
+
num_inference_steps_inpaint,
|
| 581 |
+
seed_inpaint,
|
| 582 |
+
prompt_inpaint,
|
| 583 |
+
],
|
| 584 |
+
outputs=[images_output_inpaint, zip_output_inpaint],
|
| 585 |
+
)
|
| 586 |
+
seed_inpaint.change(
|
| 587 |
+
fn=generate_seed_vis, inputs=[seed_inpaint], outputs=[seed_vis_inpaint]
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
with gr.Tab("CLIP Space"):
|
| 591 |
+
|
| 592 |
+
with gr.TabItem("Embeddings"):
|
| 593 |
+
gr.Markdown(
|
| 594 |
+
"Visualize text embedding space in 3D with input texts and output images based on the chosen axis."
|
| 595 |
+
)
|
| 596 |
+
gr.HTML(read_html("DiffusionDemo/html/embeddings.html"))
|
| 597 |
+
|
| 598 |
+
with gr.Row():
|
| 599 |
+
output = gr.HTML(
|
| 600 |
+
f"""
|
| 601 |
+
<iframe id="html" src="{dash_tunnel}" style="width:100%; height:700px;"></iframe>
|
| 602 |
+
"""
|
| 603 |
+
)
|
| 604 |
+
with gr.Row():
|
| 605 |
+
word2add_rem = gr.Textbox(lines=1, label="Add/Remove word")
|
| 606 |
+
word2change = gr.Textbox(lines=1, label="Change image for word")
|
| 607 |
+
clear_words_button = gr.Button(value="Clear words")
|
| 608 |
+
|
| 609 |
+
with gr.Accordion("Custom Semantic Dimensions", open=False):
|
| 610 |
+
with gr.Row():
|
| 611 |
+
axis_name_1 = gr.Textbox(label="Axis name", value="gender")
|
| 612 |
+
which_axis_1 = gr.Dropdown(
|
| 613 |
+
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
|
| 614 |
+
value=whichAxisMap["which_axis_1"],
|
| 615 |
+
label="Axis direction",
|
| 616 |
+
)
|
| 617 |
+
from_words_1 = gr.Textbox(
|
| 618 |
+
lines=1,
|
| 619 |
+
label="Positive",
|
| 620 |
+
value="prince husband father son uncle",
|
| 621 |
+
)
|
| 622 |
+
to_words_1 = gr.Textbox(
|
| 623 |
+
lines=1,
|
| 624 |
+
label="Negative",
|
| 625 |
+
value="princess wife mother daughter aunt",
|
| 626 |
+
)
|
| 627 |
+
submit_1 = gr.Button("Submit")
|
| 628 |
+
|
| 629 |
+
with gr.Row():
|
| 630 |
+
axis_name_2 = gr.Textbox(label="Axis name", value="age")
|
| 631 |
+
which_axis_2 = gr.Dropdown(
|
| 632 |
+
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
|
| 633 |
+
value=whichAxisMap["which_axis_2"],
|
| 634 |
+
label="Axis direction",
|
| 635 |
+
)
|
| 636 |
+
from_words_2 = gr.Textbox(
|
| 637 |
+
lines=1, label="Positive", value="man woman king queen father"
|
| 638 |
+
)
|
| 639 |
+
to_words_2 = gr.Textbox(
|
| 640 |
+
lines=1, label="Negative", value="boy girl prince princess son"
|
| 641 |
+
)
|
| 642 |
+
submit_2 = gr.Button("Submit")
|
| 643 |
+
|
| 644 |
+
with gr.Row():
|
| 645 |
+
axis_name_3 = gr.Textbox(label="Axis name", value="residual")
|
| 646 |
+
which_axis_3 = gr.Dropdown(
|
| 647 |
+
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
|
| 648 |
+
value=whichAxisMap["which_axis_3"],
|
| 649 |
+
label="Axis direction",
|
| 650 |
+
)
|
| 651 |
+
from_words_3 = gr.Textbox(lines=1, label="Positive")
|
| 652 |
+
to_words_3 = gr.Textbox(lines=1, label="Negative")
|
| 653 |
+
submit_3 = gr.Button("Submit")
|
| 654 |
+
|
| 655 |
+
with gr.Row():
|
| 656 |
+
axis_name_4 = gr.Textbox(label="Axis name", value="number")
|
| 657 |
+
which_axis_4 = gr.Dropdown(
|
| 658 |
+
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
|
| 659 |
+
value=whichAxisMap["which_axis_4"],
|
| 660 |
+
label="Axis direction",
|
| 661 |
+
)
|
| 662 |
+
from_words_4 = gr.Textbox(
|
| 663 |
+
lines=1,
|
| 664 |
+
label="Positive",
|
| 665 |
+
value="boys girls cats puppies computers",
|
| 666 |
+
)
|
| 667 |
+
to_words_4 = gr.Textbox(
|
| 668 |
+
lines=1, label="Negative", value="boy girl cat puppy computer"
|
| 669 |
+
)
|
| 670 |
+
submit_4 = gr.Button("Submit")
|
| 671 |
+
|
| 672 |
+
with gr.Row():
|
| 673 |
+
axis_name_5 = gr.Textbox(label="Axis name", value="royalty")
|
| 674 |
+
which_axis_5 = gr.Dropdown(
|
| 675 |
+
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
|
| 676 |
+
value=whichAxisMap["which_axis_5"],
|
| 677 |
+
label="Axis direction",
|
| 678 |
+
)
|
| 679 |
+
from_words_5 = gr.Textbox(
|
| 680 |
+
lines=1,
|
| 681 |
+
label="Positive",
|
| 682 |
+
value="king queen prince princess duchess",
|
| 683 |
+
)
|
| 684 |
+
to_words_5 = gr.Textbox(
|
| 685 |
+
lines=1, label="Negative", value="man woman boy girl woman"
|
| 686 |
+
)
|
| 687 |
+
submit_5 = gr.Button("Submit")
|
| 688 |
+
|
| 689 |
+
with gr.Row():
|
| 690 |
+
axis_name_6 = gr.Textbox(label="Axis name")
|
| 691 |
+
which_axis_6 = gr.Dropdown(
|
| 692 |
+
choices=["X - Axis", "Y - Axis", "Z - Axis", "---"],
|
| 693 |
+
value=whichAxisMap["which_axis_6"],
|
| 694 |
+
label="Axis direction",
|
| 695 |
+
)
|
| 696 |
+
from_words_6 = gr.Textbox(lines=1, label="Positive")
|
| 697 |
+
to_words_6 = gr.Textbox(lines=1, label="Negative")
|
| 698 |
+
submit_6 = gr.Button("Submit")
|
| 699 |
+
|
| 700 |
+
@word2add_rem.submit(inputs=[word2add_rem], outputs=[output, word2add_rem])
|
| 701 |
+
def add_rem_word_and_clear(words):
|
| 702 |
+
return add_rem_word(words), ""
|
| 703 |
+
|
| 704 |
+
@word2change.submit(inputs=[word2change], outputs=[output, word2change])
|
| 705 |
+
def change_word_and_clear(word):
|
| 706 |
+
return change_word(word), ""
|
| 707 |
+
|
| 708 |
+
clear_words_button.click(fn=clear_words, outputs=[output])
|
| 709 |
+
|
| 710 |
+
@submit_1.click(
|
| 711 |
+
inputs=[axis_name_1, which_axis_1, from_words_1, to_words_1],
|
| 712 |
+
outputs=[
|
| 713 |
+
output,
|
| 714 |
+
which_axis_2,
|
| 715 |
+
which_axis_3,
|
| 716 |
+
which_axis_4,
|
| 717 |
+
which_axis_5,
|
| 718 |
+
which_axis_6,
|
| 719 |
+
],
|
| 720 |
+
)
|
| 721 |
+
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
|
| 722 |
+
|
| 723 |
+
for ax in whichAxisMap:
|
| 724 |
+
if whichAxisMap[ax] == which_axis:
|
| 725 |
+
whichAxisMap[ax] = "---"
|
| 726 |
+
|
| 727 |
+
whichAxisMap["which_axis_1"] = which_axis
|
| 728 |
+
return (
|
| 729 |
+
set_axis(axis_name, which_axis, from_words, to_words),
|
| 730 |
+
whichAxisMap["which_axis_2"],
|
| 731 |
+
whichAxisMap["which_axis_3"],
|
| 732 |
+
whichAxisMap["which_axis_4"],
|
| 733 |
+
whichAxisMap["which_axis_5"],
|
| 734 |
+
whichAxisMap["which_axis_6"],
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
@submit_2.click(
|
| 738 |
+
inputs=[axis_name_2, which_axis_2, from_words_2, to_words_2],
|
| 739 |
+
outputs=[
|
| 740 |
+
output,
|
| 741 |
+
which_axis_1,
|
| 742 |
+
which_axis_3,
|
| 743 |
+
which_axis_4,
|
| 744 |
+
which_axis_5,
|
| 745 |
+
which_axis_6,
|
| 746 |
+
],
|
| 747 |
+
)
|
| 748 |
+
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
|
| 749 |
+
|
| 750 |
+
for ax in whichAxisMap:
|
| 751 |
+
if whichAxisMap[ax] == which_axis:
|
| 752 |
+
whichAxisMap[ax] = "---"
|
| 753 |
+
|
| 754 |
+
whichAxisMap["which_axis_2"] = which_axis
|
| 755 |
+
return (
|
| 756 |
+
set_axis(axis_name, which_axis, from_words, to_words),
|
| 757 |
+
whichAxisMap["which_axis_1"],
|
| 758 |
+
whichAxisMap["which_axis_3"],
|
| 759 |
+
whichAxisMap["which_axis_4"],
|
| 760 |
+
whichAxisMap["which_axis_5"],
|
| 761 |
+
whichAxisMap["which_axis_6"],
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
@submit_3.click(
|
| 765 |
+
inputs=[axis_name_3, which_axis_3, from_words_3, to_words_3],
|
| 766 |
+
outputs=[
|
| 767 |
+
output,
|
| 768 |
+
which_axis_1,
|
| 769 |
+
which_axis_2,
|
| 770 |
+
which_axis_4,
|
| 771 |
+
which_axis_5,
|
| 772 |
+
which_axis_6,
|
| 773 |
+
],
|
| 774 |
+
)
|
| 775 |
+
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
|
| 776 |
+
|
| 777 |
+
for ax in whichAxisMap:
|
| 778 |
+
if whichAxisMap[ax] == which_axis:
|
| 779 |
+
whichAxisMap[ax] = "---"
|
| 780 |
+
|
| 781 |
+
whichAxisMap["which_axis_3"] = which_axis
|
| 782 |
+
return (
|
| 783 |
+
set_axis(axis_name, which_axis, from_words, to_words),
|
| 784 |
+
whichAxisMap["which_axis_1"],
|
| 785 |
+
whichAxisMap["which_axis_2"],
|
| 786 |
+
whichAxisMap["which_axis_4"],
|
| 787 |
+
whichAxisMap["which_axis_5"],
|
| 788 |
+
whichAxisMap["which_axis_6"],
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
@submit_4.click(
|
| 792 |
+
inputs=[axis_name_4, which_axis_4, from_words_4, to_words_4],
|
| 793 |
+
outputs=[
|
| 794 |
+
output,
|
| 795 |
+
which_axis_1,
|
| 796 |
+
which_axis_2,
|
| 797 |
+
which_axis_3,
|
| 798 |
+
which_axis_5,
|
| 799 |
+
which_axis_6,
|
| 800 |
+
],
|
| 801 |
+
)
|
| 802 |
+
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
|
| 803 |
+
|
| 804 |
+
for ax in whichAxisMap:
|
| 805 |
+
if whichAxisMap[ax] == which_axis:
|
| 806 |
+
whichAxisMap[ax] = "---"
|
| 807 |
+
|
| 808 |
+
whichAxisMap["which_axis_4"] = which_axis
|
| 809 |
+
return (
|
| 810 |
+
set_axis(axis_name, which_axis, from_words, to_words),
|
| 811 |
+
whichAxisMap["which_axis_1"],
|
| 812 |
+
whichAxisMap["which_axis_2"],
|
| 813 |
+
whichAxisMap["which_axis_3"],
|
| 814 |
+
whichAxisMap["which_axis_5"],
|
| 815 |
+
whichAxisMap["which_axis_6"],
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
@submit_5.click(
|
| 819 |
+
inputs=[axis_name_5, which_axis_5, from_words_5, to_words_5],
|
| 820 |
+
outputs=[
|
| 821 |
+
output,
|
| 822 |
+
which_axis_1,
|
| 823 |
+
which_axis_2,
|
| 824 |
+
which_axis_3,
|
| 825 |
+
which_axis_4,
|
| 826 |
+
which_axis_6,
|
| 827 |
+
],
|
| 828 |
+
)
|
| 829 |
+
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
|
| 830 |
+
|
| 831 |
+
for ax in whichAxisMap:
|
| 832 |
+
if whichAxisMap[ax] == which_axis:
|
| 833 |
+
whichAxisMap[ax] = "---"
|
| 834 |
+
|
| 835 |
+
whichAxisMap["which_axis_5"] = which_axis
|
| 836 |
+
return (
|
| 837 |
+
set_axis(axis_name, which_axis, from_words, to_words),
|
| 838 |
+
whichAxisMap["which_axis_1"],
|
| 839 |
+
whichAxisMap["which_axis_2"],
|
| 840 |
+
whichAxisMap["which_axis_3"],
|
| 841 |
+
whichAxisMap["which_axis_4"],
|
| 842 |
+
whichAxisMap["which_axis_6"],
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
@submit_6.click(
|
| 846 |
+
inputs=[axis_name_6, which_axis_6, from_words_6, to_words_6],
|
| 847 |
+
outputs=[
|
| 848 |
+
output,
|
| 849 |
+
which_axis_1,
|
| 850 |
+
which_axis_2,
|
| 851 |
+
which_axis_3,
|
| 852 |
+
which_axis_4,
|
| 853 |
+
which_axis_5,
|
| 854 |
+
],
|
| 855 |
+
)
|
| 856 |
+
def set_axis_wrapper(axis_name, which_axis, from_words, to_words):
|
| 857 |
+
|
| 858 |
+
for ax in whichAxisMap:
|
| 859 |
+
if whichAxisMap[ax] == which_axis:
|
| 860 |
+
whichAxisMap[ax] = "---"
|
| 861 |
+
|
| 862 |
+
whichAxisMap["which_axis_6"] = which_axis
|
| 863 |
+
return (
|
| 864 |
+
set_axis(axis_name, which_axis, from_words, to_words),
|
| 865 |
+
whichAxisMap["which_axis_1"],
|
| 866 |
+
whichAxisMap["which_axis_2"],
|
| 867 |
+
whichAxisMap["which_axis_3"],
|
| 868 |
+
whichAxisMap["which_axis_4"],
|
| 869 |
+
whichAxisMap["which_axis_5"],
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
with gr.TabItem("Interpolate"):
|
| 873 |
+
gr.Markdown(
|
| 874 |
+
"Interpolate between the first and the second prompt, and observe how the output changes."
|
| 875 |
+
)
|
| 876 |
+
gr.HTML(read_html("DiffusionDemo/html/interpolate.html"))
|
| 877 |
+
|
| 878 |
+
with gr.Row():
|
| 879 |
+
with gr.Column():
|
| 880 |
+
promptA = gr.Textbox(
|
| 881 |
+
lines=1,
|
| 882 |
+
label="First Prompt",
|
| 883 |
+
value="Self-portrait oil painting, a beautiful man with golden hair, 8k",
|
| 884 |
+
)
|
| 885 |
+
promptB = gr.Textbox(
|
| 886 |
+
lines=1,
|
| 887 |
+
label="Second Prompt",
|
| 888 |
+
value="Self-portrait oil painting, a beautiful woman with golden hair, 8k",
|
| 889 |
+
)
|
| 890 |
+
num_images_interpolate = gr.Slider(
|
| 891 |
+
minimum=0,
|
| 892 |
+
maximum=100,
|
| 893 |
+
step=1,
|
| 894 |
+
value=5,
|
| 895 |
+
label="Number of Interpolation Steps",
|
| 896 |
+
)
|
| 897 |
+
num_inference_steps_interpolate = gr.Slider(
|
| 898 |
+
minimum=2,
|
| 899 |
+
maximum=100,
|
| 900 |
+
step=1,
|
| 901 |
+
value=8,
|
| 902 |
+
label="Number of Inference Steps per Image",
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
with gr.Row():
|
| 906 |
+
seed_interpolate = gr.Slider(
|
| 907 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
| 908 |
+
)
|
| 909 |
+
seed_vis_interpolate = gr.Plot(
|
| 910 |
+
value=generate_seed_vis(14), label="Seed"
|
| 911 |
+
)
|
| 912 |
+
|
| 913 |
+
generate_images_button_interpolate = gr.Button("Generate Images")
|
| 914 |
+
|
| 915 |
+
with gr.Column():
|
| 916 |
+
images_output_interpolate = gr.Gallery(
|
| 917 |
+
label="Interpolated Images", selected_index=0
|
| 918 |
+
)
|
| 919 |
+
gif_interpolate = gr.Image(label="GIF")
|
| 920 |
+
zip_output_interpolate = gr.File(label="Download ZIP")
|
| 921 |
+
|
| 922 |
+
generate_images_button_interpolate.click(
|
| 923 |
+
fn=display_interpolate_images,
|
| 924 |
+
inputs=[
|
| 925 |
+
seed_interpolate,
|
| 926 |
+
promptA,
|
| 927 |
+
promptB,
|
| 928 |
+
num_inference_steps_interpolate,
|
| 929 |
+
num_images_interpolate,
|
| 930 |
+
],
|
| 931 |
+
outputs=[
|
| 932 |
+
images_output_interpolate,
|
| 933 |
+
gif_interpolate,
|
| 934 |
+
zip_output_interpolate,
|
| 935 |
+
],
|
| 936 |
+
)
|
| 937 |
+
seed_interpolate.change(
|
| 938 |
+
fn=generate_seed_vis,
|
| 939 |
+
inputs=[seed_interpolate],
|
| 940 |
+
outputs=[seed_vis_interpolate],
|
| 941 |
+
)
|
| 942 |
+
|
| 943 |
+
with gr.TabItem("Negative"):
|
| 944 |
+
gr.Markdown("Observe the effect of negative prompts.")
|
| 945 |
+
gr.HTML(read_html("DiffusionDemo/html/negative.html"))
|
| 946 |
+
|
| 947 |
+
with gr.Row():
|
| 948 |
+
with gr.Column():
|
| 949 |
+
prompt_negative = gr.Textbox(
|
| 950 |
+
lines=1,
|
| 951 |
+
label="Prompt",
|
| 952 |
+
value="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
|
| 953 |
+
)
|
| 954 |
+
neg_prompt = gr.Textbox(
|
| 955 |
+
lines=1, label="Negative Prompt", value="Yellow"
|
| 956 |
+
)
|
| 957 |
+
num_inference_steps_negative = gr.Slider(
|
| 958 |
+
minimum=2,
|
| 959 |
+
maximum=100,
|
| 960 |
+
step=1,
|
| 961 |
+
value=8,
|
| 962 |
+
label="Number of Inference Steps per Image",
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
with gr.Row():
|
| 966 |
+
seed_negative = gr.Slider(
|
| 967 |
+
minimum=0, maximum=100, step=1, value=14, label="Seed"
|
| 968 |
+
)
|
| 969 |
+
seed_vis_negative = gr.Plot(
|
| 970 |
+
value=generate_seed_vis(14), label="Seed"
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
generate_images_button_negative = gr.Button("Generate Images")
|
| 974 |
+
|
| 975 |
+
with gr.Column():
|
| 976 |
+
images_output_negative = gr.Image(
|
| 977 |
+
label="Image without Negative Prompt"
|
| 978 |
+
)
|
| 979 |
+
images_neg_output_negative = gr.Image(
|
| 980 |
+
label="Image with Negative Prompt"
|
| 981 |
+
)
|
| 982 |
+
zip_output_negative = gr.File(label="Download ZIP")
|
| 983 |
+
|
| 984 |
+
seed_negative.change(
|
| 985 |
+
fn=generate_seed_vis, inputs=[seed_negative], outputs=[seed_vis_negative]
|
| 986 |
+
)
|
| 987 |
+
generate_images_button_negative.click(
|
| 988 |
+
fn=display_negative_images,
|
| 989 |
+
inputs=[
|
| 990 |
+
prompt_negative,
|
| 991 |
+
seed_negative,
|
| 992 |
+
num_inference_steps_negative,
|
| 993 |
+
neg_prompt,
|
| 994 |
+
],
|
| 995 |
+
outputs=[
|
| 996 |
+
images_output_negative,
|
| 997 |
+
images_neg_output_negative,
|
| 998 |
+
zip_output_negative,
|
| 999 |
+
],
|
| 1000 |
+
)
|
| 1001 |
+
|
| 1002 |
+
with gr.Tab("Credits"):
|
| 1003 |
+
gr.Markdown("""
|
| 1004 |
+
Author: Adithya Kameswara Rao, Carnegie Mellon University.
|
| 1005 |
+
|
| 1006 |
+
Advisor: David S. Touretzky, Carnegie Mellon University.
|
| 1007 |
+
|
| 1008 |
+
This work was funded by a grant from NEOM Company, and by National Science Foundation award IIS-2112633.
|
| 1009 |
+
""")
|
| 1010 |
+
|
| 1011 |
+
|
| 1012 |
+
def run_dash():
|
| 1013 |
+
app.run(host="127.0.0.1", port="8000")
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
def run_gradio():
|
| 1017 |
+
demo.queue()
|
| 1018 |
+
_, _, public_url = demo.launch(share=True)
|
| 1019 |
+
return public_url
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
# if __name__ == "__main__":
|
| 1023 |
+
# thread = Thread(target=run_dash)
|
| 1024 |
+
# thread.daemon = True
|
| 1025 |
+
# thread.start()
|
| 1026 |
+
# try:
|
| 1027 |
+
# run_gradio()
|
| 1028 |
+
# except KeyboardInterrupt:
|
| 1029 |
+
# print("Server closed")
|
src/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import util
|
| 2 |
+
from . import pipelines
|
src/pipelines/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .circular import *
|
| 2 |
+
from .embeddings import *
|
| 3 |
+
from .interpolate import *
|
| 4 |
+
from .poke import *
|
| 5 |
+
from .seed import *
|
| 6 |
+
from .perturbations import *
|
| 7 |
+
from .negative import *
|
| 8 |
+
from .guidance import *
|
| 9 |
+
from .inpainting import *
|
src/pipelines/circular.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from src.util.base import *
|
| 5 |
+
from src.util.params import *
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def display_circular_images(
|
| 9 |
+
prompt, seed, num_inference_steps, num_images, degree, progress=gr.Progress()
|
| 10 |
+
):
|
| 11 |
+
np.random.seed(seed)
|
| 12 |
+
text_embeddings = get_text_embeddings(prompt)
|
| 13 |
+
|
| 14 |
+
latents_x = generate_latents(seed)
|
| 15 |
+
latents_y = generate_latents(seed * np.random.randint(0, 100000))
|
| 16 |
+
|
| 17 |
+
scale_x = torch.cos(
|
| 18 |
+
torch.linspace(0, 2, num_images) * torch.pi * (degree / 360)
|
| 19 |
+
).to(torch_device)
|
| 20 |
+
scale_y = torch.sin(
|
| 21 |
+
torch.linspace(0, 2, num_images) * torch.pi * (degree / 360)
|
| 22 |
+
).to(torch_device)
|
| 23 |
+
|
| 24 |
+
noise_x = torch.tensordot(scale_x, latents_x, dims=0)
|
| 25 |
+
noise_y = torch.tensordot(scale_y, latents_y, dims=0)
|
| 26 |
+
|
| 27 |
+
noise = noise_x + noise_y
|
| 28 |
+
|
| 29 |
+
progress(0)
|
| 30 |
+
images = []
|
| 31 |
+
for i in range(num_images):
|
| 32 |
+
progress(i / num_images)
|
| 33 |
+
image = generate_images(noise[i], text_embeddings, num_inference_steps)
|
| 34 |
+
images.append((image, "{}".format(i)))
|
| 35 |
+
|
| 36 |
+
progress(1, desc="Exporting as gif")
|
| 37 |
+
export_as_gif(images, filename="circular.gif")
|
| 38 |
+
|
| 39 |
+
fname = "circular"
|
| 40 |
+
tab_config = {
|
| 41 |
+
"Tab": "Circular",
|
| 42 |
+
"Prompt": prompt,
|
| 43 |
+
"Number of Steps around the Circle": num_images,
|
| 44 |
+
"Proportion of Circle": degree,
|
| 45 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
| 46 |
+
"Seed": seed,
|
| 47 |
+
}
|
| 48 |
+
export_as_zip(images, fname, tab_config)
|
| 49 |
+
return images, "outputs/circular.gif", f"outputs/{fname}.zip"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
__all__ = ["display_circular_images"]
|
src/pipelines/embeddings.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import numpy as np
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from diffusers import StableDiffusionPipeline
|
| 6 |
+
|
| 7 |
+
import base64
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
import plotly.express as px
|
| 10 |
+
|
| 11 |
+
from src.util.base import *
|
| 12 |
+
from src.util.params import *
|
| 13 |
+
from src.util.clip_config import *
|
| 14 |
+
|
| 15 |
+
age = get_axis_embeddings(young, old)
|
| 16 |
+
gender = get_axis_embeddings(masculine, feminine)
|
| 17 |
+
royalty = get_axis_embeddings(common, elite)
|
| 18 |
+
|
| 19 |
+
images = []
|
| 20 |
+
for example in examples:
|
| 21 |
+
image = pipe(
|
| 22 |
+
prompt=example,
|
| 23 |
+
num_inference_steps=num_inference_steps,
|
| 24 |
+
guidance_scale=guidance_scale,
|
| 25 |
+
).images[0]
|
| 26 |
+
buffer = BytesIO()
|
| 27 |
+
image.save(buffer, format="JPEG")
|
| 28 |
+
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 29 |
+
images.append("data:image/jpeg;base64, " + encoded_image)
|
| 30 |
+
|
| 31 |
+
axis = np.vstack([gender, royalty, age])
|
| 32 |
+
axis[1] = calculate_residual(axis, axis_names)
|
| 33 |
+
|
| 34 |
+
coords = get_concat_embeddings(examples) @ axis.T
|
| 35 |
+
coords[:, 1] = 5 * (1.0 - coords[:, 1])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def update_fig():
|
| 39 |
+
global coords, examples, fig
|
| 40 |
+
fig.data[0].x = coords[:, 0]
|
| 41 |
+
fig.data[0].y = coords[:, 1]
|
| 42 |
+
fig.data[0].z = coords[:, 2]
|
| 43 |
+
fig.data[0].text = examples
|
| 44 |
+
|
| 45 |
+
return f"""
|
| 46 |
+
<script>
|
| 47 |
+
document.getElementById("html").src += "?rand={random.random()}"
|
| 48 |
+
</script>
|
| 49 |
+
<iframe id="html" src={dash_tunnel} style="width:100%; height:725px;"></iframe>
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def add_word(new_example):
|
| 54 |
+
global coords, images, examples
|
| 55 |
+
new_coord = get_concat_embeddings([new_example]) @ axis.T
|
| 56 |
+
new_coord[:, 1] = 5 * (1.0 - new_coord[:, 1])
|
| 57 |
+
coords = np.vstack([coords, new_coord])
|
| 58 |
+
|
| 59 |
+
image = pipe(
|
| 60 |
+
prompt=new_example,
|
| 61 |
+
num_inference_steps=num_inference_steps,
|
| 62 |
+
guidance_scale=guidance_scale,
|
| 63 |
+
).images[0]
|
| 64 |
+
buffer = BytesIO()
|
| 65 |
+
image.save(buffer, format="JPEG")
|
| 66 |
+
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 67 |
+
images.append("data:image/jpeg;base64, " + encoded_image)
|
| 68 |
+
examples.append(new_example)
|
| 69 |
+
return update_fig()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def remove_word(new_example):
|
| 73 |
+
global coords, images, examples
|
| 74 |
+
examplesMap = {example: index for index, example in enumerate(examples)}
|
| 75 |
+
index = examplesMap[new_example]
|
| 76 |
+
|
| 77 |
+
coords = np.delete(coords, index, 0)
|
| 78 |
+
images.pop(index)
|
| 79 |
+
examples.pop(index)
|
| 80 |
+
return update_fig()
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def add_rem_word(new_examples):
|
| 84 |
+
global examples
|
| 85 |
+
new_examples = new_examples.replace(",", " ").split()
|
| 86 |
+
|
| 87 |
+
for new_example in new_examples:
|
| 88 |
+
if new_example in examples:
|
| 89 |
+
remove_word(new_example)
|
| 90 |
+
gr.Info("Removed {}".format(new_example))
|
| 91 |
+
else:
|
| 92 |
+
tokens = tokenizer.encode(new_example)
|
| 93 |
+
if len(tokens) != 3:
|
| 94 |
+
gr.Warning(f"{new_example} not found in embeddings")
|
| 95 |
+
else:
|
| 96 |
+
add_word(new_example)
|
| 97 |
+
gr.Info("Added {}".format(new_example))
|
| 98 |
+
|
| 99 |
+
return update_fig()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def set_axis(axis_name, which_axis, from_words, to_words):
|
| 103 |
+
global coords, examples, fig, axis_names
|
| 104 |
+
|
| 105 |
+
if axis_name != "residual":
|
| 106 |
+
from_words, to_words = (
|
| 107 |
+
from_words.replace(",", " ").split(),
|
| 108 |
+
to_words.replace(",", " ").split(),
|
| 109 |
+
)
|
| 110 |
+
axis_emb = get_axis_embeddings(from_words, to_words)
|
| 111 |
+
axis[axisMap[which_axis]] = axis_emb
|
| 112 |
+
axis_names[axisMap[which_axis]] = axis_name
|
| 113 |
+
|
| 114 |
+
for i, name in enumerate(axis_names):
|
| 115 |
+
if name == "residual":
|
| 116 |
+
axis[i] = calculate_residual(axis, axis_names, from_words, to_words, i)
|
| 117 |
+
axis_names[i] = "residual"
|
| 118 |
+
else:
|
| 119 |
+
residual = calculate_residual(
|
| 120 |
+
axis, axis_names, residual_axis=axisMap[which_axis]
|
| 121 |
+
)
|
| 122 |
+
axis[axisMap[which_axis]] = residual
|
| 123 |
+
axis_names[axisMap[which_axis]] = axis_name
|
| 124 |
+
|
| 125 |
+
coords = get_concat_embeddings(examples) @ axis.T
|
| 126 |
+
coords[:, 1] = 5 * (1.0 - coords[:, 1])
|
| 127 |
+
|
| 128 |
+
fig.update_layout(
|
| 129 |
+
scene=dict(
|
| 130 |
+
xaxis_title=axis_names[0],
|
| 131 |
+
yaxis_title=axis_names[1],
|
| 132 |
+
zaxis_title=axis_names[2],
|
| 133 |
+
)
|
| 134 |
+
)
|
| 135 |
+
return update_fig()
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def change_word(examples):
|
| 139 |
+
examples = examples.replace(",", " ").split()
|
| 140 |
+
|
| 141 |
+
for example in examples:
|
| 142 |
+
remove_word(example)
|
| 143 |
+
add_word(example)
|
| 144 |
+
gr.Info("Changed image for {}".format(example))
|
| 145 |
+
|
| 146 |
+
return update_fig()
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def clear_words():
|
| 150 |
+
while examples:
|
| 151 |
+
remove_word(examples[-1])
|
| 152 |
+
return update_fig()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def generate_word_emb_vis(prompt):
|
| 156 |
+
buf = BytesIO()
|
| 157 |
+
emb = get_word_embeddings(prompt).reshape(77, 768)[1]
|
| 158 |
+
plt.imsave(buf, [emb], cmap="inferno")
|
| 159 |
+
img = "data:image/jpeg;base64, " + base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 160 |
+
return img
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
fig = px.scatter_3d(
|
| 164 |
+
x=coords[:, 0],
|
| 165 |
+
y=coords[:, 1],
|
| 166 |
+
z=coords[:, 2],
|
| 167 |
+
labels={
|
| 168 |
+
"x": axis_names[0],
|
| 169 |
+
"y": axis_names[1],
|
| 170 |
+
"z": axis_names[2],
|
| 171 |
+
},
|
| 172 |
+
text=examples,
|
| 173 |
+
height=750,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
fig.update_layout(
|
| 177 |
+
margin=dict(l=0, r=0, b=0, t=0), scene_camera=dict(eye=dict(x=2, y=2, z=0.1))
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
fig.update_traces(hoverinfo="none", hovertemplate=None)
|
| 181 |
+
|
| 182 |
+
__all__ = [
|
| 183 |
+
"fig",
|
| 184 |
+
"update_fig",
|
| 185 |
+
"coords",
|
| 186 |
+
"images",
|
| 187 |
+
"examples",
|
| 188 |
+
"add_word",
|
| 189 |
+
"remove_word",
|
| 190 |
+
"add_rem_word",
|
| 191 |
+
"change_word",
|
| 192 |
+
"clear_words",
|
| 193 |
+
"generate_word_emb_vis",
|
| 194 |
+
"set_axis",
|
| 195 |
+
"axis",
|
| 196 |
+
]
|
src/pipelines/guidance.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from src.util.base import *
|
| 3 |
+
from src.util.params import *
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def display_guidance_images(
|
| 7 |
+
prompt, seed, num_inference_steps, guidance_values, progress=gr.Progress()
|
| 8 |
+
):
|
| 9 |
+
text_embeddings = get_text_embeddings(prompt)
|
| 10 |
+
latents = generate_latents(seed)
|
| 11 |
+
|
| 12 |
+
progress(0)
|
| 13 |
+
images = []
|
| 14 |
+
guidance_values = guidance_values.replace(",", " ").split()
|
| 15 |
+
num_images = len(guidance_values)
|
| 16 |
+
|
| 17 |
+
for i in range(num_images):
|
| 18 |
+
progress(i / num_images)
|
| 19 |
+
image = generate_images(
|
| 20 |
+
latents,
|
| 21 |
+
text_embeddings,
|
| 22 |
+
num_inference_steps,
|
| 23 |
+
guidance_scale=int(guidance_values[i]),
|
| 24 |
+
)
|
| 25 |
+
images.append((image, "{}".format(int(guidance_values[i]))))
|
| 26 |
+
|
| 27 |
+
fname = "guidance"
|
| 28 |
+
tab_config = {
|
| 29 |
+
"Tab": "Guidance",
|
| 30 |
+
"Prompt": prompt,
|
| 31 |
+
"Guidance Scale Values": guidance_values,
|
| 32 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
| 33 |
+
"Seed": seed,
|
| 34 |
+
}
|
| 35 |
+
export_as_zip(images, fname, tab_config)
|
| 36 |
+
return images, f"outputs/{fname}.zip"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
__all__ = ["display_guidance_images"]
|
src/pipelines/inpainting.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from src.util.base import *
|
| 4 |
+
from src.util.params import *
|
| 5 |
+
from diffusers import AutoPipelineForInpainting
|
| 6 |
+
|
| 7 |
+
# inpaint_pipe = AutoPipelineForInpainting.from_pretrained(inpaint_model_path).to(torch_device)
|
| 8 |
+
inpaint_pipe = AutoPipelineForInpainting.from_pipe(pipe).to(torch_device)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def inpaint(dict, num_inference_steps, seed, prompt="", progress=gr.Progress()):
|
| 12 |
+
progress(0)
|
| 13 |
+
mask = dict["mask"].convert("RGB").resize((imageHeight, imageWidth))
|
| 14 |
+
init_image = dict["image"].convert("RGB").resize((imageHeight, imageWidth))
|
| 15 |
+
output = inpaint_pipe(
|
| 16 |
+
prompt=prompt,
|
| 17 |
+
image=init_image,
|
| 18 |
+
mask_image=mask,
|
| 19 |
+
guidance_scale=guidance_scale,
|
| 20 |
+
num_inference_steps=num_inference_steps,
|
| 21 |
+
generator=torch.Generator().manual_seed(seed),
|
| 22 |
+
)
|
| 23 |
+
progress(1)
|
| 24 |
+
|
| 25 |
+
fname = "inpainting"
|
| 26 |
+
tab_config = {
|
| 27 |
+
"Tab": "Inpainting",
|
| 28 |
+
"Prompt": prompt,
|
| 29 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
| 30 |
+
"Seed": seed,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
imgs_list = []
|
| 34 |
+
imgs_list.append((output.images[0], "Inpainted Image"))
|
| 35 |
+
imgs_list.append((mask, "Mask"))
|
| 36 |
+
|
| 37 |
+
export_as_zip(imgs_list, fname, tab_config)
|
| 38 |
+
return output.images[0], f"outputs/{fname}.zip"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
__all__ = ["inpaint"]
|
src/pipelines/interpolate.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from src.util.base import *
|
| 4 |
+
from src.util.params import *
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def interpolate_prompts(promptA, promptB, num_interpolation_steps):
|
| 8 |
+
text_embeddingsA = get_text_embeddings(promptA)
|
| 9 |
+
text_embeddingsB = get_text_embeddings(promptB)
|
| 10 |
+
|
| 11 |
+
interpolated_embeddings = []
|
| 12 |
+
|
| 13 |
+
for i in range(num_interpolation_steps):
|
| 14 |
+
alpha = i / num_interpolation_steps
|
| 15 |
+
interpolated_embedding = torch.lerp(text_embeddingsA, text_embeddingsB, alpha)
|
| 16 |
+
interpolated_embeddings.append(interpolated_embedding)
|
| 17 |
+
|
| 18 |
+
return interpolated_embeddings
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def display_interpolate_images(
|
| 22 |
+
seed, promptA, promptB, num_inference_steps, num_images, progress=gr.Progress()
|
| 23 |
+
):
|
| 24 |
+
latents = generate_latents(seed)
|
| 25 |
+
num_images = num_images + 2 # add 2 for first and last image
|
| 26 |
+
text_embeddings = interpolate_prompts(promptA, promptB, num_images)
|
| 27 |
+
images = []
|
| 28 |
+
progress(0)
|
| 29 |
+
|
| 30 |
+
for i in range(num_images):
|
| 31 |
+
progress(i / num_images)
|
| 32 |
+
image = generate_images(latents, text_embeddings[i], num_inference_steps)
|
| 33 |
+
images.append((image, "{}".format(i + 1)))
|
| 34 |
+
|
| 35 |
+
progress(1, desc="Exporting as gif")
|
| 36 |
+
export_as_gif(images, filename="interpolate.gif", reverse=True)
|
| 37 |
+
|
| 38 |
+
fname = "interpolate"
|
| 39 |
+
tab_config = {
|
| 40 |
+
"Tab": "Interpolate",
|
| 41 |
+
"First Prompt": promptA,
|
| 42 |
+
"Second Prompt": promptB,
|
| 43 |
+
"Number of Interpolation Steps": num_images,
|
| 44 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
| 45 |
+
"Seed": seed,
|
| 46 |
+
}
|
| 47 |
+
export_as_zip(images, fname, tab_config)
|
| 48 |
+
return images, "outputs/interpolate.gif", f"outputs/{fname}.zip"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
__all__ = ["display_interpolate_images"]
|
src/pipelines/negative.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from src.util.base import *
|
| 3 |
+
from src.util.params import *
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def display_negative_images(
|
| 7 |
+
prompt, seed, num_inference_steps, negative_prompt="", progress=gr.Progress()
|
| 8 |
+
):
|
| 9 |
+
text_embeddings = get_text_embeddings(prompt)
|
| 10 |
+
text_embeddings_neg = get_text_embeddings(prompt, negative_prompt=negative_prompt)
|
| 11 |
+
|
| 12 |
+
latents = generate_latents(seed)
|
| 13 |
+
|
| 14 |
+
progress(0)
|
| 15 |
+
images = generate_images(latents, text_embeddings, num_inference_steps)
|
| 16 |
+
|
| 17 |
+
progress(0.5)
|
| 18 |
+
images_neg = generate_images(latents, text_embeddings_neg, num_inference_steps)
|
| 19 |
+
|
| 20 |
+
fname = "negative"
|
| 21 |
+
tab_config = {
|
| 22 |
+
"Tab": "Negative",
|
| 23 |
+
"Prompt": prompt,
|
| 24 |
+
"Negative Prompt": negative_prompt,
|
| 25 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
| 26 |
+
"Seed": seed,
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
imgs_list = []
|
| 30 |
+
imgs_list.append((images, "Without Negative Prompt"))
|
| 31 |
+
imgs_list.append((images_neg, "With Negative Prompt"))
|
| 32 |
+
export_as_zip(imgs_list, fname, tab_config)
|
| 33 |
+
|
| 34 |
+
return images, images_neg, f"outputs/{fname}.zip"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
__all__ = ["display_negative_images"]
|
src/pipelines/perturbations.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from src.util.base import *
|
| 5 |
+
from src.util.params import *
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def display_perturb_images(
|
| 9 |
+
prompt,
|
| 10 |
+
seed,
|
| 11 |
+
num_inference_steps,
|
| 12 |
+
num_images,
|
| 13 |
+
perturbation_size,
|
| 14 |
+
progress=gr.Progress(),
|
| 15 |
+
):
|
| 16 |
+
text_embeddings = get_text_embeddings(prompt)
|
| 17 |
+
|
| 18 |
+
latents_x = generate_latents(seed)
|
| 19 |
+
scale_x = torch.cos(
|
| 20 |
+
torch.linspace(0, 2, num_images) * torch.pi * perturbation_size / 4
|
| 21 |
+
).to(torch_device)
|
| 22 |
+
noise_x = torch.tensordot(scale_x, latents_x, dims=0)
|
| 23 |
+
|
| 24 |
+
progress(0)
|
| 25 |
+
images = []
|
| 26 |
+
images.append(
|
| 27 |
+
(
|
| 28 |
+
generate_images(latents_x, text_embeddings, num_inference_steps),
|
| 29 |
+
"{}".format(1),
|
| 30 |
+
)
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
for i in range(num_images):
|
| 34 |
+
np.random.seed(i)
|
| 35 |
+
progress(i / (num_images))
|
| 36 |
+
latents_y = generate_latents(np.random.randint(0, 100000))
|
| 37 |
+
scale_y = torch.sin(
|
| 38 |
+
torch.linspace(0, 2, num_images) * torch.pi * perturbation_size / 4
|
| 39 |
+
).to(torch_device)
|
| 40 |
+
noise_y = torch.tensordot(scale_y, latents_y, dims=0)
|
| 41 |
+
|
| 42 |
+
noise = noise_x + noise_y
|
| 43 |
+
image = generate_images(
|
| 44 |
+
noise[num_images - 1], text_embeddings, num_inference_steps
|
| 45 |
+
)
|
| 46 |
+
images.append((image, "{}".format(i + 2)))
|
| 47 |
+
|
| 48 |
+
fname = "perturbations"
|
| 49 |
+
tab_config = {
|
| 50 |
+
"Tab": "Perturbations",
|
| 51 |
+
"Prompt": prompt,
|
| 52 |
+
"Number of Perturbations": num_images,
|
| 53 |
+
"Perturbation Size": perturbation_size,
|
| 54 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
| 55 |
+
"Seed": seed,
|
| 56 |
+
}
|
| 57 |
+
export_as_zip(images, fname, tab_config)
|
| 58 |
+
|
| 59 |
+
return images, f"outputs/{fname}.zip"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
__all__ = ["display_perturb_images"]
|
src/pipelines/poke.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from src.util.base import *
|
| 4 |
+
from src.util.params import *
|
| 5 |
+
from PIL import Image, ImageDraw
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def visualize_poke(
|
| 9 |
+
pokeX, pokeY, pokeHeight, pokeWidth, imageHeight=imageHeight, imageWidth=imageWidth
|
| 10 |
+
):
|
| 11 |
+
if (
|
| 12 |
+
(pokeX - pokeWidth // 2 < 0)
|
| 13 |
+
or (pokeX + pokeWidth // 2 > imageWidth // 8)
|
| 14 |
+
or (pokeY - pokeHeight // 2 < 0)
|
| 15 |
+
or (pokeY + pokeHeight // 2 > imageHeight // 8)
|
| 16 |
+
):
|
| 17 |
+
gr.Warning("Modification outside image")
|
| 18 |
+
shape = [
|
| 19 |
+
(pokeX * 8 - pokeWidth * 8 // 2, pokeY * 8 - pokeHeight * 8 // 2),
|
| 20 |
+
(pokeX * 8 + pokeWidth * 8 // 2, pokeY * 8 + pokeHeight * 8 // 2),
|
| 21 |
+
]
|
| 22 |
+
|
| 23 |
+
blank = Image.new("RGB", (imageWidth, imageHeight))
|
| 24 |
+
|
| 25 |
+
if os.path.exists("outputs/original.png"):
|
| 26 |
+
oImg = Image.open("outputs/original.png")
|
| 27 |
+
pImg = Image.open("outputs/poked.png")
|
| 28 |
+
else:
|
| 29 |
+
oImg = blank
|
| 30 |
+
pImg = blank
|
| 31 |
+
|
| 32 |
+
oRec = ImageDraw.Draw(oImg)
|
| 33 |
+
pRec = ImageDraw.Draw(pImg)
|
| 34 |
+
|
| 35 |
+
oRec.rectangle(shape, outline="white")
|
| 36 |
+
pRec.rectangle(shape, outline="white")
|
| 37 |
+
|
| 38 |
+
return oImg, pImg
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def display_poke_images(
|
| 42 |
+
prompt,
|
| 43 |
+
seed,
|
| 44 |
+
num_inference_steps,
|
| 45 |
+
poke=False,
|
| 46 |
+
pokeX=None,
|
| 47 |
+
pokeY=None,
|
| 48 |
+
pokeHeight=None,
|
| 49 |
+
pokeWidth=None,
|
| 50 |
+
intermediate=False,
|
| 51 |
+
progress=gr.Progress(),
|
| 52 |
+
):
|
| 53 |
+
text_embeddings = get_text_embeddings(prompt)
|
| 54 |
+
latents, modified_latents = generate_modified_latents(
|
| 55 |
+
poke, seed, pokeX, pokeY, pokeHeight, pokeWidth
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
progress(0)
|
| 59 |
+
images = generate_images(
|
| 60 |
+
latents, text_embeddings, num_inference_steps, intermediate=intermediate
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
if not intermediate:
|
| 64 |
+
images.save("outputs/original.png")
|
| 65 |
+
|
| 66 |
+
if poke:
|
| 67 |
+
progress(0.5)
|
| 68 |
+
modImages = generate_images(
|
| 69 |
+
modified_latents,
|
| 70 |
+
text_embeddings,
|
| 71 |
+
num_inference_steps,
|
| 72 |
+
intermediate=intermediate,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if not intermediate:
|
| 76 |
+
modImages.save("outputs/poked.png")
|
| 77 |
+
else:
|
| 78 |
+
modImages = None
|
| 79 |
+
|
| 80 |
+
return images, modImages
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
__all__ = ["display_poke_images", "visualize_poke"]
|
src/pipelines/seed.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from src.util.base import *
|
| 3 |
+
from src.util.params import *
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def display_seed_images(
|
| 7 |
+
prompt, num_inference_steps, num_images, progress=gr.Progress()
|
| 8 |
+
):
|
| 9 |
+
text_embeddings = get_text_embeddings(prompt)
|
| 10 |
+
|
| 11 |
+
images = []
|
| 12 |
+
progress(0)
|
| 13 |
+
|
| 14 |
+
for i in range(num_images):
|
| 15 |
+
progress(i / num_images)
|
| 16 |
+
latents = generate_latents(i)
|
| 17 |
+
image = generate_images(latents, text_embeddings, num_inference_steps)
|
| 18 |
+
images.append((image, "{}".format(i + 1)))
|
| 19 |
+
|
| 20 |
+
fname = "seeds"
|
| 21 |
+
tab_config = {
|
| 22 |
+
"Tab": "Seeds",
|
| 23 |
+
"Prompt": prompt,
|
| 24 |
+
"Number of Seeds": num_images,
|
| 25 |
+
"Number of Inference Steps per Image": num_inference_steps,
|
| 26 |
+
}
|
| 27 |
+
export_as_zip(images, fname, tab_config)
|
| 28 |
+
|
| 29 |
+
return images, f"outputs/{fname}.zip"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
__all__ = ["display_seed_images"]
|
src/util/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import *
|
| 2 |
+
from .params import *
|
| 3 |
+
from .clip_config import *
|
src/util/base.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import zipfile
|
| 5 |
+
import numpy as np
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from tqdm.auto import tqdm
|
| 9 |
+
from src.util.params import *
|
| 10 |
+
from src.util.clip_config import *
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_text_embeddings(
|
| 15 |
+
prompt,
|
| 16 |
+
tokenizer=tokenizer,
|
| 17 |
+
text_encoder=text_encoder,
|
| 18 |
+
torch_device=torch_device,
|
| 19 |
+
batch_size=1,
|
| 20 |
+
negative_prompt="",
|
| 21 |
+
):
|
| 22 |
+
text_input = tokenizer(
|
| 23 |
+
prompt,
|
| 24 |
+
padding="max_length",
|
| 25 |
+
max_length=tokenizer.model_max_length,
|
| 26 |
+
truncation=True,
|
| 27 |
+
return_tensors="pt",
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
|
| 32 |
+
max_length = text_input.input_ids.shape[-1]
|
| 33 |
+
uncond_input = tokenizer(
|
| 34 |
+
[negative_prompt] * batch_size,
|
| 35 |
+
padding="max_length",
|
| 36 |
+
max_length=max_length,
|
| 37 |
+
return_tensors="pt",
|
| 38 |
+
)
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
|
| 41 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
| 42 |
+
|
| 43 |
+
return text_embeddings
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def generate_latents(
|
| 47 |
+
seed,
|
| 48 |
+
height=imageHeight,
|
| 49 |
+
width=imageWidth,
|
| 50 |
+
torch_device=torch_device,
|
| 51 |
+
unet=unet,
|
| 52 |
+
batch_size=1,
|
| 53 |
+
):
|
| 54 |
+
generator = torch.Generator().manual_seed(int(seed))
|
| 55 |
+
|
| 56 |
+
latents = torch.randn(
|
| 57 |
+
(batch_size, unet.config.in_channels, height // 8, width // 8),
|
| 58 |
+
generator=generator,
|
| 59 |
+
).to(torch_device)
|
| 60 |
+
|
| 61 |
+
return latents
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def generate_modified_latents(
|
| 65 |
+
poke,
|
| 66 |
+
seed,
|
| 67 |
+
pokeX=None,
|
| 68 |
+
pokeY=None,
|
| 69 |
+
pokeHeight=None,
|
| 70 |
+
pokeWidth=None,
|
| 71 |
+
imageHeight=imageHeight,
|
| 72 |
+
imageWidth=imageWidth,
|
| 73 |
+
):
|
| 74 |
+
original_latents = generate_latents(seed, height=imageHeight, width=imageWidth)
|
| 75 |
+
if poke:
|
| 76 |
+
np.random.seed(seed)
|
| 77 |
+
poke_latents = generate_latents(
|
| 78 |
+
np.random.randint(0, 100000), height=pokeHeight * 8, width=pokeWidth * 8
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
x_origin = pokeX - pokeWidth // 2
|
| 82 |
+
y_origin = pokeY - pokeHeight // 2
|
| 83 |
+
|
| 84 |
+
modified_latents = original_latents.clone()
|
| 85 |
+
modified_latents[
|
| 86 |
+
:, :, y_origin : y_origin + pokeHeight, x_origin : x_origin + pokeWidth
|
| 87 |
+
] = poke_latents
|
| 88 |
+
else:
|
| 89 |
+
modified_latents = None
|
| 90 |
+
|
| 91 |
+
return original_latents, modified_latents
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def convert_to_pil_image(image):
|
| 95 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 96 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 97 |
+
images = (image * 255).round().astype("uint8")
|
| 98 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 99 |
+
return pil_images[0]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def generate_images(
|
| 103 |
+
latents,
|
| 104 |
+
text_embeddings,
|
| 105 |
+
num_inference_steps,
|
| 106 |
+
unet=unet,
|
| 107 |
+
guidance_scale=guidance_scale,
|
| 108 |
+
vae=vae,
|
| 109 |
+
scheduler=scheduler,
|
| 110 |
+
intermediate=False,
|
| 111 |
+
progress=gr.Progress(),
|
| 112 |
+
):
|
| 113 |
+
scheduler.set_timesteps(num_inference_steps)
|
| 114 |
+
latents = latents * scheduler.init_noise_sigma
|
| 115 |
+
images = []
|
| 116 |
+
i = 1
|
| 117 |
+
|
| 118 |
+
for t in tqdm(scheduler.timesteps):
|
| 119 |
+
latent_model_input = torch.cat([latents] * 2)
|
| 120 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
|
| 121 |
+
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
noise_pred = unet(
|
| 124 |
+
latent_model_input, t, encoder_hidden_states=text_embeddings
|
| 125 |
+
).sample
|
| 126 |
+
|
| 127 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 128 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 129 |
+
noise_pred_text - noise_pred_uncond
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
if intermediate:
|
| 133 |
+
progress(((1000 - t) / 1000))
|
| 134 |
+
Latents = 1 / 0.18215 * latents
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
image = vae.decode(Latents).sample
|
| 137 |
+
images.append((convert_to_pil_image(image), "{}".format(i)))
|
| 138 |
+
|
| 139 |
+
latents = scheduler.step(noise_pred, t, latents).prev_sample
|
| 140 |
+
i += 1
|
| 141 |
+
|
| 142 |
+
if not intermediate:
|
| 143 |
+
Latents = 1 / 0.18215 * latents
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
image = vae.decode(Latents).sample
|
| 146 |
+
images = convert_to_pil_image(image)
|
| 147 |
+
|
| 148 |
+
return images
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def get_word_embeddings(
|
| 152 |
+
prompt, tokenizer=tokenizer, text_encoder=text_encoder, torch_device=torch_device
|
| 153 |
+
):
|
| 154 |
+
text_input = tokenizer(
|
| 155 |
+
prompt,
|
| 156 |
+
padding="max_length",
|
| 157 |
+
max_length=tokenizer.model_max_length,
|
| 158 |
+
truncation=True,
|
| 159 |
+
return_tensors="pt",
|
| 160 |
+
).to(torch_device)
|
| 161 |
+
|
| 162 |
+
with torch.no_grad():
|
| 163 |
+
text_embeddings = text_encoder(text_input.input_ids)[0].reshape(1, -1)
|
| 164 |
+
|
| 165 |
+
text_embeddings = text_embeddings.cpu().numpy()
|
| 166 |
+
return text_embeddings / np.linalg.norm(text_embeddings)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def get_concat_embeddings(names, merge=False):
|
| 170 |
+
embeddings = []
|
| 171 |
+
|
| 172 |
+
for name in names:
|
| 173 |
+
embedding = get_word_embeddings(name)
|
| 174 |
+
embeddings.append(embedding)
|
| 175 |
+
|
| 176 |
+
embeddings = np.vstack(embeddings)
|
| 177 |
+
|
| 178 |
+
if merge:
|
| 179 |
+
embeddings = np.average(embeddings, axis=0).reshape(1, -1)
|
| 180 |
+
|
| 181 |
+
return embeddings
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_axis_embeddings(A, B):
|
| 185 |
+
emb = []
|
| 186 |
+
|
| 187 |
+
for a, b in zip(A, B):
|
| 188 |
+
e = get_word_embeddings(a) - get_word_embeddings(b)
|
| 189 |
+
emb.append(e)
|
| 190 |
+
|
| 191 |
+
emb = np.vstack(emb)
|
| 192 |
+
ax = np.average(emb, axis=0).reshape(1, -1)
|
| 193 |
+
|
| 194 |
+
return ax
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def calculate_residual(
|
| 198 |
+
axis, axis_names, from_words=None, to_words=None, residual_axis=1
|
| 199 |
+
):
|
| 200 |
+
axis_indices = [0, 1, 2]
|
| 201 |
+
axis_indices.remove(residual_axis)
|
| 202 |
+
|
| 203 |
+
if axis_names[axis_indices[0]] in axis_combinations:
|
| 204 |
+
fembeddings = get_concat_embeddings(
|
| 205 |
+
axis_combinations[axis_names[axis_indices[0]]], merge=True
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
axis_combinations[axis_names[axis_indices[0]]] = from_words + to_words
|
| 209 |
+
fembeddings = get_concat_embeddings(from_words + to_words, merge=True)
|
| 210 |
+
|
| 211 |
+
if axis_names[axis_indices[1]] in axis_combinations:
|
| 212 |
+
sembeddings = get_concat_embeddings(
|
| 213 |
+
axis_combinations[axis_names[axis_indices[1]]], merge=True
|
| 214 |
+
)
|
| 215 |
+
else:
|
| 216 |
+
axis_combinations[axis_names[axis_indices[1]]] = from_words + to_words
|
| 217 |
+
sembeddings = get_concat_embeddings(from_words + to_words, merge=True)
|
| 218 |
+
|
| 219 |
+
fprojections = fembeddings @ axis[axis_indices[0]].T
|
| 220 |
+
sprojections = sembeddings @ axis[axis_indices[1]].T
|
| 221 |
+
|
| 222 |
+
partial_residual = fembeddings - (fprojections.reshape(-1, 1) * fembeddings)
|
| 223 |
+
residual = partial_residual - (sprojections.reshape(-1, 1) * sembeddings)
|
| 224 |
+
|
| 225 |
+
return residual
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def calculate_step_size(num_images, differentiation):
|
| 229 |
+
return differentiation / (num_images - 1)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def generate_seed_vis(seed):
|
| 233 |
+
np.random.seed(seed)
|
| 234 |
+
emb = np.random.rand(15)
|
| 235 |
+
plt.close()
|
| 236 |
+
plt.switch_backend("agg")
|
| 237 |
+
plt.figure(figsize=(10, 0.5))
|
| 238 |
+
plt.imshow([emb], cmap="viridis")
|
| 239 |
+
plt.axis("off")
|
| 240 |
+
return plt
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def export_as_gif(images, filename, frames_per_second=2, reverse=False):
|
| 244 |
+
imgs = [img[0] for img in images]
|
| 245 |
+
|
| 246 |
+
if reverse:
|
| 247 |
+
imgs += imgs[2:-1][::-1]
|
| 248 |
+
|
| 249 |
+
imgs[0].save(
|
| 250 |
+
f"outputs/{filename}",
|
| 251 |
+
format="GIF",
|
| 252 |
+
save_all=True,
|
| 253 |
+
append_images=imgs[1:],
|
| 254 |
+
duration=1000 // frames_per_second,
|
| 255 |
+
loop=0,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def export_as_zip(images, fname, tab_config=None):
|
| 260 |
+
|
| 261 |
+
if not os.path.exists(f"outputs/{fname}.zip"):
|
| 262 |
+
os.makedirs("outputs", exist_ok=True)
|
| 263 |
+
|
| 264 |
+
with zipfile.ZipFile(f"outputs/{fname}.zip", "w") as img_zip:
|
| 265 |
+
|
| 266 |
+
if tab_config:
|
| 267 |
+
with open("outputs/config.txt", "w") as f:
|
| 268 |
+
for key, value in tab_config.items():
|
| 269 |
+
f.write(f"{key}: {value}\n")
|
| 270 |
+
f.close()
|
| 271 |
+
|
| 272 |
+
img_zip.write("outputs/config.txt", "config.txt")
|
| 273 |
+
|
| 274 |
+
for idx, img in enumerate(images):
|
| 275 |
+
buff = io.BytesIO()
|
| 276 |
+
img[0].save(buff, format="PNG")
|
| 277 |
+
buff = buff.getvalue()
|
| 278 |
+
max_num = len(images)
|
| 279 |
+
num_leading_zeros = len(str(max_num))
|
| 280 |
+
img_name = f"{{:0{num_leading_zeros}}}.png"
|
| 281 |
+
img_zip.writestr(img_name.format(idx + 1), buff)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def read_html(file_path):
|
| 285 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 286 |
+
content = f.read()
|
| 287 |
+
return content
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
__all__ = [
|
| 291 |
+
"get_text_embeddings",
|
| 292 |
+
"generate_latents",
|
| 293 |
+
"generate_modified_latents",
|
| 294 |
+
"generate_images",
|
| 295 |
+
"get_word_embeddings",
|
| 296 |
+
"get_concat_embeddings",
|
| 297 |
+
"get_axis_embeddings",
|
| 298 |
+
"calculate_residual",
|
| 299 |
+
"calculate_step_size",
|
| 300 |
+
"generate_seed_vis",
|
| 301 |
+
"export_as_gif",
|
| 302 |
+
"export_as_zip",
|
| 303 |
+
"read_html",
|
| 304 |
+
]
|
src/util/clip_config.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
masculine = [
|
| 2 |
+
"man",
|
| 3 |
+
"king",
|
| 4 |
+
"prince",
|
| 5 |
+
"husband",
|
| 6 |
+
"father",
|
| 7 |
+
]
|
| 8 |
+
|
| 9 |
+
feminine = [
|
| 10 |
+
"woman",
|
| 11 |
+
"queen",
|
| 12 |
+
"princess",
|
| 13 |
+
"wife",
|
| 14 |
+
"mother",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
young = [
|
| 18 |
+
"man",
|
| 19 |
+
"woman",
|
| 20 |
+
"king",
|
| 21 |
+
"queen",
|
| 22 |
+
"father",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
old = [
|
| 26 |
+
"boy",
|
| 27 |
+
"girl",
|
| 28 |
+
"prince",
|
| 29 |
+
"princess",
|
| 30 |
+
"son",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
common = [
|
| 34 |
+
"man",
|
| 35 |
+
"woman",
|
| 36 |
+
"boy",
|
| 37 |
+
"girl",
|
| 38 |
+
"woman",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
elite = [
|
| 42 |
+
"king",
|
| 43 |
+
"queen",
|
| 44 |
+
"prince",
|
| 45 |
+
"princess",
|
| 46 |
+
"duchess",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
singular = [
|
| 50 |
+
"boy",
|
| 51 |
+
"girl",
|
| 52 |
+
"cat",
|
| 53 |
+
"puppy",
|
| 54 |
+
"computer",
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
plural = [
|
| 58 |
+
"boys",
|
| 59 |
+
"girls",
|
| 60 |
+
"cats",
|
| 61 |
+
"puppies",
|
| 62 |
+
"computers",
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
examples = [
|
| 66 |
+
"king",
|
| 67 |
+
"queen",
|
| 68 |
+
"man",
|
| 69 |
+
"woman",
|
| 70 |
+
"boys",
|
| 71 |
+
"girls",
|
| 72 |
+
"apple",
|
| 73 |
+
"orange",
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
axis_names = ["gender", "residual", "age"]
|
| 77 |
+
|
| 78 |
+
axis_combinations = {
|
| 79 |
+
"age": young + old,
|
| 80 |
+
"gender": masculine + feminine,
|
| 81 |
+
"royalty": common + elite,
|
| 82 |
+
"number": singular + plural,
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
axisMap = {
|
| 86 |
+
"X - Axis": 0,
|
| 87 |
+
"Y - Axis": 1,
|
| 88 |
+
"Z - Axis": 2,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
whichAxisMap = {
|
| 92 |
+
"which_axis_1": "X - Axis",
|
| 93 |
+
"which_axis_2": "Z - Axis",
|
| 94 |
+
"which_axis_3": "Y - Axis",
|
| 95 |
+
"which_axis_4": "---",
|
| 96 |
+
"which_axis_5": "---",
|
| 97 |
+
"which_axis_6": "---",
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
__all__ = [
|
| 101 |
+
"axisMap",
|
| 102 |
+
"whichAxisMap",
|
| 103 |
+
"axis_names",
|
| 104 |
+
"axis_combinations",
|
| 105 |
+
"examples",
|
| 106 |
+
"masculine",
|
| 107 |
+
"feminine",
|
| 108 |
+
"young",
|
| 109 |
+
"old",
|
| 110 |
+
"common",
|
| 111 |
+
"elite",
|
| 112 |
+
"singular",
|
| 113 |
+
"plural",
|
| 114 |
+
]
|
src/util/params.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import secrets
|
| 3 |
+
from gradio.networking import setup_tunnel
|
| 4 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 5 |
+
from diffusers import (
|
| 6 |
+
AutoencoderKL,
|
| 7 |
+
UNet2DConditionModel,
|
| 8 |
+
LCMScheduler,
|
| 9 |
+
DDIMScheduler,
|
| 10 |
+
StableDiffusionPipeline,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
+
|
| 15 |
+
isLCM = False
|
| 16 |
+
HF_ACCESS_TOKEN = ""
|
| 17 |
+
|
| 18 |
+
model_path = "segmind/small-sd"
|
| 19 |
+
inpaint_model_path = "Lykon/dreamshaper-8-inpainting"
|
| 20 |
+
prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
|
| 21 |
+
promptA = "Self-portrait oil painting, a beautiful man with golden hair, 8k"
|
| 22 |
+
promptB = "Self-portrait oil painting, a beautiful woman with golden hair, 8k"
|
| 23 |
+
negative_prompt = "a photo frame"
|
| 24 |
+
|
| 25 |
+
num_images = 5
|
| 26 |
+
degree = 360
|
| 27 |
+
perturbation_size = 0.1
|
| 28 |
+
num_inference_steps = 8
|
| 29 |
+
seed = 69420
|
| 30 |
+
|
| 31 |
+
guidance_scale = 8
|
| 32 |
+
guidance_values = "1, 8, 20"
|
| 33 |
+
intermediate = True
|
| 34 |
+
pokeX, pokeY = 256, 256
|
| 35 |
+
pokeHeight, pokeWidth = 128, 128
|
| 36 |
+
imageHeight, imageWidth = 512, 512
|
| 37 |
+
|
| 38 |
+
tokenizer = CLIPTokenizer.from_pretrained(model_path, subfolder="tokenizer")
|
| 39 |
+
text_encoder = CLIPTextModel.from_pretrained(model_path, subfolder="text_encoder").to(
|
| 40 |
+
torch_device
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
if isLCM:
|
| 44 |
+
scheduler = LCMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
| 45 |
+
else:
|
| 46 |
+
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
|
| 47 |
+
|
| 48 |
+
unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet").to(
|
| 49 |
+
torch_device
|
| 50 |
+
)
|
| 51 |
+
vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(torch_device)
|
| 52 |
+
|
| 53 |
+
pipe = StableDiffusionPipeline(
|
| 54 |
+
tokenizer=tokenizer,
|
| 55 |
+
text_encoder=text_encoder,
|
| 56 |
+
unet=unet,
|
| 57 |
+
scheduler=scheduler,
|
| 58 |
+
vae=vae,
|
| 59 |
+
safety_checker=None,
|
| 60 |
+
feature_extractor=None,
|
| 61 |
+
requires_safety_checker=False,
|
| 62 |
+
).to(torch_device)
|
| 63 |
+
|
| 64 |
+
dash_tunnel = setup_tunnel("0.0.0.0", 8000, secrets.token_urlsafe(32))
|
| 65 |
+
|
| 66 |
+
__all__ = [
|
| 67 |
+
"prompt",
|
| 68 |
+
"negative_prompt",
|
| 69 |
+
"num_images",
|
| 70 |
+
"degree",
|
| 71 |
+
"perturbation_size",
|
| 72 |
+
"num_inference_steps",
|
| 73 |
+
"seed",
|
| 74 |
+
"intermediate",
|
| 75 |
+
"pokeX",
|
| 76 |
+
"pokeY",
|
| 77 |
+
"pokeHeight",
|
| 78 |
+
"pokeWidth",
|
| 79 |
+
"promptA",
|
| 80 |
+
"promptB",
|
| 81 |
+
"tokenizer",
|
| 82 |
+
"text_encoder",
|
| 83 |
+
"scheduler",
|
| 84 |
+
"unet",
|
| 85 |
+
"vae",
|
| 86 |
+
"torch_device",
|
| 87 |
+
"imageHeight",
|
| 88 |
+
"imageWidth",
|
| 89 |
+
"guidance_scale",
|
| 90 |
+
"guidance_values",
|
| 91 |
+
"HF_ACCESS_TOKEN",
|
| 92 |
+
"model_path",
|
| 93 |
+
"inpaint_model_path",
|
| 94 |
+
"dash_tunnel",
|
| 95 |
+
"pipe",
|
| 96 |
+
]
|