r3gm lnyan commited on
Commit
164b1a9
·
verified ·
0 Parent(s):

Super-squash branch 'main' using huggingface_hub

Browse files

Co-authored-by: lnyan <lnyan@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
PyPatchMatch/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ /build/
2
+ /*.so
3
+ __pycache__
4
+ *.py[cod]
PyPatchMatch/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Jiayuan Mao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
PyPatchMatch/Makefile ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Makefile
3
+ # Jiayuan Mao, 2019-01-09 13:59
4
+ #
5
+
6
+ SRC_DIR = csrc
7
+ INC_DIR = csrc
8
+ OBJ_DIR = build/obj
9
+ TARGET = libpatchmatch.so
10
+
11
+ LIB_TARGET = $(TARGET)
12
+ INCLUDE_DIR = -I $(SRC_DIR) -I $(INC_DIR)
13
+
14
+ CXX = $(ENVIRONMENT_OPTIONS) g++
15
+ CXXFLAGS = -std=c++14
16
+ CXXFLAGS += -Ofast -ffast-math -w
17
+ # CXXFLAGS += -g
18
+ CXXFLAGS += $(shell pkg-config --cflags opencv) -fPIC
19
+ CXXFLAGS += $(INCLUDE_DIR)
20
+ LDFLAGS = $(shell pkg-config --cflags --libs opencv) -shared -fPIC
21
+
22
+
23
+ CXXSOURCES = $(shell find $(SRC_DIR)/ -name "*.cpp")
24
+ OBJS = $(addprefix $(OBJ_DIR)/,$(CXXSOURCES:.cpp=.o))
25
+ DEPFILES = $(OBJS:.o=.d)
26
+
27
+ .PHONY: all clean rebuild test
28
+
29
+ all: $(LIB_TARGET)
30
+
31
+ $(OBJ_DIR)/%.o: %.cpp
32
+ @echo "[CC] $< ..."
33
+ @$(CXX) -c $< $(CXXFLAGS) -o $@
34
+
35
+ $(OBJ_DIR)/%.d: %.cpp
36
+ @mkdir -pv $(dir $@)
37
+ @echo "[dep] $< ..."
38
+ @$(CXX) $(INCLUDE_DIR) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cpp=.o) $(OBJ_DIR)/$(<:.cpp=.d)" "$<" > "$@"
39
+
40
+ sinclude $(DEPFILES)
41
+
42
+ $(LIB_TARGET): $(OBJS)
43
+ @echo "[link] $(LIB_TARGET) ..."
44
+ @$(CXX) $(OBJS) -o $@ $(CXXFLAGS) $(LDFLAGS)
45
+
46
+ clean:
47
+ rm -rf $(OBJ_DIR) $(LIB_TARGET)
48
+
49
+ rebuild:
50
+ +@make clean
51
+ +@make
52
+
53
+ # vim:ft=make
54
+ #
PyPatchMatch/README.md ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PatchMatch based Inpainting
2
+ =====================================
3
+ This library implements the PatchMatch based inpainting algorithm. It provides both C++ and Python interfaces.
4
+ This implementation is heavily based on the implementation by Younesse ANDAM:
5
+ (younesse-cv/PatchMatch)[https://github.com/younesse-cv/PatchMatch], with some bugs fix.
6
+
7
+ Usage
8
+ -------------------------------------
9
+
10
+ You need to first install OpenCV to compile the C++ libraries. Then, run `make` to compile the
11
+ shared library `libpatchmatch.so`.
12
+
13
+ For Python users (example available at `examples/py_example.py`)
14
+
15
+ ```python
16
+ import patch_match
17
+
18
+ image = ... # either a numpy ndarray or a PIL Image object.
19
+ mask = ... # either a numpy ndarray or a PIL Image object.
20
+ result = patch_match.inpaint(image, mask, patch_size=5)
21
+ ```
22
+
23
+ For C++ users (examples available at `examples/cpp_example.cpp`)
24
+
25
+ ```cpp
26
+ #include "inpaint.h"
27
+
28
+ int main() {
29
+ cv::Mat image = ...
30
+ cv::Mat mask = ...
31
+
32
+ cv::Mat result = Inpainting(image, mask, 5).run();
33
+
34
+ return 0;
35
+ }
36
+ ```
37
+
38
+
39
+ README and COPYRIGHT by Younesse ANDAM
40
+ -------------------------------------
41
+ @Author: Younesse ANDAM
42
+
43
+ @Contact: younesse.andam@gmail.com
44
+
45
+ Description: This project is a personal implementation of an algorithm called PATCHMATCH that restores missing areas in an image.
46
+ The algorithm is presented in the following paper
47
+ PatchMatch A Randomized Correspondence Algorithm
48
+ for Structural Image Editing
49
+ by C.Barnes,E.Shechtman,A.Finkelstein and Dan B.Goldman
50
+ ACM Transactions on Graphics (Proc. SIGGRAPH), vol.28, aug-2009
51
+
52
+ For more information please refer to
53
+ http://www.cs.princeton.edu/gfx/pubs/Barnes_2009_PAR/index.php
54
+
55
+ Copyright (c) 2010-2011
56
+
57
+
58
+ Requirements
59
+ -------------------------------------
60
+
61
+ To run the project you need to install Opencv library and link it to your project.
62
+ Opencv can be download it here
63
+ http://opencv.org/downloads.html
64
+
PyPatchMatch/csrc/inpaint.cpp ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <algorithm>
2
+ #include <iostream>
3
+ #include <opencv2/imgcodecs.hpp>
4
+ #include <opencv2/imgproc.hpp>
5
+ #include <opencv2/highgui.hpp>
6
+
7
+ #include "inpaint.h"
8
+
9
+ namespace {
10
+ static std::vector<double> kDistance2Similarity;
11
+
12
+ void init_kDistance2Similarity() {
13
+ double base[11] = {1.0, 0.99, 0.96, 0.83, 0.38, 0.11, 0.02, 0.005, 0.0006, 0.0001, 0};
14
+ int length = (PatchDistanceMetric::kDistanceScale + 1);
15
+ kDistance2Similarity.resize(length);
16
+ for (int i = 0; i < length; ++i) {
17
+ double t = (double) i / length;
18
+ int j = (int) (100 * t);
19
+ int k = j + 1;
20
+ double vj = (j < 11) ? base[j] : 0;
21
+ double vk = (k < 11) ? base[k] : 0;
22
+ kDistance2Similarity[i] = vj + (100 * t - j) * (vk - vj);
23
+ }
24
+ }
25
+
26
+
27
+ inline void _weighted_copy(const MaskedImage &source, int ys, int xs, cv::Mat &target, int yt, int xt, double weight) {
28
+ if (source.is_masked(ys, xs)) return;
29
+ if (source.is_globally_masked(ys, xs)) return;
30
+
31
+ auto source_ptr = source.get_image(ys, xs);
32
+ auto target_ptr = target.ptr<double>(yt, xt);
33
+
34
+ #pragma unroll
35
+ for (int c = 0; c < 3; ++c)
36
+ target_ptr[c] += static_cast<double>(source_ptr[c]) * weight;
37
+ target_ptr[3] += weight;
38
+ }
39
+ }
40
+
41
+ /**
42
+ * This algorithme uses a version proposed by Xavier Philippeau.
43
+ */
44
+
45
+ Inpainting::Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric)
46
+ : m_initial(image, mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
47
+ _initialize_pyramid();
48
+ }
49
+
50
+ Inpainting::Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric)
51
+ : m_initial(image, mask, global_mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
52
+ _initialize_pyramid();
53
+ }
54
+
55
+ void Inpainting::_initialize_pyramid() {
56
+ auto source = m_initial;
57
+ m_pyramid.push_back(source);
58
+ while (source.size().height > m_distance_metric->patch_size() && source.size().width > m_distance_metric->patch_size()) {
59
+ source = source.downsample();
60
+ m_pyramid.push_back(source);
61
+ }
62
+
63
+ if (kDistance2Similarity.size() == 0) {
64
+ init_kDistance2Similarity();
65
+ }
66
+ }
67
+
68
+ cv::Mat Inpainting::run(bool verbose, bool verbose_visualize, unsigned int random_seed) {
69
+ srand(random_seed);
70
+ const int nr_levels = m_pyramid.size();
71
+
72
+ MaskedImage source, target;
73
+ for (int level = nr_levels - 1; level >= 0; --level) {
74
+ if (verbose) std::cerr << "Inpainting level: " << level << std::endl;
75
+
76
+ source = m_pyramid[level];
77
+
78
+ if (level == nr_levels - 1) {
79
+ target = source.clone();
80
+ target.clear_mask();
81
+ m_source2target = NearestNeighborField(source, target, m_distance_metric);
82
+ m_target2source = NearestNeighborField(target, source, m_distance_metric);
83
+ } else {
84
+ m_source2target = NearestNeighborField(source, target, m_distance_metric, m_source2target);
85
+ m_target2source = NearestNeighborField(target, source, m_distance_metric, m_target2source);
86
+ }
87
+
88
+ if (verbose) std::cerr << "Initialization done." << std::endl;
89
+
90
+ if (verbose_visualize) {
91
+ auto visualize_size = m_initial.size();
92
+ cv::Mat source_visualize(visualize_size, m_initial.image().type());
93
+ cv::resize(source.image(), source_visualize, visualize_size);
94
+ cv::imshow("Source", source_visualize);
95
+ cv::Mat target_visualize(visualize_size, m_initial.image().type());
96
+ cv::resize(target.image(), target_visualize, visualize_size);
97
+ cv::imshow("Target", target_visualize);
98
+ cv::waitKey(0);
99
+ }
100
+
101
+ target = _expectation_maximization(source, target, level, verbose);
102
+ }
103
+
104
+ return target.image();
105
+ }
106
+
107
+ // EM-Like algorithm (see "PatchMatch" - page 6).
108
+ // Returns a double sized target image (unless level = 0).
109
+ MaskedImage Inpainting::_expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose) {
110
+ const int nr_iters_em = 1 + 2 * level;
111
+ const int nr_iters_nnf = static_cast<int>(std::min(7, 1 + level));
112
+ const int patch_size = m_distance_metric->patch_size();
113
+
114
+ MaskedImage new_source, new_target;
115
+
116
+ for (int iter_em = 0; iter_em < nr_iters_em; ++iter_em) {
117
+ if (iter_em != 0) {
118
+ m_source2target.set_target(new_target);
119
+ m_target2source.set_source(new_target);
120
+ target = new_target;
121
+ }
122
+
123
+ if (verbose) std::cerr << "EM Iteration: " << iter_em << std::endl;
124
+
125
+ auto size = source.size();
126
+ for (int i = 0; i < size.height; ++i) {
127
+ for (int j = 0; j < size.width; ++j) {
128
+ if (!source.contains_mask(i, j, patch_size)) {
129
+ m_source2target.set_identity(i, j);
130
+ m_target2source.set_identity(i, j);
131
+ }
132
+ }
133
+ }
134
+ if (verbose) std::cerr << " NNF minimization started." << std::endl;
135
+ m_source2target.minimize(nr_iters_nnf);
136
+ m_target2source.minimize(nr_iters_nnf);
137
+ if (verbose) std::cerr << " NNF minimization finished." << std::endl;
138
+
139
+ // Instead of upsizing the final target, we build the last target from the next level source image.
140
+ // Thus, the final target is less blurry (see "Space-Time Video Completion" - page 5).
141
+ bool upscaled = false;
142
+ if (level >= 1 && iter_em == nr_iters_em - 1) {
143
+ new_source = m_pyramid[level - 1];
144
+ new_target = target.upsample(new_source.size().width, new_source.size().height, m_pyramid[level - 1].global_mask());
145
+ upscaled = true;
146
+ } else {
147
+ new_source = m_pyramid[level];
148
+ new_target = target.clone();
149
+ }
150
+
151
+ auto vote = cv::Mat(new_target.size(), CV_64FC4);
152
+ vote.setTo(cv::Scalar::all(0));
153
+
154
+ // Votes for best patch from NNF Source->Target (completeness) and Target->Source (coherence).
155
+ _expectation_step(m_source2target, 1, vote, new_source, upscaled);
156
+ if (verbose) std::cerr << " Expectation source to target finished." << std::endl;
157
+ _expectation_step(m_target2source, 0, vote, new_source, upscaled);
158
+ if (verbose) std::cerr << " Expectation target to source finished." << std::endl;
159
+
160
+ // Compile votes and update pixel values.
161
+ _maximization_step(new_target, vote);
162
+ if (verbose) std::cerr << " Minimization step finished." << std::endl;
163
+ }
164
+
165
+ return new_target;
166
+ }
167
+
168
+ // Expectation step: vote for best estimations of each pixel.
169
+ void Inpainting::_expectation_step(
170
+ const NearestNeighborField &nnf, bool source2target,
171
+ cv::Mat &vote, const MaskedImage &source, bool upscaled
172
+ ) {
173
+ auto source_size = nnf.source_size();
174
+ auto target_size = nnf.target_size();
175
+ const int patch_size = m_distance_metric->patch_size();
176
+
177
+ for (int i = 0; i < source_size.height; ++i) {
178
+ for (int j = 0; j < source_size.width; ++j) {
179
+ if (nnf.source().is_globally_masked(i, j)) continue;
180
+ int yp = nnf.at(i, j, 0), xp = nnf.at(i, j, 1), dp = nnf.at(i, j, 2);
181
+ double w = kDistance2Similarity[dp];
182
+
183
+ for (int di = -patch_size; di <= patch_size; ++di) {
184
+ for (int dj = -patch_size; dj <= patch_size; ++dj) {
185
+ int ys = i + di, xs = j + dj, yt = yp + di, xt = xp + dj;
186
+ if (!(ys >= 0 && ys < source_size.height && xs >= 0 && xs < source_size.width)) continue;
187
+ if (nnf.source().is_globally_masked(ys, xs)) continue;
188
+ if (!(yt >= 0 && yt < target_size.height && xt >= 0 && xt < target_size.width)) continue;
189
+ if (nnf.target().is_globally_masked(yt, xt)) continue;
190
+
191
+ if (!source2target) {
192
+ std::swap(ys, yt);
193
+ std::swap(xs, xt);
194
+ }
195
+
196
+ if (upscaled) {
197
+ for (int uy = 0; uy < 2; ++uy) {
198
+ for (int ux = 0; ux < 2; ++ux) {
199
+ _weighted_copy(source, 2 * ys + uy, 2 * xs + ux, vote, 2 * yt + uy, 2 * xt + ux, w);
200
+ }
201
+ }
202
+ } else {
203
+ _weighted_copy(source, ys, xs, vote, yt, xt, w);
204
+ }
205
+ }
206
+ }
207
+ }
208
+ }
209
+ }
210
+
211
+ // Maximization Step: maximum likelihood of target pixel.
212
+ void Inpainting::_maximization_step(MaskedImage &target, const cv::Mat &vote) {
213
+ auto target_size = target.size();
214
+ for (int i = 0; i < target_size.height; ++i) {
215
+ for (int j = 0; j < target_size.width; ++j) {
216
+ const double *source_ptr = vote.ptr<double>(i, j);
217
+ unsigned char *target_ptr = target.get_mutable_image(i, j);
218
+
219
+ if (target.is_globally_masked(i, j)) {
220
+ continue;
221
+ }
222
+
223
+ if (source_ptr[3] > 0) {
224
+ unsigned char r = cv::saturate_cast<unsigned char>(source_ptr[0] / source_ptr[3]);
225
+ unsigned char g = cv::saturate_cast<unsigned char>(source_ptr[1] / source_ptr[3]);
226
+ unsigned char b = cv::saturate_cast<unsigned char>(source_ptr[2] / source_ptr[3]);
227
+ target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
228
+ } else {
229
+ target.set_mask(i, j, 0);
230
+ }
231
+ }
232
+ }
233
+ }
234
+
PyPatchMatch/csrc/inpaint.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <vector>
4
+
5
+ #include "masked_image.h"
6
+ #include "nnf.h"
7
+
8
+ class Inpainting {
9
+ public:
10
+ Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric);
11
+ Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric);
12
+ cv::Mat run(bool verbose = false, bool verbose_visualize = false, unsigned int random_seed = 1212);
13
+
14
+ private:
15
+ void _initialize_pyramid(void);
16
+ MaskedImage _expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose);
17
+ void _expectation_step(const NearestNeighborField &nnf, bool source2target, cv::Mat &vote, const MaskedImage &source, bool upscaled);
18
+ void _maximization_step(MaskedImage &target, const cv::Mat &vote);
19
+
20
+ MaskedImage m_initial;
21
+ std::vector<MaskedImage> m_pyramid;
22
+
23
+ NearestNeighborField m_source2target;
24
+ NearestNeighborField m_target2source;
25
+ const PatchDistanceMetric *m_distance_metric;
26
+ };
27
+
PyPatchMatch/csrc/masked_image.cpp ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "masked_image.h"
2
+ #include <algorithm>
3
+ #include <iostream>
4
+
5
+ const cv::Size MaskedImage::kDownsampleKernelSize = cv::Size(6, 6);
6
+ const int MaskedImage::kDownsampleKernel[6] = {1, 5, 10, 10, 5, 1};
7
+
8
+ bool MaskedImage::contains_mask(int y, int x, int patch_size) const {
9
+ auto mask_size = size();
10
+ for (int dy = -patch_size; dy <= patch_size; ++dy) {
11
+ for (int dx = -patch_size; dx <= patch_size; ++dx) {
12
+ int yy = y + dy, xx = x + dx;
13
+ if (yy >= 0 && yy < mask_size.height && xx >= 0 && xx < mask_size.width) {
14
+ if (is_masked(yy, xx) && !is_globally_masked(yy, xx)) return true;
15
+ }
16
+ }
17
+ }
18
+ return false;
19
+ }
20
+
21
+ MaskedImage MaskedImage::downsample() const {
22
+ const auto &kernel_size = MaskedImage::kDownsampleKernelSize;
23
+ const auto &kernel = MaskedImage::kDownsampleKernel;
24
+
25
+ const auto size = this->size();
26
+ const auto new_size = cv::Size(size.width / 2, size.height / 2);
27
+
28
+ auto ret = MaskedImage(new_size.width, new_size.height);
29
+ if (!m_global_mask.empty()) ret.init_global_mask_mat();
30
+ for (int y = 0; y < size.height - 1; y += 2) {
31
+ for (int x = 0; x < size.width - 1; x += 2) {
32
+ int r = 0, g = 0, b = 0, ksum = 0;
33
+ bool is_gmasked = true;
34
+
35
+ for (int dy = -kernel_size.height / 2 + 1; dy <= kernel_size.height / 2; ++dy) {
36
+ for (int dx = -kernel_size.width / 2 + 1; dx <= kernel_size.width / 2; ++dx) {
37
+ int yy = y + dy, xx = x + dx;
38
+ if (yy >= 0 && yy < size.height && xx >= 0 && xx < size.width) {
39
+ if (!is_globally_masked(yy, xx)) {
40
+ is_gmasked = false;
41
+ }
42
+ if (!is_masked(yy, xx)) {
43
+ auto source_ptr = get_image(yy, xx);
44
+ int k = kernel[kernel_size.height / 2 - 1 + dy] * kernel[kernel_size.width / 2 - 1 + dx];
45
+ r += source_ptr[0] * k, g += source_ptr[1] * k, b += source_ptr[2] * k;
46
+ ksum += k;
47
+ }
48
+ }
49
+ }
50
+ }
51
+
52
+ if (ksum > 0) r /= ksum, g /= ksum, b /= ksum;
53
+
54
+ if (!m_global_mask.empty()) {
55
+ ret.set_global_mask(y / 2, x / 2, is_gmasked);
56
+ }
57
+ if (ksum > 0) {
58
+ auto target_ptr = ret.get_mutable_image(y / 2, x / 2);
59
+ target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
60
+ ret.set_mask(y / 2, x / 2, 0);
61
+ } else {
62
+ ret.set_mask(y / 2, x / 2, 1);
63
+ }
64
+ }
65
+ }
66
+
67
+ return ret;
68
+ }
69
+
70
+ MaskedImage MaskedImage::upsample(int new_w, int new_h) const {
71
+ const auto size = this->size();
72
+ auto ret = MaskedImage(new_w, new_h);
73
+ if (!m_global_mask.empty()) ret.init_global_mask_mat();
74
+ for (int y = 0; y < new_h; ++y) {
75
+ for (int x = 0; x < new_w; ++x) {
76
+ int yy = y * size.height / new_h;
77
+ int xx = x * size.width / new_w;
78
+
79
+ if (is_globally_masked(yy, xx)) {
80
+ ret.set_global_mask(y, x, 1);
81
+ ret.set_mask(y, x, 1);
82
+ } else {
83
+ if (!m_global_mask.empty()) ret.set_global_mask(y, x, 0);
84
+
85
+ if (is_masked(yy, xx)) {
86
+ ret.set_mask(y, x, 1);
87
+ } else {
88
+ auto source_ptr = get_image(yy, xx);
89
+ auto target_ptr = ret.get_mutable_image(y, x);
90
+ for (int c = 0; c < 3; ++c)
91
+ target_ptr[c] = source_ptr[c];
92
+ ret.set_mask(y, x, 0);
93
+ }
94
+ }
95
+ }
96
+ }
97
+
98
+ return ret;
99
+ }
100
+
101
+ MaskedImage MaskedImage::upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const {
102
+ auto ret = upsample(new_w, new_h);
103
+ ret.set_global_mask_mat(new_global_mask);
104
+ return ret;
105
+ }
106
+
107
+ void MaskedImage::compute_image_gradients() {
108
+ if (m_image_grad_computed) {
109
+ return;
110
+ }
111
+
112
+ const auto size = m_image.size();
113
+ m_image_grady = cv::Mat(size, CV_8UC3);
114
+ m_image_gradx = cv::Mat(size, CV_8UC3);
115
+ m_image_grady = cv::Scalar::all(0);
116
+ m_image_gradx = cv::Scalar::all(0);
117
+
118
+ for (int i = 1; i < size.height - 1; ++i) {
119
+ const auto *ptr = m_image.ptr<unsigned char>(i, 0);
120
+ const auto *ptry1 = m_image.ptr<unsigned char>(i + 1, 0);
121
+ const auto *ptry2 = m_image.ptr<unsigned char>(i - 1, 0);
122
+ const auto *ptrx1 = m_image.ptr<unsigned char>(i, 0) + 3;
123
+ const auto *ptrx2 = m_image.ptr<unsigned char>(i, 0) - 3;
124
+ auto *mptry = m_image_grady.ptr<unsigned char>(i, 0);
125
+ auto *mptrx = m_image_gradx.ptr<unsigned char>(i, 0);
126
+ for (int j = 3; j < size.width * 3 - 3; ++j) {
127
+ mptry[j] = (ptry1[j] / 2 - ptry2[j] / 2) + 128;
128
+ mptrx[j] = (ptrx1[j] / 2 - ptrx2[j] / 2) + 128;
129
+ }
130
+ }
131
+
132
+ m_image_grad_computed = true;
133
+ }
134
+
135
+ void MaskedImage::compute_image_gradients() const {
136
+ const_cast<MaskedImage *>(this)->compute_image_gradients();
137
+ }
138
+
PyPatchMatch/csrc/masked_image.h ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <opencv2/core.hpp>
4
+
5
+ class MaskedImage {
6
+ public:
7
+ MaskedImage() : m_image(), m_mask(), m_global_mask(), m_image_grady(), m_image_gradx(), m_image_grad_computed(false) {
8
+ // pass
9
+ }
10
+ MaskedImage(cv::Mat image, cv::Mat mask) : m_image(image), m_mask(mask), m_image_grad_computed(false) {
11
+ // pass
12
+ }
13
+ MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask) : m_image(image), m_mask(mask), m_global_mask(global_mask), m_image_grad_computed(false) {
14
+ // pass
15
+ }
16
+ MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask, cv::Mat grady, cv::Mat gradx, bool grad_computed) :
17
+ m_image(image), m_mask(mask), m_global_mask(global_mask),
18
+ m_image_grady(grady), m_image_gradx(gradx), m_image_grad_computed(grad_computed) {
19
+ // pass
20
+ }
21
+ MaskedImage(int width, int height) : m_global_mask(), m_image_grady(), m_image_gradx() {
22
+ m_image = cv::Mat(cv::Size(width, height), CV_8UC3);
23
+ m_image = cv::Scalar::all(0);
24
+
25
+ m_mask = cv::Mat(cv::Size(width, height), CV_8U);
26
+ m_mask = cv::Scalar::all(0);
27
+ }
28
+ inline MaskedImage clone() {
29
+ return MaskedImage(
30
+ m_image.clone(), m_mask.clone(), m_global_mask.clone(),
31
+ m_image_grady.clone(), m_image_gradx.clone(), m_image_grad_computed
32
+ );
33
+ }
34
+
35
+ inline cv::Size size() const {
36
+ return m_image.size();
37
+ }
38
+ inline const cv::Mat &image() const {
39
+ return m_image;
40
+ }
41
+ inline const cv::Mat &mask() const {
42
+ return m_mask;
43
+ }
44
+ inline const cv::Mat &global_mask() const {
45
+ return m_global_mask;
46
+ }
47
+ inline const cv::Mat &grady() const {
48
+ assert(m_image_grad_computed);
49
+ return m_image_grady;
50
+ }
51
+ inline const cv::Mat &gradx() const {
52
+ assert(m_image_grad_computed);
53
+ return m_image_gradx;
54
+ }
55
+
56
+ inline void init_global_mask_mat() {
57
+ m_global_mask = cv::Mat(m_mask.size(), CV_8U);
58
+ m_global_mask.setTo(cv::Scalar(0));
59
+ }
60
+ inline void set_global_mask_mat(const cv::Mat &other) {
61
+ m_global_mask = other;
62
+ }
63
+
64
+ inline bool is_masked(int y, int x) const {
65
+ return static_cast<bool>(m_mask.at<unsigned char>(y, x));
66
+ }
67
+ inline bool is_globally_masked(int y, int x) const {
68
+ return !m_global_mask.empty() && static_cast<bool>(m_global_mask.at<unsigned char>(y, x));
69
+ }
70
+ inline void set_mask(int y, int x, bool value) {
71
+ m_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
72
+ }
73
+ inline void set_global_mask(int y, int x, bool value) {
74
+ m_global_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
75
+ }
76
+ inline void clear_mask() {
77
+ m_mask.setTo(cv::Scalar(0));
78
+ }
79
+
80
+ inline const unsigned char *get_image(int y, int x) const {
81
+ return m_image.ptr<unsigned char>(y, x);
82
+ }
83
+ inline unsigned char *get_mutable_image(int y, int x) {
84
+ return m_image.ptr<unsigned char>(y, x);
85
+ }
86
+
87
+ inline unsigned char get_image(int y, int x, int c) const {
88
+ return m_image.ptr<unsigned char>(y, x)[c];
89
+ }
90
+ inline int get_image_int(int y, int x, int c) const {
91
+ return static_cast<int>(m_image.ptr<unsigned char>(y, x)[c]);
92
+ }
93
+
94
+ bool contains_mask(int y, int x, int patch_size) const;
95
+ MaskedImage downsample() const;
96
+ MaskedImage upsample(int new_w, int new_h) const;
97
+ MaskedImage upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const;
98
+ void compute_image_gradients();
99
+ void compute_image_gradients() const;
100
+
101
+ static const cv::Size kDownsampleKernelSize;
102
+ static const int kDownsampleKernel[6];
103
+
104
+ private:
105
+ cv::Mat m_image;
106
+ cv::Mat m_mask;
107
+ cv::Mat m_global_mask;
108
+ cv::Mat m_image_grady;
109
+ cv::Mat m_image_gradx;
110
+ bool m_image_grad_computed = false;
111
+ };
112
+
PyPatchMatch/csrc/nnf.cpp ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <algorithm>
2
+ #include <iostream>
3
+ #include <cmath>
4
+
5
+ #include "masked_image.h"
6
+ #include "nnf.h"
7
+
8
+ /**
9
+ * Nearest-Neighbor Field (see PatchMatch algorithm).
10
+ * This algorithme uses a version proposed by Xavier Philippeau.
11
+ *
12
+ */
13
+
14
+ template <typename T>
15
+ T clamp(T value, T min_value, T max_value) {
16
+ return std::min(std::max(value, min_value), max_value);
17
+ }
18
+
19
+ void NearestNeighborField::_randomize_field(int max_retry, bool reset) {
20
+ auto this_size = source_size();
21
+ for (int i = 0; i < this_size.height; ++i) {
22
+ for (int j = 0; j < this_size.width; ++j) {
23
+ if (m_source.is_globally_masked(i, j)) continue;
24
+
25
+ auto this_ptr = mutable_ptr(i, j);
26
+ int distance = reset ? PatchDistanceMetric::kDistanceScale : this_ptr[2];
27
+ if (distance < PatchDistanceMetric::kDistanceScale) {
28
+ continue;
29
+ }
30
+
31
+ int i_target = 0, j_target = 0;
32
+ for (int t = 0; t < max_retry; ++t) {
33
+ i_target = rand() % this_size.height;
34
+ j_target = rand() % this_size.width;
35
+ if (m_target.is_globally_masked(i_target, j_target)) continue;
36
+
37
+ distance = _distance(i, j, i_target, j_target);
38
+ if (distance < PatchDistanceMetric::kDistanceScale)
39
+ break;
40
+ }
41
+
42
+ this_ptr[0] = i_target, this_ptr[1] = j_target, this_ptr[2] = distance;
43
+ }
44
+ }
45
+ }
46
+
47
+ void NearestNeighborField::_initialize_field_from(const NearestNeighborField &other, int max_retry) {
48
+ const auto &this_size = source_size();
49
+ const auto &other_size = other.source_size();
50
+ double fi = static_cast<double>(this_size.height) / other_size.height;
51
+ double fj = static_cast<double>(this_size.width) / other_size.width;
52
+
53
+ for (int i = 0; i < this_size.height; ++i) {
54
+ for (int j = 0; j < this_size.width; ++j) {
55
+ if (m_source.is_globally_masked(i, j)) continue;
56
+
57
+ int ilow = static_cast<int>(std::min(i / fi, static_cast<double>(other_size.height - 1)));
58
+ int jlow = static_cast<int>(std::min(j / fj, static_cast<double>(other_size.width - 1)));
59
+ auto this_value = mutable_ptr(i, j);
60
+ auto other_value = other.ptr(ilow, jlow);
61
+
62
+ this_value[0] = static_cast<int>(other_value[0] * fi);
63
+ this_value[1] = static_cast<int>(other_value[1] * fj);
64
+ this_value[2] = _distance(i, j, this_value[0], this_value[1]);
65
+ }
66
+ }
67
+
68
+ _randomize_field(max_retry, false);
69
+ }
70
+
71
+ void NearestNeighborField::minimize(int nr_pass) {
72
+ const auto &this_size = source_size();
73
+ while (nr_pass--) {
74
+ for (int i = 0; i < this_size.height; ++i)
75
+ for (int j = 0; j < this_size.width; ++j) {
76
+ if (m_source.is_globally_masked(i, j)) continue;
77
+ if (at(i, j, 2) > 0) _minimize_link(i, j, +1);
78
+ }
79
+ for (int i = this_size.height - 1; i >= 0; --i)
80
+ for (int j = this_size.width - 1; j >= 0; --j) {
81
+ if (m_source.is_globally_masked(i, j)) continue;
82
+ if (at(i, j, 2) > 0) _minimize_link(i, j, -1);
83
+ }
84
+ }
85
+ }
86
+
87
+ void NearestNeighborField::_minimize_link(int y, int x, int direction) {
88
+ const auto &this_size = source_size();
89
+ const auto &this_target_size = target_size();
90
+ auto this_ptr = mutable_ptr(y, x);
91
+
92
+ // propagation along the y direction.
93
+ if (y - direction >= 0 && y - direction < this_size.height && !m_source.is_globally_masked(y - direction, x)) {
94
+ int yp = at(y - direction, x, 0) + direction;
95
+ int xp = at(y - direction, x, 1);
96
+ int dp = _distance(y, x, yp, xp);
97
+ if (dp < at(y, x, 2)) {
98
+ this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
99
+ }
100
+ }
101
+
102
+ // propagation along the x direction.
103
+ if (x - direction >= 0 && x - direction < this_size.width && !m_source.is_globally_masked(y, x - direction)) {
104
+ int yp = at(y, x - direction, 0);
105
+ int xp = at(y, x - direction, 1) + direction;
106
+ int dp = _distance(y, x, yp, xp);
107
+ if (dp < at(y, x, 2)) {
108
+ this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
109
+ }
110
+ }
111
+
112
+ // random search with a progressive step size.
113
+ int random_scale = (std::min(this_target_size.height, this_target_size.width) - 1) / 2;
114
+ while (random_scale > 0) {
115
+ int yp = this_ptr[0] + (rand() % (2 * random_scale + 1) - random_scale);
116
+ int xp = this_ptr[1] + (rand() % (2 * random_scale + 1) - random_scale);
117
+ yp = clamp(yp, 0, target_size().height - 1);
118
+ xp = clamp(xp, 0, target_size().width - 1);
119
+
120
+ if (m_target.is_globally_masked(yp, xp)) {
121
+ random_scale /= 2;
122
+ }
123
+
124
+ int dp = _distance(y, x, yp, xp);
125
+ if (dp < at(y, x, 2)) {
126
+ this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
127
+ }
128
+ random_scale /= 2;
129
+ }
130
+ }
131
+
132
+ const int PatchDistanceMetric::kDistanceScale = 65535;
133
+ const int PatchSSDDistanceMetric::kSSDScale = 9 * 255 * 255;
134
+
135
+ namespace {
136
+
137
+ inline int pow2(int i) {
138
+ return i * i;
139
+ }
140
+
141
+ int distance_masked_images(
142
+ const MaskedImage &source, int ys, int xs,
143
+ const MaskedImage &target, int yt, int xt,
144
+ int patch_size
145
+ ) {
146
+ long double distance = 0;
147
+ long double wsum = 0;
148
+
149
+ source.compute_image_gradients();
150
+ target.compute_image_gradients();
151
+
152
+ auto source_size = source.size();
153
+ auto target_size = target.size();
154
+
155
+ for (int dy = -patch_size; dy <= patch_size; ++dy) {
156
+ const int yys = ys + dy, yyt = yt + dy;
157
+
158
+ if (yys <= 0 || yys >= source_size.height - 1 || yyt <= 0 || yyt >= target_size.height - 1) {
159
+ distance += (long double)(PatchSSDDistanceMetric::kSSDScale) * (2 * patch_size + 1);
160
+ wsum += 2 * patch_size + 1;
161
+ continue;
162
+ }
163
+
164
+ const auto *p_si = source.image().ptr<unsigned char>(yys, 0);
165
+ const auto *p_ti = target.image().ptr<unsigned char>(yyt, 0);
166
+ const auto *p_sm = source.mask().ptr<unsigned char>(yys, 0);
167
+ const auto *p_tm = target.mask().ptr<unsigned char>(yyt, 0);
168
+
169
+ const unsigned char *p_sgm = nullptr;
170
+ const unsigned char *p_tgm = nullptr;
171
+ if (!source.global_mask().empty()) {
172
+ p_sgm = source.global_mask().ptr<unsigned char>(yys, 0);
173
+ p_tgm = target.global_mask().ptr<unsigned char>(yyt, 0);
174
+ }
175
+
176
+ const auto *p_sgy = source.grady().ptr<unsigned char>(yys, 0);
177
+ const auto *p_tgy = target.grady().ptr<unsigned char>(yyt, 0);
178
+ const auto *p_sgx = source.gradx().ptr<unsigned char>(yys, 0);
179
+ const auto *p_tgx = target.gradx().ptr<unsigned char>(yyt, 0);
180
+
181
+ for (int dx = -patch_size; dx <= patch_size; ++dx) {
182
+ int xxs = xs + dx, xxt = xt + dx;
183
+ wsum += 1;
184
+
185
+ if (xxs <= 0 || xxs >= source_size.width - 1 || xxt <= 0 || xxt >= source_size.width - 1) {
186
+ distance += PatchSSDDistanceMetric::kSSDScale;
187
+ continue;
188
+ }
189
+
190
+ if (p_sm[xxs] || p_tm[xxt] || (p_sgm && p_sgm[xxs]) || (p_tgm && p_tgm[xxt]) ) {
191
+ distance += PatchSSDDistanceMetric::kSSDScale;
192
+ continue;
193
+ }
194
+
195
+ int ssd = 0;
196
+ for (int c = 0; c < 3; ++c) {
197
+ int s_value = p_si[xxs * 3 + c];
198
+ int t_value = p_ti[xxt * 3 + c];
199
+ int s_gy = p_sgy[xxs * 3 + c];
200
+ int t_gy = p_tgy[xxt * 3 + c];
201
+ int s_gx = p_sgx[xxs * 3 + c];
202
+ int t_gx = p_tgx[xxt * 3 + c];
203
+
204
+ ssd += pow2(static_cast<int>(s_value) - t_value);
205
+ ssd += pow2(static_cast<int>(s_gx) - t_gx);
206
+ ssd += pow2(static_cast<int>(s_gy) - t_gy);
207
+ }
208
+ distance += ssd;
209
+ }
210
+ }
211
+
212
+ distance /= (long double)(PatchSSDDistanceMetric::kSSDScale);
213
+
214
+ int res = int(PatchDistanceMetric::kDistanceScale * distance / wsum);
215
+ if (res < 0 || res > PatchDistanceMetric::kDistanceScale) return PatchDistanceMetric::kDistanceScale;
216
+ return res;
217
+ }
218
+
219
+ }
220
+
221
+ int PatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
222
+ return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
223
+ }
224
+
225
+ int DebugPatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
226
+ fprintf(stderr, "DebugPatchSSDDistanceMetric: %d %d %d %d\n", source.size().width, source.size().height, m_width, m_height);
227
+ return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
228
+ }
229
+
230
+ int RegularityGuidedPatchDistanceMetricV1::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
231
+ double dx = remainder(double(source_x - target_x) / source.size().width, m_dx1);
232
+ double dy = remainder(double(source_y - target_y) / source.size().height, m_dy2);
233
+
234
+ double score1 = sqrt(dx * dx + dy *dy) / m_scale;
235
+ if (score1 < 0 || score1 > 1) score1 = 1;
236
+ score1 *= PatchDistanceMetric::kDistanceScale;
237
+
238
+ double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
239
+ double score = score1 * m_weight + score2 / (1 + m_weight);
240
+ return static_cast<int>(score / (1 + m_weight));
241
+ }
242
+
243
+ int RegularityGuidedPatchDistanceMetricV2::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
244
+ if (target_y < 0 || target_y >= target.size().height || target_x < 0 || target_x >= target.size().width)
245
+ return PatchDistanceMetric::kDistanceScale;
246
+
247
+ int source_scale = m_ijmap.size().height / source.size().height;
248
+ int target_scale = m_ijmap.size().height / target.size().height;
249
+
250
+ // fprintf(stderr, "RegularityGuidedPatchDistanceMetricV2 %d %d %d %d\n", source_y * source_scale, m_ijmap.size().height, source_x * source_scale, m_ijmap.size().width);
251
+
252
+ double score1 = PatchDistanceMetric::kDistanceScale;
253
+ if (!source.is_globally_masked(source_y, source_x) && !target.is_globally_masked(target_y, target_x)) {
254
+ auto source_ij = m_ijmap.ptr<float>(source_y * source_scale, source_x * source_scale);
255
+ auto target_ij = m_ijmap.ptr<float>(target_y * target_scale, target_x * target_scale);
256
+
257
+ float di = fabs(source_ij[0] - target_ij[0]); if (di > 0.5) di = 1 - di;
258
+ float dj = fabs(source_ij[1] - target_ij[1]); if (dj > 0.5) dj = 1 - dj;
259
+ score1 = sqrt(di * di + dj *dj) / 0.707;
260
+ if (score1 < 0 || score1 > 1) score1 = 1;
261
+ score1 *= PatchDistanceMetric::kDistanceScale;
262
+ }
263
+
264
+ double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
265
+ double score = score1 * m_weight + score2;
266
+ return int(score / (1 + m_weight));
267
+ }
268
+
PyPatchMatch/csrc/nnf.h ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <opencv2/core.hpp>
4
+ #include "masked_image.h"
5
+
6
+ class PatchDistanceMetric {
7
+ public:
8
+ PatchDistanceMetric(int patch_size) : m_patch_size(patch_size) {}
9
+ virtual ~PatchDistanceMetric() = default;
10
+
11
+ inline int patch_size() const { return m_patch_size; }
12
+ virtual int operator()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const = 0;
13
+ static const int kDistanceScale;
14
+
15
+ protected:
16
+ int m_patch_size;
17
+ };
18
+
19
+ class NearestNeighborField {
20
+ public:
21
+ NearestNeighborField() : m_source(), m_target(), m_field(), m_distance_metric(nullptr) {
22
+ // pass
23
+ }
24
+ NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, int max_retry = 20)
25
+ : m_source(source), m_target(target), m_distance_metric(metric) {
26
+ m_field = cv::Mat(m_source.size(), CV_32SC3);
27
+ _randomize_field(max_retry);
28
+ }
29
+ NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, const NearestNeighborField &other, int max_retry = 20)
30
+ : m_source(source), m_target(target), m_distance_metric(metric) {
31
+ m_field = cv::Mat(m_source.size(), CV_32SC3);
32
+ _initialize_field_from(other, max_retry);
33
+ }
34
+
35
+ const MaskedImage &source() const {
36
+ return m_source;
37
+ }
38
+ const MaskedImage &target() const {
39
+ return m_target;
40
+ }
41
+ inline cv::Size source_size() const {
42
+ return m_source.size();
43
+ }
44
+ inline cv::Size target_size() const {
45
+ return m_target.size();
46
+ }
47
+ inline void set_source(const MaskedImage &source) {
48
+ m_source = source;
49
+ }
50
+ inline void set_target(const MaskedImage &target) {
51
+ m_target = target;
52
+ }
53
+
54
+ inline int *mutable_ptr(int y, int x) {
55
+ return m_field.ptr<int>(y, x);
56
+ }
57
+ inline const int *ptr(int y, int x) const {
58
+ return m_field.ptr<int>(y, x);
59
+ }
60
+
61
+ inline int at(int y, int x, int c) const {
62
+ return m_field.ptr<int>(y, x)[c];
63
+ }
64
+ inline int &at(int y, int x, int c) {
65
+ return m_field.ptr<int>(y, x)[c];
66
+ }
67
+ inline void set_identity(int y, int x) {
68
+ auto ptr = mutable_ptr(y, x);
69
+ ptr[0] = y, ptr[1] = x, ptr[2] = 0;
70
+ }
71
+
72
+ void minimize(int nr_pass);
73
+
74
+ private:
75
+ inline int _distance(int source_y, int source_x, int target_y, int target_x) {
76
+ return (*m_distance_metric)(m_source, source_y, source_x, m_target, target_y, target_x);
77
+ }
78
+
79
+ void _randomize_field(int max_retry = 20, bool reset = true);
80
+ void _initialize_field_from(const NearestNeighborField &other, int max_retry);
81
+ void _minimize_link(int y, int x, int direction);
82
+
83
+ MaskedImage m_source;
84
+ MaskedImage m_target;
85
+ cv::Mat m_field; // { y_target, x_target, distance_scaled }
86
+ const PatchDistanceMetric *m_distance_metric;
87
+ };
88
+
89
+
90
+ class PatchSSDDistanceMetric : public PatchDistanceMetric {
91
+ public:
92
+ using PatchDistanceMetric::PatchDistanceMetric;
93
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
94
+ static const int kSSDScale;
95
+ };
96
+
97
+ class DebugPatchSSDDistanceMetric : public PatchDistanceMetric {
98
+ public:
99
+ DebugPatchSSDDistanceMetric(int patch_size, int width, int height) : PatchDistanceMetric(patch_size), m_width(width), m_height(height) {}
100
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
101
+ protected:
102
+ int m_width, m_height;
103
+ };
104
+
105
+ class RegularityGuidedPatchDistanceMetricV1 : public PatchDistanceMetric {
106
+ public:
107
+ RegularityGuidedPatchDistanceMetricV1(int patch_size, double dx1, double dy1, double dx2, double dy2, double weight)
108
+ : PatchDistanceMetric(patch_size), m_dx1(dx1), m_dy1(dy1), m_dx2(dx2), m_dy2(dy2), m_weight(weight) {
109
+
110
+ assert(m_dy1 == 0);
111
+ assert(m_dx2 == 0);
112
+ m_scale = sqrt(m_dx1 * m_dx1 + m_dy2 * m_dy2) / 4;
113
+ }
114
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
115
+
116
+ protected:
117
+ double m_dx1, m_dy1, m_dx2, m_dy2;
118
+ double m_scale, m_weight;
119
+ };
120
+
121
+ class RegularityGuidedPatchDistanceMetricV2 : public PatchDistanceMetric {
122
+ public:
123
+ RegularityGuidedPatchDistanceMetricV2(int patch_size, cv::Mat ijmap, double weight)
124
+ : PatchDistanceMetric(patch_size), m_ijmap(ijmap), m_weight(weight) {
125
+
126
+ }
127
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
128
+
129
+ protected:
130
+ cv::Mat m_ijmap;
131
+ double m_width, m_height, m_weight;
132
+ };
133
+
PyPatchMatch/csrc/pyinterface.cpp ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "pyinterface.h"
2
+ #include "inpaint.h"
3
+
4
+ static unsigned int PM_seed = 1212;
5
+ static bool PM_verbose = false;
6
+
7
+ int _dtype_py_to_cv(int dtype_py);
8
+ int _dtype_cv_to_py(int dtype_cv);
9
+ cv::Mat _py_to_cv2(PM_mat_t pymat);
10
+ PM_mat_t _cv2_to_py(cv::Mat cvmat);
11
+
12
+ void PM_set_random_seed(unsigned int seed) {
13
+ PM_seed = seed;
14
+ }
15
+
16
+ void PM_set_verbose(int value) {
17
+ PM_verbose = static_cast<bool>(value);
18
+ }
19
+
20
+ void PM_free_pymat(PM_mat_t pymat) {
21
+ free(pymat.data_ptr);
22
+ }
23
+
24
+ PM_mat_t PM_inpaint(PM_mat_t source_py, PM_mat_t mask_py, int patch_size) {
25
+ cv::Mat source = _py_to_cv2(source_py);
26
+ cv::Mat mask = _py_to_cv2(mask_py);
27
+ auto metric = PatchSSDDistanceMetric(patch_size);
28
+ cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
29
+ return _cv2_to_py(result);
30
+ }
31
+
32
+ PM_mat_t PM_inpaint_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
33
+ cv::Mat source = _py_to_cv2(source_py);
34
+ cv::Mat mask = _py_to_cv2(mask_py);
35
+ cv::Mat ijmap = _py_to_cv2(ijmap_py);
36
+
37
+ auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
38
+ cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
39
+ return _cv2_to_py(result);
40
+ }
41
+
42
+ PM_mat_t PM_inpaint2(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, int patch_size) {
43
+ cv::Mat source = _py_to_cv2(source_py);
44
+ cv::Mat mask = _py_to_cv2(mask_py);
45
+ cv::Mat global_mask = _py_to_cv2(global_mask_py);
46
+
47
+ auto metric = PatchSSDDistanceMetric(patch_size);
48
+ cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
49
+ return _cv2_to_py(result);
50
+ }
51
+
52
+ PM_mat_t PM_inpaint2_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
53
+ cv::Mat source = _py_to_cv2(source_py);
54
+ cv::Mat mask = _py_to_cv2(mask_py);
55
+ cv::Mat global_mask = _py_to_cv2(global_mask_py);
56
+ cv::Mat ijmap = _py_to_cv2(ijmap_py);
57
+
58
+ auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
59
+ cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
60
+ return _cv2_to_py(result);
61
+ }
62
+
63
+ int _dtype_py_to_cv(int dtype_py) {
64
+ switch (dtype_py) {
65
+ case PM_UINT8: return CV_8U;
66
+ case PM_INT8: return CV_8S;
67
+ case PM_UINT16: return CV_16U;
68
+ case PM_INT16: return CV_16S;
69
+ case PM_INT32: return CV_32S;
70
+ case PM_FLOAT32: return CV_32F;
71
+ case PM_FLOAT64: return CV_64F;
72
+ }
73
+
74
+ return CV_8U;
75
+ }
76
+
77
+ int _dtype_cv_to_py(int dtype_cv) {
78
+ switch (dtype_cv) {
79
+ case CV_8U: return PM_UINT8;
80
+ case CV_8S: return PM_INT8;
81
+ case CV_16U: return PM_UINT16;
82
+ case CV_16S: return PM_INT16;
83
+ case CV_32S: return PM_INT32;
84
+ case CV_32F: return PM_FLOAT32;
85
+ case CV_64F: return PM_FLOAT64;
86
+ }
87
+
88
+ return PM_UINT8;
89
+ }
90
+
91
+ cv::Mat _py_to_cv2(PM_mat_t pymat) {
92
+ int dtype = _dtype_py_to_cv(pymat.dtype);
93
+ dtype = CV_MAKETYPE(pymat.dtype, pymat.shape.channels);
94
+ return cv::Mat(cv::Size(pymat.shape.width, pymat.shape.height), dtype, pymat.data_ptr).clone();
95
+ }
96
+
97
+ PM_mat_t _cv2_to_py(cv::Mat cvmat) {
98
+ PM_shape_t shape = {cvmat.size().width, cvmat.size().height, cvmat.channels()};
99
+ int dtype = _dtype_cv_to_py(cvmat.depth());
100
+ size_t dsize = cvmat.total() * cvmat.elemSize();
101
+
102
+ void *data_ptr = reinterpret_cast<void *>(malloc(dsize));
103
+ memcpy(data_ptr, reinterpret_cast<void *>(cvmat.data), dsize);
104
+
105
+ return PM_mat_t {data_ptr, shape, dtype};
106
+ }
107
+
PyPatchMatch/csrc/pyinterface.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <opencv2/core.hpp>
2
+ #include <cstdlib>
3
+ #include <cstdio>
4
+ #include <cstring>
5
+
6
+ extern "C" {
7
+
8
+ struct PM_shape_t {
9
+ int width, height, channels;
10
+ };
11
+
12
+ enum PM_dtype_e {
13
+ PM_UINT8,
14
+ PM_INT8,
15
+ PM_UINT16,
16
+ PM_INT16,
17
+ PM_INT32,
18
+ PM_FLOAT32,
19
+ PM_FLOAT64,
20
+ };
21
+
22
+ struct PM_mat_t {
23
+ void *data_ptr;
24
+ PM_shape_t shape;
25
+ int dtype;
26
+ };
27
+
28
+ void PM_set_random_seed(unsigned int seed);
29
+ void PM_set_verbose(int value);
30
+
31
+ void PM_free_pymat(PM_mat_t pymat);
32
+ PM_mat_t PM_inpaint(PM_mat_t image, PM_mat_t mask, int patch_size);
33
+ PM_mat_t PM_inpaint_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t ijmap, int patch_size, float guide_weight);
34
+ PM_mat_t PM_inpaint2(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, int patch_size);
35
+ PM_mat_t PM_inpaint2_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, PM_mat_t ijmap, int patch_size, float guide_weight);
36
+
37
+ } /* extern "C" */
38
+
PyPatchMatch/examples/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ /cpp_example.exe
2
+ /images/*recovered.bmp
PyPatchMatch/examples/cpp_example.cpp ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <iostream>
2
+ #include <opencv2/imgcodecs.hpp>
3
+ #include <opencv2/highgui.hpp>
4
+
5
+ #include "masked_image.h"
6
+ #include "nnf.h"
7
+ #include "inpaint.h"
8
+
9
+ int main() {
10
+ auto source = cv::imread("./images/forest_pruned.bmp", cv::IMREAD_COLOR);
11
+
12
+ auto mask = cv::Mat(source.size(), CV_8UC1);
13
+ mask = cv::Scalar::all(0);
14
+ for (int i = 0; i < source.size().height; ++i) {
15
+ for (int j = 0; j < source.size().width; ++j) {
16
+ auto source_ptr = source.ptr<unsigned char>(i, j);
17
+ if (source_ptr[0] == 255 && source_ptr[1] == 255 && source_ptr[2] == 255) {
18
+ mask.at<unsigned char>(i, j) = 1;
19
+ }
20
+ }
21
+ }
22
+
23
+ auto metric = PatchSSDDistanceMetric(3);
24
+ auto result = Inpainting(source, mask, &metric).run(true, true);
25
+ // cv::imwrite("./images/forest_recovered.bmp", result);
26
+ // cv::imshow("Result", result);
27
+ // cv::waitKey();
28
+
29
+ return 0;
30
+ }
31
+
PyPatchMatch/examples/cpp_example_run.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+ #
3
+ # cpp_example_run.sh
4
+ # Copyright (C) 2020 Jiayuan Mao <maojiayuan@gmail.com>
5
+ #
6
+ # Distributed under terms of the MIT license.
7
+ #
8
+
9
+ set -x
10
+
11
+ CFLAGS="-std=c++14 -O2 $(pkg-config --cflags opencv)"
12
+ LDFLAGS="$(pkg-config --libs opencv)"
13
+ g++ $CFLAGS cpp_example.cpp -I../csrc/ -L../ -lpatchmatch $LDFLAGS -o cpp_example.exe
14
+
15
+ export DYLD_LIBRARY_PATH=../:$DYLD_LIBRARY_PATH # For macOS
16
+ export LD_LIBRARY_PATH=../:$LD_LIBRARY_PATH # For Linux
17
+ time ./cpp_example.exe
18
+
PyPatchMatch/examples/images/forest.bmp ADDED
PyPatchMatch/examples/images/forest_pruned.bmp ADDED
PyPatchMatch/examples/py_example.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : test.py
4
+ # Author : Jiayuan Mao
5
+ # Email : maojiayuan@gmail.com
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ from PIL import Image
11
+
12
+ import sys
13
+ sys.path.insert(0, '../')
14
+ import patch_match
15
+
16
+
17
+ if __name__ == '__main__':
18
+ source = Image.open('./images/forest_pruned.bmp')
19
+ result = patch_match.inpaint(source, patch_size=3)
20
+ Image.fromarray(result).save('./images/forest_recovered.bmp')
21
+
PyPatchMatch/examples/py_example_global_mask.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : test.py
4
+ # Author : Jiayuan Mao
5
+ # Email : maojiayuan@gmail.com
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+ import sys
14
+ sys.path.insert(0, '../')
15
+ import patch_match
16
+
17
+
18
+ if __name__ == '__main__':
19
+ patch_match.set_verbose(True)
20
+ source = Image.open('./images/forest_pruned.bmp')
21
+ source = np.array(source)
22
+ source[:100, :100] = 255
23
+ global_mask = np.zeros_like(source[..., 0])
24
+ global_mask[:100, :100] = 1
25
+ result = patch_match.inpaint(source, global_mask=global_mask, patch_size=3)
26
+ Image.fromarray(result).save('./images/forest_recovered.bmp')
27
+
PyPatchMatch/patch_match.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : patch_match.py
4
+ # Author : Jiayuan Mao
5
+ # Email : maojiayuan@gmail.com
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ import ctypes
11
+ import os.path as osp
12
+ from typing import Optional, Union
13
+
14
+ import numpy as np
15
+ from PIL import Image
16
+
17
+
18
+ import os
19
+ if os.name!="nt":
20
+ # Otherwise, fall back to the subprocess.
21
+ import subprocess
22
+ print('Compiling and loading c extensions from "{}".'.format(osp.realpath(osp.dirname(__file__))))
23
+ # subprocess.check_call(['./travis.sh'], cwd=osp.dirname(__file__))
24
+ subprocess.check_call("make clean && make", cwd=osp.dirname(__file__), shell=True)
25
+
26
+
27
+ __all__ = ['set_random_seed', 'set_verbose', 'inpaint', 'inpaint_regularity']
28
+
29
+
30
+ class CShapeT(ctypes.Structure):
31
+ _fields_ = [
32
+ ('width', ctypes.c_int),
33
+ ('height', ctypes.c_int),
34
+ ('channels', ctypes.c_int),
35
+ ]
36
+
37
+
38
+ class CMatT(ctypes.Structure):
39
+ _fields_ = [
40
+ ('data_ptr', ctypes.c_void_p),
41
+ ('shape', CShapeT),
42
+ ('dtype', ctypes.c_int)
43
+ ]
44
+
45
+ import tempfile
46
+ from urllib.request import urlopen, Request
47
+ import shutil
48
+ from pathlib import Path
49
+ from tqdm import tqdm
50
+
51
+ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
52
+ r"""Download object at the given URL to a local path.
53
+
54
+ Args:
55
+ url (string): URL of the object to download
56
+ dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file``
57
+ hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
58
+ Default: None
59
+ progress (bool, optional): whether or not to display a progress bar to stderr
60
+ Default: True
61
+ https://pytorch.org/docs/stable/_modules/torch/hub.html#load_state_dict_from_url
62
+ """
63
+ file_size = None
64
+ req = Request(url)
65
+ u = urlopen(req)
66
+ meta = u.info()
67
+ if hasattr(meta, 'getheaders'):
68
+ content_length = meta.getheaders("Content-Length")
69
+ else:
70
+ content_length = meta.get_all("Content-Length")
71
+ if content_length is not None and len(content_length) > 0:
72
+ file_size = int(content_length[0])
73
+
74
+ # We deliberately save it in a temp file and move it after
75
+ # download is complete. This prevents a local working checkpoint
76
+ # being overridden by a broken download.
77
+ dst = os.path.expanduser(dst)
78
+ dst_dir = os.path.dirname(dst)
79
+ f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
80
+
81
+ try:
82
+ with tqdm(total=file_size, disable=not progress,
83
+ unit='B', unit_scale=True, unit_divisor=1024) as pbar:
84
+ while True:
85
+ buffer = u.read(8192)
86
+ if len(buffer) == 0:
87
+ break
88
+ f.write(buffer)
89
+ pbar.update(len(buffer))
90
+
91
+ f.close()
92
+ shutil.move(f.name, dst)
93
+ finally:
94
+ f.close()
95
+ if os.path.exists(f.name):
96
+ os.remove(f.name)
97
+
98
+ if os.name!="nt":
99
+ PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.so'))
100
+ else:
101
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')):
102
+ download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll",dst=osp.join(osp.dirname(__file__), 'libpatchmatch.dll'))
103
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')):
104
+ download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll",dst=osp.join(osp.dirname(__file__), 'opencv_world460.dll'))
105
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')):
106
+ print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll and put it into the PyPatchMatch folder")
107
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')):
108
+ print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll and put it into the PyPatchMatch folder")
109
+ PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.dll'))
110
+
111
+ PMLIB.PM_set_random_seed.argtypes = [ctypes.c_uint]
112
+ PMLIB.PM_set_verbose.argtypes = [ctypes.c_int]
113
+ PMLIB.PM_free_pymat.argtypes = [CMatT]
114
+ PMLIB.PM_inpaint.argtypes = [CMatT, CMatT, ctypes.c_int]
115
+ PMLIB.PM_inpaint.restype = CMatT
116
+ PMLIB.PM_inpaint_regularity.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
117
+ PMLIB.PM_inpaint_regularity.restype = CMatT
118
+ PMLIB.PM_inpaint2.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int]
119
+ PMLIB.PM_inpaint2.restype = CMatT
120
+ PMLIB.PM_inpaint2_regularity.argtypes = [CMatT, CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
121
+ PMLIB.PM_inpaint2_regularity.restype = CMatT
122
+
123
+
124
+ def set_random_seed(seed: int):
125
+ PMLIB.PM_set_random_seed(ctypes.c_uint(seed))
126
+
127
+
128
+ def set_verbose(verbose: bool):
129
+ PMLIB.PM_set_verbose(ctypes.c_int(verbose))
130
+
131
+
132
+ def inpaint(
133
+ image: Union[np.ndarray, Image.Image],
134
+ mask: Optional[Union[np.ndarray, Image.Image]] = None,
135
+ *,
136
+ global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
137
+ patch_size: int = 15
138
+ ) -> np.ndarray:
139
+ """
140
+ PatchMatch based inpainting proposed in:
141
+
142
+ PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing
143
+ C.Barnes, E.Shechtman, A.Finkelstein and Dan B.Goldman
144
+ SIGGRAPH 2009
145
+
146
+ Args:
147
+ image (Union[np.ndarray, Image.Image]): the input image, should be 3-channel RGB/BGR.
148
+ mask (Union[np.array, Image.Image], optional): the mask of the hole(s) to be filled, should be 1-channel.
149
+ If not provided (None), the algorithm will treat all purely white pixels as the holes (255, 255, 255).
150
+ global_mask (Union[np.array, Image.Image], optional): the target mask of the output image.
151
+ patch_size (int): the patch size for the inpainting algorithm.
152
+
153
+ Return:
154
+ result (np.ndarray): the repaired image, of the same size as the input image.
155
+ """
156
+
157
+ if isinstance(image, Image.Image):
158
+ image = np.array(image)
159
+ image = np.ascontiguousarray(image)
160
+ assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
161
+
162
+ if mask is None:
163
+ mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
164
+ mask = np.ascontiguousarray(mask)
165
+ else:
166
+ mask = _canonize_mask_array(mask)
167
+
168
+ if global_mask is None:
169
+ ret_pymat = PMLIB.PM_inpaint(np_to_pymat(image), np_to_pymat(mask), ctypes.c_int(patch_size))
170
+ else:
171
+ global_mask = _canonize_mask_array(global_mask)
172
+ ret_pymat = PMLIB.PM_inpaint2(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), ctypes.c_int(patch_size))
173
+
174
+ ret_npmat = pymat_to_np(ret_pymat)
175
+ PMLIB.PM_free_pymat(ret_pymat)
176
+
177
+ return ret_npmat
178
+
179
+
180
+ def inpaint_regularity(
181
+ image: Union[np.ndarray, Image.Image],
182
+ mask: Optional[Union[np.ndarray, Image.Image]],
183
+ ijmap: np.ndarray,
184
+ *,
185
+ global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
186
+ patch_size: int = 15, guide_weight: float = 0.25
187
+ ) -> np.ndarray:
188
+ if isinstance(image, Image.Image):
189
+ image = np.array(image)
190
+ image = np.ascontiguousarray(image)
191
+
192
+ assert isinstance(ijmap, np.ndarray) and ijmap.ndim == 3 and ijmap.shape[2] == 3 and ijmap.dtype == 'float32'
193
+ ijmap = np.ascontiguousarray(ijmap)
194
+
195
+ assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
196
+ if mask is None:
197
+ mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
198
+ mask = np.ascontiguousarray(mask)
199
+ else:
200
+ mask = _canonize_mask_array(mask)
201
+
202
+
203
+ if global_mask is None:
204
+ ret_pymat = PMLIB.PM_inpaint_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
205
+ else:
206
+ global_mask = _canonize_mask_array(global_mask)
207
+ ret_pymat = PMLIB.PM_inpaint2_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
208
+
209
+ ret_npmat = pymat_to_np(ret_pymat)
210
+ PMLIB.PM_free_pymat(ret_pymat)
211
+
212
+ return ret_npmat
213
+
214
+
215
+ def _canonize_mask_array(mask):
216
+ if isinstance(mask, Image.Image):
217
+ mask = np.array(mask)
218
+ if mask.ndim == 2 and mask.dtype == 'uint8':
219
+ mask = mask[..., np.newaxis]
220
+ assert mask.ndim == 3 and mask.shape[2] == 1 and mask.dtype == 'uint8'
221
+ return np.ascontiguousarray(mask)
222
+
223
+
224
+ dtype_pymat_to_ctypes = [
225
+ ctypes.c_uint8,
226
+ ctypes.c_int8,
227
+ ctypes.c_uint16,
228
+ ctypes.c_int16,
229
+ ctypes.c_int32,
230
+ ctypes.c_float,
231
+ ctypes.c_double,
232
+ ]
233
+
234
+
235
+ dtype_np_to_pymat = {
236
+ 'uint8': 0,
237
+ 'int8': 1,
238
+ 'uint16': 2,
239
+ 'int16': 3,
240
+ 'int32': 4,
241
+ 'float32': 5,
242
+ 'float64': 6,
243
+ }
244
+
245
+
246
+ def np_to_pymat(npmat):
247
+ assert npmat.ndim == 3
248
+ return CMatT(
249
+ ctypes.cast(npmat.ctypes.data, ctypes.c_void_p),
250
+ CShapeT(npmat.shape[1], npmat.shape[0], npmat.shape[2]),
251
+ dtype_np_to_pymat[str(npmat.dtype)]
252
+ )
253
+
254
+
255
+ def pymat_to_np(pymat):
256
+ npmat = np.ctypeslib.as_array(
257
+ ctypes.cast(pymat.data_ptr, ctypes.POINTER(dtype_pymat_to_ctypes[pymat.dtype])),
258
+ (pymat.shape.height, pymat.shape.width, pymat.shape.channels)
259
+ )
260
+ ret = np.empty(npmat.shape, npmat.dtype)
261
+ ret[:] = npmat
262
+ return ret
263
+
PyPatchMatch/travis.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+ #
3
+ # travis.sh
4
+ # Copyright (C) 2020 Jiayuan Mao <maojiayuan@gmail.com>
5
+ #
6
+ # Distributed under terms of the MIT license.
7
+ #
8
+
9
+ make clean && make
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Stablediffusion Infinity
3
+ emoji: ♾️
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: true
9
+ license: apache-2.0
10
+ sdk_version: 4.31.3
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,1043 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("cp opencv.pc /usr/local/lib/pkgconfig/")
3
+ os.system("pip install 'numpy<2'")
4
+ os.system("pip uninstall triton -y")
5
+ import spaces
6
+ import io
7
+ import base64
8
+ import sys
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image, ImageOps
12
+ import gradio as gr
13
+ import skimage
14
+ import skimage.measure
15
+ import yaml
16
+ import json
17
+ from enum import Enum
18
+ from utils import *
19
+ from collections import Counter
20
+ import argparse
21
+ from stablepy import Model_Diffusers, scheduler_names, ALL_PROMPT_WEIGHT_OPTIONS, SCHEDULE_TYPE_OPTIONS
22
+ from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
23
+ from datetime import datetime
24
+
25
+ parser = argparse.ArgumentParser(description="stablediffusion-infinity")
26
+ parser.add_argument("--port", type=int, help="listen port", dest="server_port")
27
+ parser.add_argument("--host", type=str, help="host", dest="server_name")
28
+ parser.add_argument("--share", action="store_true", help="share this app?")
29
+ parser.add_argument("--debug", action="store_true", help="debug mode")
30
+ parser.add_argument("--fp32", action="store_true", help="using full precision")
31
+ parser.add_argument("--lowvram", action="store_true", help="using lowvram mode")
32
+ parser.add_argument("--encrypt", action="store_true", help="using https?")
33
+ parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile")
34
+ parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile")
35
+ parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password")
36
+ parser.add_argument(
37
+ "--auth", nargs=2, metavar=("username", "password"), help="use username password"
38
+ )
39
+ parser.add_argument(
40
+ "--remote_model",
41
+ type=str,
42
+ help="use a model (e.g. dreambooth fined) from huggingface hub",
43
+ default="",
44
+ )
45
+ parser.add_argument(
46
+ "--local_model", type=str, help="use a model stored on your PC", default=""
47
+ )
48
+ parser.add_argument(
49
+ "--stablepy_model",
50
+ type=str,
51
+ help="Model source: can be a Hugging Face Diffusers repo or a local .safetensors file path",
52
+ default="SG161222/RealVisXL_V5.0_Lightning"
53
+ )
54
+
55
+ try:
56
+ abspath = os.path.abspath(__file__)
57
+ dirname = os.path.dirname(abspath)
58
+ os.chdir(dirname)
59
+ except:
60
+ pass
61
+
62
+ Interrogator = DummyInterrogator
63
+
64
+ START_DEVICE_STABLEPY = "cpu" if os.getenv("SPACES_ZERO_GPU") else None
65
+ DEBUG_MODE = False
66
+ AUTO_SETUP = True
67
+
68
+
69
+ with open("config.yaml", "r") as yaml_in:
70
+ yaml_object = yaml.safe_load(yaml_in)
71
+ config_json = json.dumps(yaml_object)
72
+
73
+
74
+ def parse_color(color):
75
+ """
76
+ Convert color to Pillow-friendly (R, G, B, A) tuple in 0–255 range.
77
+ Supports:
78
+ - tuple/list of floats or ints
79
+ - 'rgba(r, g, b, a)' string
80
+ - 'rgb(r, g, b)' string
81
+ - hex colors: '#RRGGBB' or '#RRGGBBAA'
82
+ """
83
+ if isinstance(color, (tuple, list)):
84
+ parts = [float(c) for c in color]
85
+
86
+ elif isinstance(color, str):
87
+ c = color.strip().lower()
88
+
89
+ # Hex color
90
+ if c.startswith("#"):
91
+ c = c.lstrip("#")
92
+ if len(c) == 6: # RRGGBB
93
+ r, g, b = int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16)
94
+ return (r, g, b, 255)
95
+ elif len(c) == 8: # RRGGBBAA
96
+ r, g, b, a = int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16), int(c[6:8], 16)
97
+ return (r, g, b, a)
98
+ else:
99
+ raise ValueError(f"Invalid hex color: {color}")
100
+
101
+ # RGB / RGBA string
102
+ c = c.replace("rgba", "").replace("rgb", "").replace("(", "").replace(")", "")
103
+ parts = [float(x.strip()) for x in c.split(",")]
104
+
105
+ else:
106
+ raise ValueError(f"Unsupported color format: {color}")
107
+
108
+ # Ensure alpha
109
+ if len(parts) == 3:
110
+ parts.append(1.0) # default alpha = 1.0
111
+
112
+ return (
113
+ int(round(parts[0])),
114
+ int(round(parts[1])),
115
+ int(round(parts[2])),
116
+ int(round(parts[3] * 255 if parts[3] <= 1 else parts[3]))
117
+ )
118
+
119
+
120
+ def is_not_dark(color, threshold=30):
121
+ return not all(c <= threshold for c in color)
122
+
123
+
124
+ def get_dominant_color_exclude_dark(image_pil):
125
+ img_small = image_pil.convert("RGB").resize((50, 50))
126
+ pixels = list(img_small.getdata())
127
+ filtered_pixels = [p for p in pixels if is_not_dark(p)]
128
+ if not filtered_pixels:
129
+ filtered_pixels = pixels
130
+ most_common = Counter(filtered_pixels).most_common(1)[0][0]
131
+ return most_common
132
+
133
+
134
+ def replace_color_in_mask(image_pil, mask_pil, target_color=None):
135
+ img = np.array(image_pil.convert("RGB"))
136
+ mask = np.array(mask_pil.convert("L"))
137
+
138
+ mask_white = mask == 255
139
+ mask_nonwhite = ~mask_white
140
+
141
+ if target_color in [None, ""]:
142
+ nonwhite_pixels = img[mask_nonwhite]
143
+ nonwhite_img = Image.fromarray(nonwhite_pixels.reshape((-1, 1, 3))) # , mode="RGB"
144
+ target_color = get_dominant_color_exclude_dark(nonwhite_img)
145
+ else:
146
+ parsed = parse_color(target_color) # (R, G, B, A)
147
+ target_color = parsed[:3] # ignore alpha for replacement
148
+
149
+ img[mask_white] = target_color
150
+ return Image.fromarray(img)
151
+
152
+
153
+ def expand_white_around_black(image: Image.Image, expand_ratio=0.1) -> Image.Image:
154
+ """
155
+ Expand the white areas around the black region by a percentage of the black region size.
156
+
157
+ Args:
158
+ image: PIL grayscale image (mode "L").
159
+ expand_ratio: Fraction of black region size to expand white sides (default 0.1 = 10%).
160
+
161
+ Returns:
162
+ PIL Image with white expanded around black.
163
+ """
164
+ arr = np.array(image)
165
+
166
+ black_mask = arr == 0
167
+
168
+ height, width = arr.shape
169
+ coords = np.argwhere(black_mask)
170
+
171
+ if coords.size == 0:
172
+ # No black pixels, return original image
173
+ return image.copy()
174
+
175
+ y_min, x_min = coords.min(axis=0)
176
+ y_max, x_max = coords.max(axis=0)
177
+
178
+ expand_x = int((x_max - x_min + 1) * expand_ratio)
179
+ expand_y = int((y_max - y_min + 1) * expand_ratio)
180
+
181
+ # Shrink black bounding box to expand white sides
182
+ if y_min > 0 and np.all(arr[:y_min, :] == 255):
183
+ y_min = min(height - 1, y_min + expand_y)
184
+
185
+ if y_max < height - 1 and np.all(arr[y_max + 1:, :] == 255):
186
+ y_max = max(0, y_max - expand_y)
187
+
188
+ if x_min > 0 and np.all(arr[:, :x_min] == 255):
189
+ x_min = min(width - 1, x_min + expand_x)
190
+
191
+ if x_max < width - 1 and np.all(arr[:, x_max + 1:] == 255):
192
+ x_max = max(0, x_max - expand_x)
193
+
194
+ # Create new white canvas
195
+ expanded_arr = np.full_like(arr, 255)
196
+
197
+ # Paint black inside adjusted bounding box
198
+ expanded_arr[y_min:y_max+1, x_min:x_max+1] = 0
199
+
200
+ return Image.fromarray(expanded_arr)
201
+
202
+
203
+ def load_html():
204
+ body, canvaspy = "", ""
205
+ with open("index.html", encoding="utf8") as f:
206
+ body = f.read()
207
+ with open("canvas.py", encoding="utf8") as f:
208
+ canvaspy = f.read()
209
+ body = body.replace("- paths:\n", "")
210
+ body = body.replace(" - ./canvas.py\n", "")
211
+ body = body.replace("from canvas import InfCanvas", canvaspy)
212
+ return body
213
+
214
+
215
+ def test(x):
216
+ x = load_html()
217
+ return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
218
+ display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms
219
+ allow-scripts allow-same-origin allow-popups
220
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
221
+ allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
222
+
223
+
224
+ try:
225
+ SAMPLING_MODE = Image.Resampling.LANCZOS
226
+ except Exception as e:
227
+ SAMPLING_MODE = Image.LANCZOS
228
+
229
+ try:
230
+ contain_func = ImageOps.contain
231
+ except Exception as e:
232
+ def contain_func(image, size, method=SAMPLING_MODE):
233
+ # from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
234
+ im_ratio = image.width / image.height
235
+ dest_ratio = size[0] / size[1]
236
+ if im_ratio != dest_ratio:
237
+ if im_ratio > dest_ratio:
238
+ new_height = int(image.height / image.width * size[0])
239
+ if new_height != size[1]:
240
+ size = (size[0], new_height)
241
+ else:
242
+ new_width = int(image.width / image.height * size[1])
243
+ if new_width != size[0]:
244
+ size = (new_width, size[1])
245
+ return image.resize(size, resample=method)
246
+
247
+
248
+ if __name__ == "__main__":
249
+ args = parser.parse_args()
250
+ else:
251
+ args = parser.parse_args(["--debug"])
252
+ # args = parser.parse_args(["--debug"])
253
+ if args.auth is not None:
254
+ args.auth = tuple(args.auth)
255
+
256
+ model = {}
257
+
258
+
259
+ def get_token():
260
+ token = ""
261
+ if os.path.exists(".token"):
262
+ with open(".token", "r") as f:
263
+ token = f.read()
264
+ token = os.environ.get("hftoken", token)
265
+ return token
266
+
267
+
268
+ def save_token(token):
269
+ with open(".token", "w") as f:
270
+ f.write(token)
271
+
272
+
273
+ def my_resize(width, height):
274
+ if width >= 512 and height >= 512:
275
+ return width, height
276
+ if width == height:
277
+ return 512, 512
278
+ smaller = min(width, height)
279
+ larger = max(width, height)
280
+ if larger >= 608:
281
+ return width, height
282
+ factor = 1
283
+ if smaller < 290:
284
+ factor = 2
285
+ elif smaller < 330:
286
+ factor = 1.75
287
+ elif smaller < 384:
288
+ factor = 1.375
289
+ elif smaller < 400:
290
+ factor = 1.25
291
+ elif smaller < 450:
292
+ factor = 1.125
293
+ return int(factor * width) // 8 * 8, int(factor * height) // 8 * 8
294
+
295
+
296
+ def load_learned_embed_in_clip(
297
+ learned_embeds_path, text_encoder, tokenizer, token=None
298
+ ):
299
+ # https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb
300
+ loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
301
+
302
+ # separate token and the embeds
303
+ trained_token = list(loaded_learned_embeds.keys())[0]
304
+ embeds = loaded_learned_embeds[trained_token]
305
+
306
+ # cast to dtype of text_encoder
307
+ dtype = text_encoder.get_input_embeddings().weight.dtype
308
+ embeds.to(dtype)
309
+
310
+ # add the token in tokenizer
311
+ token = token if token is not None else trained_token
312
+ num_added_tokens = tokenizer.add_tokens(token)
313
+ if num_added_tokens == 0:
314
+ raise ValueError(
315
+ f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
316
+ )
317
+
318
+ # resize the token embeddings
319
+ text_encoder.resize_token_embeddings(len(tokenizer))
320
+
321
+ # get the id for the token and assign the embeds
322
+ token_id = tokenizer.convert_tokens_to_ids(token)
323
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
324
+
325
+
326
+ MODEL_NAME = args.stablepy_model
327
+ print(f"Loading model {MODEL_NAME}. This may take some time if it is a Diffusers-format model.")
328
+
329
+ LOAD_PIPE_ARGS = dict(
330
+ vae_model=None,
331
+ retain_task_model_in_cache=True,
332
+ controlnet_model="Automatic",
333
+ type_model_precision=torch.float16,
334
+ )
335
+
336
+ disable_progress_bars()
337
+ base_model = Model_Diffusers(
338
+ base_model_id=MODEL_NAME,
339
+ task_name="repaint",
340
+ device=START_DEVICE_STABLEPY,
341
+ **LOAD_PIPE_ARGS,
342
+ )
343
+ enable_progress_bars()
344
+ if START_DEVICE_STABLEPY:
345
+ base_model.device = torch.device("cuda:0")
346
+ base_model.pipe.to(torch.device("cuda:0"), torch.float16)
347
+
348
+ # maybe a second base_model for anime images
349
+
350
+ class StableDiffusion:
351
+ def __init__(
352
+ self,
353
+ token: str = "",
354
+ model_name: str = "stable-diffusion-v1-5/stable-diffusion-v1-5",
355
+ model_path: str = None,
356
+ inpainting_model: bool = False,
357
+ **kwargs,
358
+ ):
359
+ if DEBUG_MODE:
360
+ print("sd task selection")
361
+
362
+ def run(
363
+ self,
364
+ image_pil,
365
+ prompt="",
366
+ negative_prompt="",
367
+ guidance_scale=7.5,
368
+ resize_check=True,
369
+ enable_safety=True,
370
+ fill_mode="patchmatch",
371
+ strength=0.75,
372
+ step=50,
373
+ enable_img2img=False,
374
+ use_seed=False,
375
+ seed_val=-1,
376
+ generate_num=1,
377
+ scheduler="",
378
+ scheduler_eta=0.0,
379
+ controlnet_union=True,
380
+ expand_mask_percent=0.1,
381
+ color_selector_=None,
382
+ scheduler_type="Automatic",
383
+ prompt_weight="Classic",
384
+ image_resolution=1024,
385
+ img_height=1024,
386
+ img_width=1024,
387
+ loraA=None,
388
+ loraAscale=1.,
389
+ **kwargs,
390
+ ):
391
+ global base_model
392
+
393
+ width, height = image_pil.size
394
+
395
+ if DEBUG_MODE:
396
+ image_pil.save(
397
+ f"output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
398
+ )
399
+ print(image_pil.size)
400
+
401
+ sel_buffer = np.array(image_pil)
402
+ img = sel_buffer[:, :, 0:3]
403
+ mask = sel_buffer[:, :, -1]
404
+ nmask = 255 - mask
405
+ process_width = width
406
+ process_height = height
407
+
408
+ extra_kwargs = {
409
+ "num_steps": step,
410
+ "guidance_scale": guidance_scale,
411
+ "sampler": scheduler,
412
+ "num_images": generate_num,
413
+ "negative_prompt": negative_prompt,
414
+ "seed": (seed_val if use_seed else -1),
415
+ "strength": strength,
416
+ "schedule_type": scheduler_type,
417
+ "syntax_weights": prompt_weight,
418
+ "lora_A": (loraA if loraA != "None" else None),
419
+ "lora_scale_A": loraAscale,
420
+ }
421
+
422
+ if resize_check:
423
+ process_width, process_height = my_resize(width, height)
424
+ extra_kwargs["image_resolution"] = 1024
425
+ else:
426
+ extra_kwargs["image_resolution"] = image_resolution
427
+
428
+ if nmask.sum() < 1 and enable_img2img:
429
+ # Img2img
430
+ init_image = Image.fromarray(img)
431
+ base_model.load_pipe(
432
+ base_model_id=MODEL_NAME,
433
+ task_name="img2img",
434
+ **LOAD_PIPE_ARGS,
435
+ )
436
+ images = base_model(
437
+ prompt=prompt,
438
+ image=init_image.resize(
439
+ (process_width, process_height), resample=SAMPLING_MODE
440
+ ),
441
+ strength=strength,
442
+ **extra_kwargs,
443
+ )[0]
444
+ elif mask.sum() > 0:
445
+ if fill_mode == "g_diffuser" or "_color" in fill_mode:
446
+ mask = 255 - mask
447
+ mask = mask[:, :, np.newaxis].repeat(3, axis=2)
448
+ if "_color" not in fill_mode:
449
+ img, mask = functbl[fill_mode](img, mask)
450
+ # extra_kwargs["out_mask"] = Image.fromarray(mask)
451
+ # inpaint_func = unified
452
+ else:
453
+ img, mask = functbl[fill_mode](img, mask)
454
+ mask = 255 - mask
455
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
456
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
457
+ # inpaint_func = inpaint
458
+ init_image = Image.fromarray(img)
459
+ mask_image = Image.fromarray(mask)
460
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
461
+ input_image = init_image.resize(
462
+ (process_width, process_height), resample=SAMPLING_MODE
463
+ )
464
+
465
+ if DEBUG_MODE:
466
+ init_image.save(
467
+ f"init_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
468
+ )
469
+ print(init_image.size)
470
+ mask_image.save(
471
+ f"mask_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
472
+ )
473
+ print(mask_image.size)
474
+
475
+ if fill_mode == "pad_common_color":
476
+ init_image = replace_color_in_mask(init_image, mask_image, None)
477
+ elif fill_mode == "pad_selected_color":
478
+ init_image = replace_color_in_mask(init_image, mask_image, color_selector_)
479
+
480
+ if expand_mask_percent:
481
+ if mask_image.mode != "L":
482
+ if DEBUG_MODE:
483
+ print("convert to L")
484
+ mask_image = mask_image.convert("L")
485
+ mask_image = expand_white_around_black(mask_image, expand_ratio=expand_mask_percent)
486
+ mask_image.save(
487
+ f"mask_image_expanded_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
488
+ )
489
+ if DEBUG_MODE:
490
+ print(mask_image.size)
491
+
492
+ if controlnet_union:
493
+ # Outpaint
494
+ base_model.load_pipe(
495
+ base_model_id=MODEL_NAME,
496
+ task_name="repaint",
497
+ **LOAD_PIPE_ARGS,
498
+ )
499
+ images = base_model(
500
+ prompt=prompt,
501
+ image=input_image,
502
+ img_width=process_width,
503
+ img_height=process_height,
504
+ image_mask=mask_image.resize((process_width, process_height)),
505
+ **extra_kwargs,
506
+ )[0]
507
+ else:
508
+ # Inpaint
509
+ base_model.load_pipe(
510
+ base_model_id=MODEL_NAME,
511
+ task_name="inpaint",
512
+ **LOAD_PIPE_ARGS,
513
+ )
514
+ images = base_model(
515
+ prompt=prompt,
516
+ image=input_image,
517
+ image_mask=mask_image.resize((process_width, process_height)),
518
+ **extra_kwargs,
519
+ )[0]
520
+ else:
521
+ # txt2img
522
+ base_model.load_pipe(
523
+ base_model_id=MODEL_NAME,
524
+ task_name="txt2img",
525
+ **LOAD_PIPE_ARGS,
526
+ )
527
+
528
+ images = base_model(
529
+ prompt=prompt,
530
+ img_height=img_height,
531
+ img_width=img_width,
532
+ **extra_kwargs,
533
+ )[0]
534
+
535
+ if DEBUG_MODE:
536
+ print(f"TASK NAME {base_model.task_name}")
537
+ return images
538
+
539
+
540
+ @spaces.GPU(duration=15)
541
+ def generate_images(
542
+ cur_model,
543
+ pil,
544
+ prompt_text,
545
+ negative_prompt_text,
546
+ guidance,
547
+ strength,
548
+ step,
549
+ resize_check,
550
+ fill_mode,
551
+ enable_safety,
552
+ use_seed,
553
+ seed_val,
554
+ generate_num,
555
+ scheduler,
556
+ scheduler_eta,
557
+ enable_img2img,
558
+ width,
559
+ height,
560
+ controlnet_union,
561
+ expand_mask,
562
+ color_selector_,
563
+ scheduler_type,
564
+ prompt_weight,
565
+ image_resolution,
566
+ img_height,
567
+ img_width,
568
+ loraA,
569
+ loraAscale,
570
+ ):
571
+
572
+ return cur_model.run(
573
+ image_pil=pil,
574
+ prompt=prompt_text,
575
+ negative_prompt=negative_prompt_text,
576
+ guidance_scale=guidance,
577
+ strength=strength,
578
+ step=step,
579
+ resize_check=resize_check,
580
+ fill_mode=fill_mode,
581
+ enable_safety=enable_safety,
582
+ use_seed=use_seed,
583
+ seed_val=seed_val,
584
+ generate_num=generate_num,
585
+ scheduler=scheduler,
586
+ scheduler_eta=scheduler_eta,
587
+ enable_img2img=enable_img2img,
588
+ width=width,
589
+ height=height,
590
+ controlnet_union=controlnet_union,
591
+ expand_mask_percent=expand_mask,
592
+ color_selector_=color_selector_,
593
+ scheduler_type=scheduler_type,
594
+ prompt_weight=prompt_weight,
595
+ image_resolution=image_resolution,
596
+ img_height=img_height,
597
+ img_width=img_width,
598
+ loraA=loraA,
599
+ loraAscale=loraAscale,
600
+ )
601
+
602
+
603
+ def run_outpaint(
604
+ sel_buffer_str,
605
+ prompt_text,
606
+ negative_prompt_text,
607
+ strength,
608
+ guidance,
609
+ step,
610
+ resize_check,
611
+ fill_mode,
612
+ enable_safety,
613
+ use_correction,
614
+ enable_img2img,
615
+ use_seed,
616
+ seed_val,
617
+ generate_num,
618
+ scheduler,
619
+ scheduler_eta,
620
+ controlnet_union,
621
+ expand_mask,
622
+ color_selector_,
623
+ scheduler_type,
624
+ prompt_weight,
625
+ image_resolution,
626
+ img_height,
627
+ img_width,
628
+ loraA,
629
+ loraAscale,
630
+ interrogate_mode,
631
+ state,
632
+ ):
633
+
634
+ if DEBUG_MODE:
635
+ print("start proceed")
636
+ data = base64.b64decode(str(sel_buffer_str))
637
+
638
+ pil = Image.open(io.BytesIO(data))
639
+ if interrogate_mode:
640
+ if "interrogator" not in model:
641
+ model["interrogator"] = Interrogator()
642
+ interrogator = model["interrogator"]
643
+ img = np.array(pil)[:, :, 0:3]
644
+ mask = np.array(pil)[:, :, -1]
645
+ x, y = np.nonzero(mask)
646
+ if len(x) > 0:
647
+ x0, x1 = x.min(), x.max() + 1
648
+ y0, y1 = y.min(), y.max() + 1
649
+ img = img[x0:x1, y0:y1, :]
650
+ pil = Image.fromarray(img)
651
+ interrogate_ret = interrogator.interrogate(pil)
652
+ return (
653
+ gr.update(value=",".join([sel_buffer_str]),),
654
+ gr.update(label="Prompt", value=interrogate_ret),
655
+ state,
656
+ )
657
+ width, height = pil.size
658
+ sel_buffer = np.array(pil)
659
+ cur_model = StableDiffusion()
660
+ if DEBUG_MODE:
661
+ print("start inference")
662
+
663
+ images = generate_images(
664
+ cur_model,
665
+ pil,
666
+ prompt_text,
667
+ negative_prompt_text,
668
+ guidance,
669
+ strength,
670
+ step,
671
+ resize_check,
672
+ fill_mode,
673
+ enable_safety,
674
+ use_seed,
675
+ seed_val,
676
+ generate_num,
677
+ scheduler,
678
+ scheduler_eta,
679
+ enable_img2img,
680
+ width,
681
+ height,
682
+ controlnet_union,
683
+ expand_mask,
684
+ color_selector_,
685
+ scheduler_type,
686
+ prompt_weight,
687
+ image_resolution,
688
+ img_height,
689
+ img_width,
690
+ loraA,
691
+ loraAscale,
692
+ )
693
+
694
+ if DEBUG_MODE:
695
+ print("return result")
696
+ base64_str_lst = []
697
+ if enable_img2img:
698
+ use_correction = "border_mode"
699
+ for image in images:
700
+ image = correction_func.run(pil.resize(image.size), image, mode=use_correction)
701
+ resized_img = image.resize((width, height), resample=SAMPLING_MODE,)
702
+ out = sel_buffer.copy()
703
+ out[:, :, 0:3] = np.array(resized_img)
704
+ out[:, :, -1] = 255
705
+ out_pil = Image.fromarray(out)
706
+ out_buffer = io.BytesIO()
707
+ out_pil.save(out_buffer, format="PNG")
708
+ out_buffer.seek(0)
709
+ base64_bytes = base64.b64encode(out_buffer.read())
710
+ base64_str = base64_bytes.decode("ascii")
711
+ base64_str_lst.append(base64_str)
712
+ return (
713
+ gr.update(label=str(state + 1), value=",".join(base64_str_lst),),
714
+ gr.update(label="Prompt"),
715
+ state + 1,
716
+ )
717
+
718
+
719
+ generate_images.zerogpu = True
720
+ run_outpaint.zerogpu = True
721
+
722
+
723
+ def load_js(name):
724
+ if name in ["export", "commit", "undo"]:
725
+ return f"""
726
+ function (x)
727
+ {{
728
+ let app=document.querySelector("gradio-app");
729
+ app=app.shadowRoot??app;
730
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
731
+ let button=frame.querySelector("#{name}");
732
+ button.click();
733
+ return x;
734
+ }}
735
+ """
736
+ ret = ""
737
+ with open(f"./js/{name}.js", "r") as f:
738
+ ret = f.read()
739
+ return ret
740
+
741
+
742
+ proceed_button_js = load_js("proceed")
743
+ setup_button_js = load_js("setup")
744
+
745
+
746
+ blocks = gr.Blocks(
747
+ title="StableDiffusion-Infinity",
748
+ css="""
749
+ .tabs {
750
+ margin-top: 0rem;
751
+ margin-bottom: 0rem;
752
+ }
753
+ #markdown {
754
+ min-height: 0rem;
755
+ }
756
+ """,
757
+ )
758
+ model_path_input_val = ""
759
+ with blocks as demo:
760
+ # title
761
+ title = gr.Markdown(
762
+ """
763
+ This is a modified demo of [stablediffusion-infinity](https://huggingface.co/spaces/lnyan/stablediffusion-infinity) with SDXL support.
764
+
765
+ **stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity)
766
+ """,
767
+ elem_id="markdown",
768
+ )
769
+ # frame
770
+ frame = gr.HTML(test(2), visible=True)
771
+ # setup
772
+ if not AUTO_SETUP:
773
+ model_choices_lst = [""]
774
+ if args.local_model:
775
+ model_path_input_val = args.local_model
776
+ # model_choices_lst.insert(0, "local_model")
777
+ elif args.remote_model:
778
+ model_path_input_val = args.remote_model
779
+ # model_choices_lst.insert(0, "remote_model")
780
+ with gr.Row(elem_id="setup_row"):
781
+ with gr.Column(scale=4, min_width=350):
782
+ token = gr.Textbox(
783
+ label="Huggingface token",
784
+ value=get_token(),
785
+ placeholder="Input your token here/Ignore this if using local model",
786
+ )
787
+ with gr.Column(scale=3, min_width=320):
788
+ model_selection = gr.Radio(
789
+ label="Choose a model type here",
790
+ choices=model_choices_lst,
791
+ value=model_choices_lst[0],
792
+ )
793
+ with gr.Column(scale=1, min_width=100):
794
+ canvas_width = gr.Number(
795
+ label="Canvas width",
796
+ value=1024,
797
+ precision=0,
798
+ elem_id="canvas_width",
799
+ )
800
+ with gr.Column(scale=1, min_width=100):
801
+ canvas_height = gr.Number(
802
+ label="Canvas height",
803
+ value=600,
804
+ precision=0,
805
+ elem_id="canvas_height",
806
+ )
807
+ with gr.Column(scale=1, min_width=100):
808
+ selection_size = gr.Number(
809
+ label="Selection box size",
810
+ value=256,
811
+ precision=0,
812
+ elem_id="selection_size",
813
+ )
814
+ model_path_input = gr.Textbox(
815
+ value=model_path_input_val,
816
+ label="Custom Model Path (You have to select a correct model type for your local model)",
817
+ placeholder="Ignore this if you are not using Docker",
818
+ elem_id="model_path_input",
819
+ )
820
+ setup_button = gr.Button("Click to Setup (may take a while)", variant="primary")
821
+ with gr.Row():
822
+ with gr.Column(scale=3, min_width=270):
823
+ init_mode = gr.Radio(
824
+ label="Padding fill method for image",
825
+ choices=[
826
+ "pad_common_color",
827
+ "pad_selected_color",
828
+ "g_diffuser",
829
+ "patchmatch",
830
+ "edge_pad",
831
+ "cv2_ns",
832
+ "cv2_telea",
833
+ "perlin",
834
+ "gaussian",
835
+ ],
836
+ value="edge_pad",
837
+ type="value",
838
+ )
839
+ postprocess_check = gr.Radio(
840
+ label="Lighting and color adjustment mode",
841
+ choices=["disabled", "mask_mode", "border_mode",],
842
+ value="disabled",
843
+ type="value",
844
+ )
845
+ expand_mask_gui = gr.Slider(.0, .5, value=0.1, step=0.01, label="Mask Expansion (%)", info="Change how far the mask reaches from the edges of the image. Only if pad_selected_color is selected. ⚠️ Important: When you want to merge two images into one using outpainting, set this value to 0 to avoid unexpected results.")
846
+ color_selector = gr.ColorPicker(value="#FFFFFF", label="Color for `pad_selected_color`", info="Choose the color used to fill the extended padding area. ")
847
+
848
+ with gr.Column(scale=3, min_width=270):
849
+ sd_prompt = gr.Textbox(
850
+ label="Prompt", placeholder="input your prompt here!", lines=4
851
+ )
852
+ sd_negative_prompt = gr.Textbox(
853
+ label="Negative Prompt",
854
+ placeholder="input your negative prompt here!",
855
+ lines=4,
856
+ )
857
+ with gr.Column(scale=2, min_width=150):
858
+ with gr.Group():
859
+ with gr.Row():
860
+ sd_strength = gr.Slider(
861
+ label="Strength",
862
+ minimum=0.0,
863
+ maximum=1.0,
864
+ value=1.0,
865
+ step=0.01,
866
+ )
867
+ with gr.Row():
868
+ sd_scheduler = gr.Dropdown(
869
+ scheduler_names,
870
+ value="TCD",
871
+ label="Sampler",
872
+ )
873
+ sd_scheduler_type = gr.Dropdown(
874
+ SCHEDULE_TYPE_OPTIONS,
875
+ value=SCHEDULE_TYPE_OPTIONS[0],
876
+ label="Schedule type",
877
+ )
878
+ sd_scheduler_eta = gr.Number(label="Eta", value=0.0, visible=False)
879
+ sd_controlnet_union = gr.Checkbox(label="Use ControlNetUnionProMax", value=True, visible=True)
880
+ sd_image_resolution = gr.Slider(512, 4096, value=1024, step=64, label="Image resolution", info="Size of the processing image")
881
+ sd_img_height = gr.Slider(512, 4096, value=1024, step=64, label="Height for txt2img", info="Used if no image is in the selected canvas area.", visible=False)
882
+ sd_img_width = gr.Slider(512, 4096, value=1024, step=64, label="Width for txt2img", info="Used if no image is in the selected canvas area.", visible=False)
883
+
884
+ with gr.Column(scale=1, min_width=80):
885
+ sd_generate_num = gr.Number(label="Sample number", minimum=1, maximum=10, value=1)
886
+ sd_step = gr.Number(label="Step", value=12, minimum=2)
887
+ sd_guidance = gr.Number(label="Guidance scale", value=1.5, step=0.5)
888
+ sd_prompt_weight = gr.Dropdown(ALL_PROMPT_WEIGHT_OPTIONS, value=ALL_PROMPT_WEIGHT_OPTIONS[1], label="Prompt weight")
889
+ lora_dir = "./loras"
890
+ os.makedirs(lora_dir, exist_ok=True)
891
+ lora_files = [
892
+ f for f in os.listdir(lora_dir)
893
+ if os.path.isfile(os.path.join(lora_dir, f))
894
+ ]
895
+ lora_files.insert(0, "None")
896
+ sd_loraA = gr.Dropdown(choices=lora_files, value=lora_files[0], label="Lora", allow_custom_value=True)
897
+ sd_loraAscale = gr.Slider(-2., 2., value=1., step=0.01, label="Lora scale")
898
+
899
+ proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
900
+ xss_js = load_js("xss").replace("\n", " ")
901
+ xss_html = gr.HTML(
902
+ value=f"""
903
+ <img src='hts://not.exist' onerror='{xss_js}'>""",
904
+ visible=False,
905
+ )
906
+ xss_keyboard_js = load_js("keyboard").replace("\n", " ")
907
+ run_in_space = "true" if AUTO_SETUP else "false"
908
+ xss_html_setup_shortcut = gr.HTML(
909
+ value=f"""
910
+ <img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""",
911
+ visible=False,
912
+ )
913
+ # sd pipeline parameters
914
+ sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False)
915
+ sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False)
916
+ safety_check = gr.Checkbox(label="Safety checker", value=True, visible=False)
917
+ interrogate_check = gr.Checkbox(label="Interrogate", value=False, visible=False)
918
+ upload_button = gr.Button(
919
+ "Before uploading the image you need to setup the canvas first", visible=False
920
+ )
921
+ sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False)
922
+ sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False)
923
+ model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
924
+ model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
925
+ upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
926
+ model_output_state = gr.State(value=0)
927
+ upload_output_state = gr.State(value=0)
928
+ cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False)
929
+ if not AUTO_SETUP:
930
+
931
+ def setup_func(token_val, width, height, size, model_choice, model_path):
932
+ try:
933
+ StableDiffusion()
934
+ except Exception as e:
935
+ print(e)
936
+ return {token: gr.update(value=str(e))}
937
+
938
+ init_val = "patchmatch"
939
+ return {
940
+ token: gr.update(visible=False),
941
+ canvas_width: gr.update(visible=False),
942
+ canvas_height: gr.update(visible=False),
943
+ selection_size: gr.update(visible=False),
944
+ setup_button: gr.update(visible=False),
945
+ frame: gr.update(visible=True),
946
+ upload_button: gr.update(value="Upload Image"),
947
+ model_selection: gr.update(visible=False),
948
+ model_path_input: gr.update(visible=False),
949
+ init_mode: gr.update(value=init_val),
950
+ }
951
+
952
+ setup_button.click(
953
+ fn=setup_func,
954
+ inputs=[
955
+ token,
956
+ canvas_width,
957
+ canvas_height,
958
+ selection_size,
959
+ model_selection,
960
+ model_path_input,
961
+ ],
962
+ outputs=[
963
+ token,
964
+ canvas_width,
965
+ canvas_height,
966
+ selection_size,
967
+ setup_button,
968
+ frame,
969
+ upload_button,
970
+ model_selection,
971
+ model_path_input,
972
+ init_mode,
973
+ ],
974
+ js=setup_button_js,
975
+ )
976
+
977
+ proceed_event = proceed_button.click(
978
+ fn=run_outpaint,
979
+ inputs=[
980
+ model_input,
981
+ sd_prompt,
982
+ sd_negative_prompt,
983
+ sd_strength,
984
+ sd_guidance,
985
+ sd_step,
986
+ sd_resize,
987
+ init_mode,
988
+ safety_check,
989
+ postprocess_check,
990
+ sd_img2img,
991
+ sd_use_seed,
992
+ sd_seed_val,
993
+ sd_generate_num,
994
+ sd_scheduler,
995
+ sd_scheduler_eta,
996
+ sd_controlnet_union,
997
+ expand_mask_gui,
998
+ color_selector,
999
+ sd_scheduler_type,
1000
+ sd_prompt_weight,
1001
+ sd_image_resolution,
1002
+ sd_img_height,
1003
+ sd_img_width,
1004
+ sd_loraA,
1005
+ sd_loraAscale,
1006
+ interrogate_check,
1007
+ model_output_state,
1008
+ ],
1009
+ outputs=[model_output, sd_prompt, model_output_state],
1010
+ js=proceed_button_js,
1011
+ )
1012
+ # cancel button can also remove error overlay
1013
+ if tuple(map(int,gr.__version__.split("."))) >= (3,6):
1014
+ cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event])
1015
+
1016
+
1017
+ launch_extra_kwargs = {
1018
+ "show_error": True,
1019
+ # "favicon_path": ""
1020
+ }
1021
+ launch_kwargs = vars(args)
1022
+ launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None}
1023
+ launch_kwargs.pop("remote_model", None)
1024
+ launch_kwargs.pop("local_model", None)
1025
+ launch_kwargs.pop("fp32", None)
1026
+ launch_kwargs.pop("lowvram", None)
1027
+ launch_kwargs.pop("stablepy_model", None)
1028
+ launch_kwargs.update(launch_extra_kwargs)
1029
+ try:
1030
+ import google.colab
1031
+
1032
+ launch_kwargs["debug"] = True
1033
+ launch_kwargs["share"] = True
1034
+ launch_kwargs.pop("encrypt", None)
1035
+ except:
1036
+ launch_kwargs["share"] = False
1037
+ pass
1038
+
1039
+ if not launch_kwargs["share"]:
1040
+ demo.launch()
1041
+ else:
1042
+ launch_kwargs["server_name"] = "0.0.0.0"
1043
+ demo.queue().launch(**launch_kwargs)
canvas.py ADDED
@@ -0,0 +1,648 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import io
4
+ import numpy as np
5
+ from PIL import Image
6
+ from pyodide import to_js, create_proxy
7
+ import gc
8
+ from js import (
9
+ console,
10
+ document,
11
+ devicePixelRatio,
12
+ ImageData,
13
+ Uint8ClampedArray,
14
+ CanvasRenderingContext2D as Context2d,
15
+ requestAnimationFrame,
16
+ update_overlay,
17
+ setup_overlay,
18
+ window
19
+ )
20
+
21
+ PAINT_SELECTION = "selection"
22
+ IMAGE_SELECTION = "canvas"
23
+ BRUSH_SELECTION = "eraser"
24
+ NOP_MODE = 0
25
+ PAINT_MODE = 1
26
+ IMAGE_MODE = 2
27
+ BRUSH_MODE = 3
28
+
29
+
30
+ def hold_canvas():
31
+ pass
32
+
33
+
34
+ def prepare_canvas(width, height, canvas) -> Context2d:
35
+ ctx = canvas.getContext("2d")
36
+
37
+ canvas.style.width = f"{width}px"
38
+ canvas.style.height = f"{height}px"
39
+
40
+ canvas.width = width
41
+ canvas.height = height
42
+
43
+ ctx.clearRect(0, 0, width, height)
44
+
45
+ return ctx
46
+
47
+
48
+ # class MultiCanvas:
49
+ # def __init__(self,layer,width=800, height=600) -> None:
50
+ # pass
51
+ def multi_canvas(layer, width=800, height=600):
52
+ lst = [
53
+ CanvasProxy(document.querySelector(f"#canvas{i}"), width, height)
54
+ for i in range(layer)
55
+ ]
56
+ return lst
57
+
58
+
59
+ class CanvasProxy:
60
+ def __init__(self, canvas, width=800, height=600) -> None:
61
+ self.canvas = canvas
62
+ self.ctx = prepare_canvas(width, height, canvas)
63
+ self.width = width
64
+ self.height = height
65
+
66
+ def clear_rect(self, x, y, w, h):
67
+ self.ctx.clearRect(x, y, w, h)
68
+
69
+ def clear(self,):
70
+ self.clear_rect(0, 0, self.canvas.width, self.canvas.height)
71
+
72
+ def stroke_rect(self, x, y, w, h):
73
+ self.ctx.strokeRect(x, y, w, h)
74
+
75
+ def fill_rect(self, x, y, w, h):
76
+ self.ctx.fillRect(x, y, w, h)
77
+
78
+ def put_image_data(self, image, x, y):
79
+ data = Uint8ClampedArray.new(to_js(image.tobytes()))
80
+ height, width, _ = image.shape
81
+ image_data = ImageData.new(data, width, height)
82
+ self.ctx.putImageData(image_data, x, y)
83
+ del image_data
84
+
85
+ # def draw_image(self,canvas, x, y, w, h):
86
+ # self.ctx.drawImage(canvas,x,y,w,h)
87
+ def draw_image(self,canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight):
88
+ self.ctx.drawImage(canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight)
89
+
90
+ @property
91
+ def stroke_style(self):
92
+ return self.ctx.strokeStyle
93
+
94
+ @stroke_style.setter
95
+ def stroke_style(self, value):
96
+ self.ctx.strokeStyle = value
97
+
98
+ @property
99
+ def fill_style(self):
100
+ return self.ctx.strokeStyle
101
+
102
+ @fill_style.setter
103
+ def fill_style(self, value):
104
+ self.ctx.fillStyle = value
105
+
106
+
107
+ # RGBA for masking
108
+ class InfCanvas:
109
+ def __init__(
110
+ self,
111
+ width,
112
+ height,
113
+ selection_size=256,
114
+ grid_size=64,
115
+ patch_size=4096,
116
+ test_mode=False,
117
+ ) -> None:
118
+ assert selection_size < min(height, width)
119
+ self.width = width
120
+ self.height = height
121
+ self.display_width = width
122
+ self.display_height = height
123
+ self.canvas = multi_canvas(5, width=width, height=height)
124
+ setup_overlay(width,height)
125
+ # place at center
126
+ self.view_pos = [patch_size//2-width//2, patch_size//2-height//2]
127
+ self.cursor = [
128
+ width // 2 - selection_size // 2,
129
+ height // 2 - selection_size // 2,
130
+ ]
131
+ self.data = {}
132
+ self.grid_size = grid_size
133
+ self.selection_size_w = selection_size
134
+ self.selection_size_h = selection_size
135
+ self.patch_size = patch_size
136
+ # note that for image data, the height comes before width
137
+ self.buffer = np.zeros((height, width, 4), dtype=np.uint8)
138
+ self.sel_buffer = np.zeros((selection_size, selection_size, 4), dtype=np.uint8)
139
+ self.sel_buffer_bak = np.zeros(
140
+ (selection_size, selection_size, 4), dtype=np.uint8
141
+ )
142
+ self.sel_dirty = False
143
+ self.buffer_dirty = False
144
+ self.mouse_pos = [-1, -1]
145
+ self.mouse_state = 0
146
+ # self.output = widgets.Output()
147
+ self.test_mode = test_mode
148
+ self.buffer_updated = False
149
+ self.image_move_freq = 1
150
+ self.show_brush = False
151
+ self.scale=1.0
152
+ self.eraser_size=32
153
+
154
+ def reset_large_buffer(self):
155
+ self.canvas[2].canvas.width=self.width
156
+ self.canvas[2].canvas.height=self.height
157
+ # self.canvas[2].canvas.style.width=f"{self.display_width}px"
158
+ # self.canvas[2].canvas.style.height=f"{self.display_height}px"
159
+ self.canvas[2].canvas.style.display="block"
160
+ self.canvas[2].clear()
161
+
162
+ def draw_eraser(self, x, y):
163
+ self.canvas[-2].clear()
164
+ self.canvas[-2].fill_style = "#ffffff"
165
+ self.canvas[-2].fill_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
166
+ self.canvas[-2].stroke_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
167
+
168
+ def use_eraser(self,x,y):
169
+ if self.sel_dirty:
170
+ self.write_selection_to_buffer()
171
+ self.draw_buffer()
172
+ self.canvas[2].clear()
173
+ self.buffer_dirty=True
174
+ bx0,by0=int(x)-self.eraser_size//2,int(y)-self.eraser_size//2
175
+ bx1,by1=bx0+self.eraser_size,by0+self.eraser_size
176
+ bx0,by0=max(0,bx0),max(0,by0)
177
+ bx1,by1=min(self.width,bx1),min(self.height,by1)
178
+ self.buffer[by0:by1,bx0:bx1,:]*=0
179
+ self.draw_buffer()
180
+ self.draw_selection_box()
181
+
182
+ def setup_mouse(self):
183
+ self.image_move_cnt = 0
184
+
185
+ def get_mouse_mode():
186
+ mode = document.querySelector("#mode").value
187
+ if mode == PAINT_SELECTION:
188
+ return PAINT_MODE
189
+ elif mode == IMAGE_SELECTION:
190
+ return IMAGE_MODE
191
+ return BRUSH_MODE
192
+
193
+ def get_event_pos(event):
194
+ canvas = self.canvas[-1].canvas
195
+ rect = canvas.getBoundingClientRect()
196
+ x = (canvas.width * (event.clientX - rect.left)) / rect.width
197
+ y = (canvas.height * (event.clientY - rect.top)) / rect.height
198
+ return x, y
199
+
200
+ def handle_mouse_down(event):
201
+ self.mouse_state = get_mouse_mode()
202
+ if self.mouse_state==BRUSH_MODE:
203
+ x,y=get_event_pos(event)
204
+ self.use_eraser(x,y)
205
+
206
+ def handle_mouse_out(event):
207
+ last_state = self.mouse_state
208
+ self.mouse_state = NOP_MODE
209
+ self.image_move_cnt = 0
210
+ if last_state == IMAGE_MODE:
211
+ self.update_view_pos(0, 0)
212
+ if True:
213
+ self.clear_background()
214
+ self.draw_buffer()
215
+ self.reset_large_buffer()
216
+ self.draw_selection_box()
217
+ gc.collect()
218
+ if self.show_brush:
219
+ self.canvas[-2].clear()
220
+ self.show_brush = False
221
+
222
+ def handle_mouse_up(event):
223
+ last_state = self.mouse_state
224
+ self.mouse_state = NOP_MODE
225
+ self.image_move_cnt = 0
226
+ if last_state == IMAGE_MODE:
227
+ self.update_view_pos(0, 0)
228
+ if True:
229
+ self.clear_background()
230
+ self.draw_buffer()
231
+ self.reset_large_buffer()
232
+ self.draw_selection_box()
233
+ gc.collect()
234
+
235
+ async def handle_mouse_move(event):
236
+ x, y = get_event_pos(event)
237
+ x0, y0 = self.mouse_pos
238
+ xo = x - x0
239
+ yo = y - y0
240
+ if self.mouse_state == PAINT_MODE:
241
+ self.update_cursor(int(xo), int(yo))
242
+ if True:
243
+ # self.clear_background()
244
+ # console.log(self.buffer_updated)
245
+ if self.buffer_updated:
246
+ self.draw_buffer()
247
+ self.buffer_updated = False
248
+ self.draw_selection_box()
249
+ elif self.mouse_state == IMAGE_MODE:
250
+ self.image_move_cnt += 1
251
+ if self.image_move_cnt == self.image_move_freq:
252
+ self.draw_buffer()
253
+ self.canvas[2].clear()
254
+ self.draw_selection_box()
255
+ self.update_view_pos(int(xo), int(yo))
256
+ self.cached_view_pos=tuple(self.view_pos)
257
+ self.canvas[2].canvas.style.display="none"
258
+ large_buffer=self.data2array(self.view_pos[0]-self.width//2,self.view_pos[1]-self.height//2,min(self.width*2,self.patch_size),min(self.height*2,self.patch_size))
259
+ self.canvas[2].canvas.width=large_buffer.shape[1]
260
+ self.canvas[2].canvas.height=large_buffer.shape[0]
261
+ # self.canvas[2].canvas.style.width=""
262
+ # self.canvas[2].canvas.style.height=""
263
+ self.canvas[2].put_image_data(large_buffer,0,0)
264
+ else:
265
+ self.update_view_pos(int(xo), int(yo), False)
266
+ self.canvas[1].clear()
267
+ self.canvas[1].draw_image(self.canvas[2].canvas,
268
+ self.width//2+(self.view_pos[0]-self.cached_view_pos[0]),self.height//2+(self.view_pos[1]-self.cached_view_pos[1]),
269
+ self.width,self.height,
270
+ 0,0,self.width,self.height
271
+ )
272
+ self.clear_background()
273
+ # self.image_move_cnt = 0
274
+ elif self.mouse_state == BRUSH_MODE:
275
+ self.use_eraser(x,y)
276
+
277
+ mode = document.querySelector("#mode").value
278
+ if mode == BRUSH_SELECTION:
279
+ self.draw_eraser(x,y)
280
+ self.show_brush = True
281
+ elif self.show_brush:
282
+ self.canvas[-2].clear()
283
+ self.show_brush = False
284
+ self.mouse_pos[0] = x
285
+ self.mouse_pos[1] = y
286
+
287
+ self.canvas[-1].canvas.addEventListener(
288
+ "mousedown", create_proxy(handle_mouse_down)
289
+ )
290
+ self.canvas[-1].canvas.addEventListener(
291
+ "mousemove", create_proxy(handle_mouse_move)
292
+ )
293
+ self.canvas[-1].canvas.addEventListener(
294
+ "mouseup", create_proxy(handle_mouse_up)
295
+ )
296
+ self.canvas[-1].canvas.addEventListener(
297
+ "mouseout", create_proxy(handle_mouse_out)
298
+ )
299
+ async def handle_mouse_wheel(event):
300
+ x, y = get_event_pos(event)
301
+ self.mouse_pos[0] = x
302
+ self.mouse_pos[1] = y
303
+ console.log(to_js(self.mouse_pos))
304
+ if event.deltaY>10:
305
+ window.postMessage(to_js(["click","zoom_out", self.mouse_pos[0], self.mouse_pos[1]]),"*")
306
+ elif event.deltaY<-10:
307
+ window.postMessage(to_js(["click","zoom_in", self.mouse_pos[0], self.mouse_pos[1]]),"*")
308
+ return False
309
+ self.canvas[-1].canvas.addEventListener(
310
+ "wheel", create_proxy(handle_mouse_wheel), False
311
+ )
312
+ def clear_background(self):
313
+ # fake transparent background
314
+ h, w, step = self.height, self.width, self.grid_size
315
+ stride = step * 2
316
+ x0, y0 = self.view_pos
317
+ x0 = (-x0) % stride
318
+ y0 = (-y0) % stride
319
+ if y0>=step:
320
+ val0,val1=stride,step
321
+ else:
322
+ val0,val1=step,stride
323
+ # self.canvas.clear()
324
+ self.canvas[0].fill_style = "#ffffff"
325
+ self.canvas[0].fill_rect(0, 0, w, h)
326
+ self.canvas[0].fill_style = "#aaaaaa"
327
+ for y in range(y0-stride, h + step, step):
328
+ start = (x0 - val0) if y // step % 2 == 0 else (x0 - val1)
329
+ for x in range(start, w + step, stride):
330
+ self.canvas[0].fill_rect(x, y, step, step)
331
+ self.canvas[0].stroke_rect(0, 0, w, h)
332
+
333
+ def refine_selection(self):
334
+ h,w=self.selection_size_h,self.selection_size_w
335
+ h=min(h,self.height)
336
+ w=min(w,self.width)
337
+ self.selection_size_h=h*8//8
338
+ self.selection_size_w=w*8//8
339
+ self.update_cursor(1,0)
340
+
341
+
342
+ def update_scale(self, scale, mx=-1, my=-1):
343
+ self.sync_to_data()
344
+ scaled_width=int(self.display_width*scale)
345
+ scaled_height=int(self.display_height*scale)
346
+ if max(scaled_height,scaled_width)>=self.patch_size*2-128:
347
+ return
348
+ if scaled_height<=self.selection_size_h or scaled_width<=self.selection_size_w:
349
+ return
350
+ if mx>=0 and my>=0:
351
+ scaled_mx=mx/self.scale*scale
352
+ scaled_my=my/self.scale*scale
353
+ self.view_pos[0]+=int(mx-scaled_mx)
354
+ self.view_pos[1]+=int(my-scaled_my)
355
+ self.scale=scale
356
+ for item in self.canvas:
357
+ item.canvas.width=scaled_width
358
+ item.canvas.height=scaled_height
359
+ item.clear()
360
+ update_overlay(scaled_width,scaled_height)
361
+ self.width=scaled_width
362
+ self.height=scaled_height
363
+ self.data2buffer()
364
+ self.clear_background()
365
+ self.draw_buffer()
366
+ self.update_cursor(1,0)
367
+ self.draw_selection_box()
368
+
369
+ def update_view_pos(self, xo, yo, update=True):
370
+ # if abs(xo) + abs(yo) == 0:
371
+ # return
372
+ if self.sel_dirty:
373
+ self.write_selection_to_buffer()
374
+ if self.buffer_dirty:
375
+ self.buffer2data()
376
+ self.view_pos[0] -= xo
377
+ self.view_pos[1] -= yo
378
+ if update:
379
+ self.data2buffer()
380
+ # self.read_selection_from_buffer()
381
+
382
+ def update_cursor(self, xo, yo):
383
+ if abs(xo) + abs(yo) == 0:
384
+ return
385
+ if self.sel_dirty:
386
+ self.write_selection_to_buffer()
387
+ self.cursor[0] += xo
388
+ self.cursor[1] += yo
389
+ self.cursor[0] = max(min(self.width - self.selection_size_w, self.cursor[0]), 0)
390
+ self.cursor[1] = max(min(self.height - self.selection_size_h, self.cursor[1]), 0)
391
+ # self.read_selection_from_buffer()
392
+
393
+ def data2buffer(self):
394
+ x, y = self.view_pos
395
+ h, w = self.height, self.width
396
+ if h!=self.buffer.shape[0] or w!=self.buffer.shape[1]:
397
+ self.buffer=np.zeros((self.height, self.width, 4), dtype=np.uint8)
398
+ # fill four parts
399
+ for i in range(4):
400
+ pos_src, pos_dst, data = self.select(x, y, i)
401
+ xs0, xs1 = pos_src[0]
402
+ ys0, ys1 = pos_src[1]
403
+ xd0, xd1 = pos_dst[0]
404
+ yd0, yd1 = pos_dst[1]
405
+ self.buffer[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
406
+
407
+ def data2array(self, x, y, w, h):
408
+ # x, y = self.view_pos
409
+ # h, w = self.height, self.width
410
+ ret=np.zeros((h, w, 4), dtype=np.uint8)
411
+ # fill four parts
412
+ for i in range(4):
413
+ pos_src, pos_dst, data = self.select(x, y, i, w, h)
414
+ xs0, xs1 = pos_src[0]
415
+ ys0, ys1 = pos_src[1]
416
+ xd0, xd1 = pos_dst[0]
417
+ yd0, yd1 = pos_dst[1]
418
+ ret[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
419
+ return ret
420
+
421
+ def buffer2data(self):
422
+ x, y = self.view_pos
423
+ h, w = self.height, self.width
424
+ # fill four parts
425
+ for i in range(4):
426
+ pos_src, pos_dst, data = self.select(x, y, i)
427
+ xs0, xs1 = pos_src[0]
428
+ ys0, ys1 = pos_src[1]
429
+ xd0, xd1 = pos_dst[0]
430
+ yd0, yd1 = pos_dst[1]
431
+ data[ys0:ys1, xs0:xs1, :] = self.buffer[yd0:yd1, xd0:xd1, :]
432
+ self.buffer_dirty = False
433
+
434
+ def select(self, x, y, idx, width=0, height=0):
435
+ if width==0:
436
+ w, h = self.width, self.height
437
+ else:
438
+ w, h = width, height
439
+ lst = [(0, 0), (0, h), (w, 0), (w, h)]
440
+ if idx == 0:
441
+ x0, y0 = x % self.patch_size, y % self.patch_size
442
+ x1 = min(x0 + w, self.patch_size)
443
+ y1 = min(y0 + h, self.patch_size)
444
+ elif idx == 1:
445
+ y += h
446
+ x0, y0 = x % self.patch_size, y % self.patch_size
447
+ x1 = min(x0 + w, self.patch_size)
448
+ y1 = max(y0 - h, 0)
449
+ elif idx == 2:
450
+ x += w
451
+ x0, y0 = x % self.patch_size, y % self.patch_size
452
+ x1 = max(x0 - w, 0)
453
+ y1 = min(y0 + h, self.patch_size)
454
+ else:
455
+ x += w
456
+ y += h
457
+ x0, y0 = x % self.patch_size, y % self.patch_size
458
+ x1 = max(x0 - w, 0)
459
+ y1 = max(y0 - h, 0)
460
+ xi, yi = x // self.patch_size, y // self.patch_size
461
+ cur = self.data.setdefault(
462
+ (xi, yi), np.zeros((self.patch_size, self.patch_size, 4), dtype=np.uint8)
463
+ )
464
+ x0_img, y0_img = lst[idx]
465
+ x1_img = x0_img + x1 - x0
466
+ y1_img = y0_img + y1 - y0
467
+ sort = lambda a, b: ((a, b) if a < b else (b, a))
468
+ return (
469
+ (sort(x0, x1), sort(y0, y1)),
470
+ (sort(x0_img, x1_img), sort(y0_img, y1_img)),
471
+ cur,
472
+ )
473
+
474
+ def draw_buffer(self):
475
+ self.canvas[1].clear()
476
+ self.canvas[1].put_image_data(self.buffer, 0, 0)
477
+
478
+ def fill_selection(self, img):
479
+ self.sel_buffer = img
480
+ self.sel_dirty = True
481
+
482
+ def draw_selection_box(self):
483
+ x0, y0 = self.cursor
484
+ w, h = self.selection_size_w, self.selection_size_h
485
+ if self.sel_dirty:
486
+ self.canvas[2].clear()
487
+ self.canvas[2].put_image_data(self.sel_buffer, x0, y0)
488
+ self.canvas[-1].clear()
489
+ self.canvas[-1].stroke_style = "#0a0a0a"
490
+ self.canvas[-1].stroke_rect(x0, y0, w, h)
491
+ self.canvas[-1].stroke_style = "#ffffff"
492
+ offset=round(self.scale) if self.scale>1.0 else 1
493
+ self.canvas[-1].stroke_rect(x0 - offset, y0 - offset, w + offset*2, h + offset*2)
494
+ self.canvas[-1].stroke_style = "#000000"
495
+ self.canvas[-1].stroke_rect(x0 - offset*2, y0 - offset*2, w + offset*4, h + offset*4)
496
+
497
+ def write_selection_to_buffer(self):
498
+ x0, y0 = self.cursor
499
+ x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
500
+ self.buffer[y0:y1, x0:x1] = self.sel_buffer
501
+ self.sel_dirty = False
502
+ self.sel_buffer = np.zeros(
503
+ (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
504
+ )
505
+ self.buffer_dirty = True
506
+ self.buffer_updated = True
507
+ # self.canvas[2].clear()
508
+
509
+ def read_selection_from_buffer(self):
510
+ x0, y0 = self.cursor
511
+ x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
512
+ self.sel_buffer = self.buffer[y0:y1, x0:x1]
513
+ self.sel_dirty = False
514
+
515
+ def base64_to_numpy(self, base64_str):
516
+ try:
517
+ data = base64.b64decode(str(base64_str))
518
+ pil = Image.open(io.BytesIO(data))
519
+ arr = np.array(pil)
520
+ ret = arr
521
+ except:
522
+ ret = np.tile(
523
+ np.array([255, 0, 0, 255], dtype=np.uint8),
524
+ (self.selection_size_h, self.selection_size_w, 1),
525
+ )
526
+ return ret
527
+
528
+ def numpy_to_base64(self, arr):
529
+ out_pil = Image.fromarray(arr)
530
+ out_buffer = io.BytesIO()
531
+ out_pil.save(out_buffer, format="PNG")
532
+ out_buffer.seek(0)
533
+ base64_bytes = base64.b64encode(out_buffer.read())
534
+ base64_str = base64_bytes.decode("ascii")
535
+ return base64_str
536
+
537
+ def sync_to_data(self):
538
+ if self.sel_dirty:
539
+ self.write_selection_to_buffer()
540
+ self.canvas[2].clear()
541
+ self.draw_buffer()
542
+ if self.buffer_dirty:
543
+ self.buffer2data()
544
+
545
+ def sync_to_buffer(self):
546
+ if self.sel_dirty:
547
+ self.canvas[2].clear()
548
+ self.write_selection_to_buffer()
549
+ self.draw_buffer()
550
+
551
+ def resize(self,width,height,scale=None,**kwargs):
552
+ self.display_width=width
553
+ self.display_height=height
554
+ for canvas in self.canvas:
555
+ prepare_canvas(width=width,height=height,canvas=canvas.canvas)
556
+ setup_overlay(width,height)
557
+ if scale is None:
558
+ scale=1
559
+ self.update_scale(scale)
560
+
561
+
562
+ def save(self):
563
+ self.sync_to_data()
564
+ state={}
565
+ state["width"]=self.display_width
566
+ state["height"]=self.display_height
567
+ state["selection_width"]=self.selection_size_w
568
+ state["selection_height"]=self.selection_size_h
569
+ state["view_pos"]=self.view_pos[:]
570
+ state["cursor"]=self.cursor[:]
571
+ state["scale"]=self.scale
572
+ keys=list(self.data.keys())
573
+ data={}
574
+ for key in keys:
575
+ if self.data[key].sum()>0:
576
+ data[f"{key[0]},{key[1]}"]=self.numpy_to_base64(self.data[key])
577
+ state["data"]=data
578
+ return json.dumps(state)
579
+
580
+ def load(self, state_json):
581
+ self.reset()
582
+ state=json.loads(state_json)
583
+ self.display_width=state["width"]
584
+ self.display_height=state["height"]
585
+ self.selection_size_w=state["selection_width"]
586
+ self.selection_size_h=state["selection_height"]
587
+ self.view_pos=state["view_pos"][:]
588
+ self.cursor=state["cursor"][:]
589
+ self.scale=state["scale"]
590
+ self.resize(state["width"],state["height"],scale=state["scale"])
591
+ for k,v in state["data"].items():
592
+ key=tuple(map(int,k.split(",")))
593
+ self.data[key]=self.base64_to_numpy(v)
594
+ self.data2buffer()
595
+ self.display()
596
+
597
+ def display(self):
598
+ self.clear_background()
599
+ self.draw_buffer()
600
+ self.draw_selection_box()
601
+
602
+ def reset(self):
603
+ self.data.clear()
604
+ self.buffer*=0
605
+ self.buffer_dirty=False
606
+ self.buffer_updated=False
607
+ self.sel_buffer*=0
608
+ self.sel_dirty=False
609
+ self.view_pos = [0, 0]
610
+ self.clear_background()
611
+ for i in range(1,len(self.canvas)-1):
612
+ self.canvas[i].clear()
613
+
614
+ def export(self):
615
+ self.sync_to_data()
616
+ xmin, xmax, ymin, ymax = 0, 0, 0, 0
617
+ if len(self.data.keys()) == 0:
618
+ return np.zeros(
619
+ (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
620
+ )
621
+ for xi, yi in self.data.keys():
622
+ buf = self.data[(xi, yi)]
623
+ if buf.sum() > 0:
624
+ xmin = min(xi, xmin)
625
+ xmax = max(xi, xmax)
626
+ ymin = min(yi, ymin)
627
+ ymax = max(yi, ymax)
628
+ yn = ymax - ymin + 1
629
+ xn = xmax - xmin + 1
630
+ image = np.zeros(
631
+ (yn * self.patch_size, xn * self.patch_size, 4), dtype=np.uint8
632
+ )
633
+ for xi, yi in self.data.keys():
634
+ buf = self.data[(xi, yi)]
635
+ if buf.sum() > 0:
636
+ y0 = (yi - ymin) * self.patch_size
637
+ x0 = (xi - xmin) * self.patch_size
638
+ image[y0 : y0 + self.patch_size, x0 : x0 + self.patch_size] = buf
639
+ ylst, xlst = image[:, :, -1].nonzero()
640
+ if len(ylst) > 0:
641
+ yt, xt = ylst.min(), xlst.min()
642
+ yb, xb = ylst.max(), xlst.max()
643
+ image = image[yt : yb + 1, xt : xb + 1]
644
+ return image
645
+ else:
646
+ return np.zeros(
647
+ (self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
648
+ )
config.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ shortcut:
2
+ clear: Escape
3
+ load: Ctrl+o
4
+ save: Ctrl+s
5
+ export: Ctrl+e
6
+ upload: Ctrl+u
7
+ selection: 1
8
+ canvas: 2
9
+ eraser: 3
10
+ outpaint: d
11
+ accept: a
12
+ cancel: c
13
+ retry: r
14
+ prev: q
15
+ next: e
16
+ zoom_in: z
17
+ zoom_out: x
18
+ random_seed: s
convert_checkpoint.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py
16
+ """ Conversion script for the LDM checkpoints. """
17
+
18
+ import argparse
19
+ import os
20
+
21
+ import torch
22
+
23
+
24
+ try:
25
+ from omegaconf import OmegaConf
26
+ except ImportError:
27
+ raise ImportError(
28
+ "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
29
+ )
30
+
31
+ from diffusers import (
32
+ AutoencoderKL,
33
+ DDIMScheduler,
34
+ LDMTextToImagePipeline,
35
+ LMSDiscreteScheduler,
36
+ PNDMScheduler,
37
+ StableDiffusionPipeline,
38
+ UNet2DConditionModel,
39
+ )
40
+ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
41
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
42
+ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
43
+
44
+
45
+ def shave_segments(path, n_shave_prefix_segments=1):
46
+ """
47
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
48
+ """
49
+ if n_shave_prefix_segments >= 0:
50
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
51
+ else:
52
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
53
+
54
+
55
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
56
+ """
57
+ Updates paths inside resnets to the new naming scheme (local renaming)
58
+ """
59
+ mapping = []
60
+ for old_item in old_list:
61
+ new_item = old_item.replace("in_layers.0", "norm1")
62
+ new_item = new_item.replace("in_layers.2", "conv1")
63
+
64
+ new_item = new_item.replace("out_layers.0", "norm2")
65
+ new_item = new_item.replace("out_layers.3", "conv2")
66
+
67
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
68
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
69
+
70
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
71
+
72
+ mapping.append({"old": old_item, "new": new_item})
73
+
74
+ return mapping
75
+
76
+
77
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
78
+ """
79
+ Updates paths inside resnets to the new naming scheme (local renaming)
80
+ """
81
+ mapping = []
82
+ for old_item in old_list:
83
+ new_item = old_item
84
+
85
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
86
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
87
+
88
+ mapping.append({"old": old_item, "new": new_item})
89
+
90
+ return mapping
91
+
92
+
93
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
94
+ """
95
+ Updates paths inside attentions to the new naming scheme (local renaming)
96
+ """
97
+ mapping = []
98
+ for old_item in old_list:
99
+ new_item = old_item
100
+
101
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
102
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
103
+
104
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
105
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
106
+
107
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
108
+
109
+ mapping.append({"old": old_item, "new": new_item})
110
+
111
+ return mapping
112
+
113
+
114
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
115
+ """
116
+ Updates paths inside attentions to the new naming scheme (local renaming)
117
+ """
118
+ mapping = []
119
+ for old_item in old_list:
120
+ new_item = old_item
121
+
122
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
123
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
124
+
125
+ new_item = new_item.replace("q.weight", "query.weight")
126
+ new_item = new_item.replace("q.bias", "query.bias")
127
+
128
+ new_item = new_item.replace("k.weight", "key.weight")
129
+ new_item = new_item.replace("k.bias", "key.bias")
130
+
131
+ new_item = new_item.replace("v.weight", "value.weight")
132
+ new_item = new_item.replace("v.bias", "value.bias")
133
+
134
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
135
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
136
+
137
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
138
+
139
+ mapping.append({"old": old_item, "new": new_item})
140
+
141
+ return mapping
142
+
143
+
144
+ def assign_to_checkpoint(
145
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
146
+ ):
147
+ """
148
+ This does the final conversion step: take locally converted weights and apply a global renaming
149
+ to them. It splits attention layers, and takes into account additional replacements
150
+ that may arise.
151
+
152
+ Assigns the weights to the new checkpoint.
153
+ """
154
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
155
+
156
+ # Splits the attention layers into three variables.
157
+ if attention_paths_to_split is not None:
158
+ for path, path_map in attention_paths_to_split.items():
159
+ old_tensor = old_checkpoint[path]
160
+ channels = old_tensor.shape[0] // 3
161
+
162
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
163
+
164
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
165
+
166
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
167
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
168
+
169
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
170
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
171
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
172
+
173
+ for path in paths:
174
+ new_path = path["new"]
175
+
176
+ # These have already been assigned
177
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
178
+ continue
179
+
180
+ # Global renaming happens here
181
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
182
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
183
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
184
+
185
+ if additional_replacements is not None:
186
+ for replacement in additional_replacements:
187
+ new_path = new_path.replace(replacement["old"], replacement["new"])
188
+
189
+ # proj_attn.weight has to be converted from conv 1D to linear
190
+ if "proj_attn.weight" in new_path:
191
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
192
+ else:
193
+ checkpoint[new_path] = old_checkpoint[path["old"]]
194
+
195
+
196
+ def conv_attn_to_linear(checkpoint):
197
+ keys = list(checkpoint.keys())
198
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
199
+ for key in keys:
200
+ if ".".join(key.split(".")[-2:]) in attn_keys:
201
+ if checkpoint[key].ndim > 2:
202
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
203
+ elif "proj_attn.weight" in key:
204
+ if checkpoint[key].ndim > 2:
205
+ checkpoint[key] = checkpoint[key][:, :, 0]
206
+
207
+
208
+ def create_unet_diffusers_config(original_config):
209
+ """
210
+ Creates a config for the diffusers based on the config of the LDM model.
211
+ """
212
+ unet_params = original_config.model.params.unet_config.params
213
+
214
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
215
+
216
+ down_block_types = []
217
+ resolution = 1
218
+ for i in range(len(block_out_channels)):
219
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
220
+ down_block_types.append(block_type)
221
+ if i != len(block_out_channels) - 1:
222
+ resolution *= 2
223
+
224
+ up_block_types = []
225
+ for i in range(len(block_out_channels)):
226
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
227
+ up_block_types.append(block_type)
228
+ resolution //= 2
229
+
230
+ config = dict(
231
+ sample_size=unet_params.image_size,
232
+ in_channels=unet_params.in_channels,
233
+ out_channels=unet_params.out_channels,
234
+ down_block_types=tuple(down_block_types),
235
+ up_block_types=tuple(up_block_types),
236
+ block_out_channels=tuple(block_out_channels),
237
+ layers_per_block=unet_params.num_res_blocks,
238
+ cross_attention_dim=unet_params.context_dim,
239
+ attention_head_dim=unet_params.num_heads,
240
+ )
241
+
242
+ return config
243
+
244
+
245
+ def create_vae_diffusers_config(original_config):
246
+ """
247
+ Creates a config for the diffusers based on the config of the LDM model.
248
+ """
249
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
250
+ _ = original_config.model.params.first_stage_config.params.embed_dim
251
+
252
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
253
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
254
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
255
+
256
+ config = dict(
257
+ sample_size=vae_params.resolution,
258
+ in_channels=vae_params.in_channels,
259
+ out_channels=vae_params.out_ch,
260
+ down_block_types=tuple(down_block_types),
261
+ up_block_types=tuple(up_block_types),
262
+ block_out_channels=tuple(block_out_channels),
263
+ latent_channels=vae_params.z_channels,
264
+ layers_per_block=vae_params.num_res_blocks,
265
+ )
266
+ return config
267
+
268
+
269
+ def create_diffusers_schedular(original_config):
270
+ schedular = DDIMScheduler(
271
+ num_train_timesteps=original_config.model.params.timesteps,
272
+ beta_start=original_config.model.params.linear_start,
273
+ beta_end=original_config.model.params.linear_end,
274
+ beta_schedule="scaled_linear",
275
+ )
276
+ return schedular
277
+
278
+
279
+ def create_ldm_bert_config(original_config):
280
+ bert_params = original_config.model.parms.cond_stage_config.params
281
+ config = LDMBertConfig(
282
+ d_model=bert_params.n_embed,
283
+ encoder_layers=bert_params.n_layer,
284
+ encoder_ffn_dim=bert_params.n_embed * 4,
285
+ )
286
+ return config
287
+
288
+
289
+ def convert_ldm_unet_checkpoint(checkpoint, config):
290
+ """
291
+ Takes a state dict and a config, and returns a converted checkpoint.
292
+ """
293
+
294
+ # extract state_dict for UNet
295
+ unet_state_dict = {}
296
+ unet_key = "model.diffusion_model."
297
+ keys = list(checkpoint.keys())
298
+ for key in keys:
299
+ if key.startswith(unet_key):
300
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
301
+
302
+ new_checkpoint = {}
303
+
304
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
305
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
306
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
307
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
308
+
309
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
310
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
311
+
312
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
313
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
314
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
315
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
316
+
317
+ # Retrieves the keys for the input blocks only
318
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
319
+ input_blocks = {
320
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
321
+ for layer_id in range(num_input_blocks)
322
+ }
323
+
324
+ # Retrieves the keys for the middle blocks only
325
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
326
+ middle_blocks = {
327
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
328
+ for layer_id in range(num_middle_blocks)
329
+ }
330
+
331
+ # Retrieves the keys for the output blocks only
332
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
333
+ output_blocks = {
334
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
335
+ for layer_id in range(num_output_blocks)
336
+ }
337
+
338
+ for i in range(1, num_input_blocks):
339
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
340
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
341
+
342
+ resnets = [
343
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
344
+ ]
345
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
346
+
347
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
348
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
349
+ f"input_blocks.{i}.0.op.weight"
350
+ )
351
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
352
+ f"input_blocks.{i}.0.op.bias"
353
+ )
354
+
355
+ paths = renew_resnet_paths(resnets)
356
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
357
+ assign_to_checkpoint(
358
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
359
+ )
360
+
361
+ if len(attentions):
362
+ paths = renew_attention_paths(attentions)
363
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
364
+ assign_to_checkpoint(
365
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
366
+ )
367
+
368
+ resnet_0 = middle_blocks[0]
369
+ attentions = middle_blocks[1]
370
+ resnet_1 = middle_blocks[2]
371
+
372
+ resnet_0_paths = renew_resnet_paths(resnet_0)
373
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
374
+
375
+ resnet_1_paths = renew_resnet_paths(resnet_1)
376
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
377
+
378
+ attentions_paths = renew_attention_paths(attentions)
379
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
380
+ assign_to_checkpoint(
381
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
382
+ )
383
+
384
+ for i in range(num_output_blocks):
385
+ block_id = i // (config["layers_per_block"] + 1)
386
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
387
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
388
+ output_block_list = {}
389
+
390
+ for layer in output_block_layers:
391
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
392
+ if layer_id in output_block_list:
393
+ output_block_list[layer_id].append(layer_name)
394
+ else:
395
+ output_block_list[layer_id] = [layer_name]
396
+
397
+ if len(output_block_list) > 1:
398
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
399
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
400
+
401
+ resnet_0_paths = renew_resnet_paths(resnets)
402
+ paths = renew_resnet_paths(resnets)
403
+
404
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
405
+ assign_to_checkpoint(
406
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
407
+ )
408
+
409
+ if ["conv.weight", "conv.bias"] in output_block_list.values():
410
+ index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
411
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
412
+ f"output_blocks.{i}.{index}.conv.weight"
413
+ ]
414
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
415
+ f"output_blocks.{i}.{index}.conv.bias"
416
+ ]
417
+
418
+ # Clear attentions as they have been attributed above.
419
+ if len(attentions) == 2:
420
+ attentions = []
421
+
422
+ if len(attentions):
423
+ paths = renew_attention_paths(attentions)
424
+ meta_path = {
425
+ "old": f"output_blocks.{i}.1",
426
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
427
+ }
428
+ assign_to_checkpoint(
429
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
430
+ )
431
+ else:
432
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
433
+ for path in resnet_0_paths:
434
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
435
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
436
+
437
+ new_checkpoint[new_path] = unet_state_dict[old_path]
438
+
439
+ return new_checkpoint
440
+
441
+
442
+ def convert_ldm_vae_checkpoint(checkpoint, config):
443
+ # extract state dict for VAE
444
+ vae_state_dict = {}
445
+ vae_key = "first_stage_model."
446
+ keys = list(checkpoint.keys())
447
+ for key in keys:
448
+ if key.startswith(vae_key):
449
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
450
+
451
+ new_checkpoint = {}
452
+
453
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
454
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
455
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
456
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
457
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
458
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
459
+
460
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
461
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
462
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
463
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
464
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
465
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
466
+
467
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
468
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
469
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
470
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
471
+
472
+ # Retrieves the keys for the encoder down blocks only
473
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
474
+ down_blocks = {
475
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
476
+ }
477
+
478
+ # Retrieves the keys for the decoder up blocks only
479
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
480
+ up_blocks = {
481
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
482
+ }
483
+
484
+ for i in range(num_down_blocks):
485
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
486
+
487
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
488
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
489
+ f"encoder.down.{i}.downsample.conv.weight"
490
+ )
491
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
492
+ f"encoder.down.{i}.downsample.conv.bias"
493
+ )
494
+
495
+ paths = renew_vae_resnet_paths(resnets)
496
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
497
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
498
+
499
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
500
+ num_mid_res_blocks = 2
501
+ for i in range(1, num_mid_res_blocks + 1):
502
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
503
+
504
+ paths = renew_vae_resnet_paths(resnets)
505
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
506
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
507
+
508
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
509
+ paths = renew_vae_attention_paths(mid_attentions)
510
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
511
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
512
+ conv_attn_to_linear(new_checkpoint)
513
+
514
+ for i in range(num_up_blocks):
515
+ block_id = num_up_blocks - 1 - i
516
+ resnets = [
517
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
518
+ ]
519
+
520
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
521
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
522
+ f"decoder.up.{block_id}.upsample.conv.weight"
523
+ ]
524
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
525
+ f"decoder.up.{block_id}.upsample.conv.bias"
526
+ ]
527
+
528
+ paths = renew_vae_resnet_paths(resnets)
529
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
530
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
531
+
532
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
533
+ num_mid_res_blocks = 2
534
+ for i in range(1, num_mid_res_blocks + 1):
535
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
536
+
537
+ paths = renew_vae_resnet_paths(resnets)
538
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
539
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
540
+
541
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
542
+ paths = renew_vae_attention_paths(mid_attentions)
543
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
544
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
545
+ conv_attn_to_linear(new_checkpoint)
546
+ return new_checkpoint
547
+
548
+
549
+ def convert_ldm_bert_checkpoint(checkpoint, config):
550
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
551
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
552
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
553
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
554
+
555
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
556
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
557
+
558
+ def _copy_linear(hf_linear, pt_linear):
559
+ hf_linear.weight = pt_linear.weight
560
+ hf_linear.bias = pt_linear.bias
561
+
562
+ def _copy_layer(hf_layer, pt_layer):
563
+ # copy layer norms
564
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
565
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
566
+
567
+ # copy attn
568
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
569
+
570
+ # copy MLP
571
+ pt_mlp = pt_layer[1][1]
572
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
573
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
574
+
575
+ def _copy_layers(hf_layers, pt_layers):
576
+ for i, hf_layer in enumerate(hf_layers):
577
+ if i != 0:
578
+ i += i
579
+ pt_layer = pt_layers[i : i + 2]
580
+ _copy_layer(hf_layer, pt_layer)
581
+
582
+ hf_model = LDMBertModel(config).eval()
583
+
584
+ # copy embeds
585
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
586
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
587
+
588
+ # copy layer norm
589
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
590
+
591
+ # copy hidden layers
592
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
593
+
594
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
595
+
596
+ return hf_model
597
+
598
+
599
+ def convert_ldm_clip_checkpoint(checkpoint):
600
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
601
+
602
+ keys = list(checkpoint.keys())
603
+
604
+ text_model_dict = {}
605
+
606
+ for key in keys:
607
+ if key.startswith("cond_stage_model.transformer"):
608
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
609
+
610
+ text_model.load_state_dict(text_model_dict)
611
+
612
+ return text_model
613
+
614
+ import os
615
+ def convert_checkpoint(checkpoint_path, inpainting=False):
616
+ parser = argparse.ArgumentParser()
617
+
618
+ parser.add_argument(
619
+ "--checkpoint_path", default=checkpoint_path, type=str, help="Path to the checkpoint to convert."
620
+ )
621
+ # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
622
+ parser.add_argument(
623
+ "--original_config_file",
624
+ default=None,
625
+ type=str,
626
+ help="The YAML config file corresponding to the original architecture.",
627
+ )
628
+ parser.add_argument(
629
+ "--scheduler_type",
630
+ default="pndm",
631
+ type=str,
632
+ help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
633
+ )
634
+ parser.add_argument("--dump_path", default=None, type=str, help="Path to the output model.")
635
+
636
+ args = parser.parse_args([])
637
+ if args.original_config_file is None:
638
+ if inpainting:
639
+ args.original_config_file = "./models/v1-inpainting-inference.yaml"
640
+ else:
641
+ args.original_config_file = "./models/v1-inference.yaml"
642
+
643
+ original_config = OmegaConf.load(args.original_config_file)
644
+ checkpoint = torch.load(args.checkpoint_path)["state_dict"]
645
+
646
+ num_train_timesteps = original_config.model.params.timesteps
647
+ beta_start = original_config.model.params.linear_start
648
+ beta_end = original_config.model.params.linear_end
649
+ if args.scheduler_type == "pndm":
650
+ scheduler = PNDMScheduler(
651
+ beta_end=beta_end,
652
+ beta_schedule="scaled_linear",
653
+ beta_start=beta_start,
654
+ num_train_timesteps=num_train_timesteps,
655
+ skip_prk_steps=True,
656
+ )
657
+ elif args.scheduler_type == "lms":
658
+ scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
659
+ elif args.scheduler_type == "ddim":
660
+ scheduler = DDIMScheduler(
661
+ beta_start=beta_start,
662
+ beta_end=beta_end,
663
+ beta_schedule="scaled_linear",
664
+ clip_sample=False,
665
+ set_alpha_to_one=False,
666
+ )
667
+ else:
668
+ raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
669
+
670
+ # Convert the UNet2DConditionModel model.
671
+ unet_config = create_unet_diffusers_config(original_config)
672
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
673
+
674
+ unet = UNet2DConditionModel(**unet_config)
675
+ unet.load_state_dict(converted_unet_checkpoint)
676
+
677
+ # Convert the VAE model.
678
+ vae_config = create_vae_diffusers_config(original_config)
679
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
680
+
681
+ vae = AutoencoderKL(**vae_config)
682
+ vae.load_state_dict(converted_vae_checkpoint)
683
+
684
+ # Convert the text model.
685
+ text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
686
+ if text_model_type == "FrozenCLIPEmbedder":
687
+ text_model = convert_ldm_clip_checkpoint(checkpoint)
688
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
689
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
690
+ feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
691
+ pipe = StableDiffusionPipeline(
692
+ vae=vae,
693
+ text_encoder=text_model,
694
+ tokenizer=tokenizer,
695
+ unet=unet,
696
+ scheduler=scheduler,
697
+ safety_checker=safety_checker,
698
+ feature_extractor=feature_extractor,
699
+ )
700
+ else:
701
+ text_config = create_ldm_bert_config(original_config)
702
+ text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
703
+ tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
704
+ pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
705
+
706
+ return pipe
css/w2ui.min.css ADDED
The diff for this file is too large to render. See raw diff
 
index.html ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <html>
2
+ <head>
3
+ <title>Stablediffusion Infinity</title>
4
+ <meta charset="utf-8">
5
+ <link rel="icon" type="image/x-icon" href="./favicon.png">
6
+
7
+ <link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@v0.1.2/css/w2ui.min.css">
8
+ <script type="text/javascript" src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@v0.1.2/js/w2ui.min.js"></script>
9
+ <link rel="stylesheet" type="text/css" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.2.0/css/all.min.css">
10
+ <script src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@v0.1.2/js/fabric.min.js"></script>
11
+ <script defer src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@v0.1.2/js/toolbar.js"></script>
12
+
13
+ <link rel="stylesheet" href="https://pyscript.net/releases/2022.05.1/pyscript.css" />
14
+ <script defer src="https://pyscript.net/releases/2022.05.1/pyscript.js"></script>
15
+
16
+ <style>
17
+ #container {
18
+ position: relative;
19
+ margin:auto;
20
+ display: block;
21
+ }
22
+ #container > canvas {
23
+ position: absolute;
24
+ top: 0;
25
+ left: 0;
26
+ }
27
+ .control {
28
+ display: none;
29
+ }
30
+ </style>
31
+
32
+ </head>
33
+ <body>
34
+ <div>
35
+ <button type="button" class="control" id="export">Export</button>
36
+ <button type="button" class="control" id="outpaint">Outpaint</button>
37
+ <button type="button" class="control" id="undo">Undo</button>
38
+ <button type="button" class="control" id="commit">Commit</button>
39
+ <button type="button" class="control" id="transfer">Transfer</button>
40
+ <button type="button" class="control" id="upload">Upload</button>
41
+ <button type="button" class="control" id="draw">Draw</button>
42
+ <input type="text" id="mode" value="selection" class="control">
43
+ <input type="text" id="setup" value="0" class="control">
44
+ <input type="text" id="upload_content" value="0" class="control">
45
+ <textarea rows="1" id="selbuffer" name="selbuffer" class="control"></textarea>
46
+ <fieldset class="control">
47
+ <div>
48
+ <input type="radio" id="mode0" name="mode" value="0" checked>
49
+ <label for="mode0">SelBox</label>
50
+ </div>
51
+ <div>
52
+ <input type="radio" id="mode1" name="mode" value="1">
53
+ <label for="mode1">Image</label>
54
+ </div>
55
+ <div>
56
+ <input type="radio" id="mode2" name="mode" value="2">
57
+ <label for="mode2">Brush</label>
58
+ </div>
59
+ </fieldset>
60
+ </div>
61
+ <div id = "outer_container">
62
+ <div id = "container">
63
+ <canvas id = "canvas0"></canvas>
64
+ <canvas id = "canvas1"></canvas>
65
+ <canvas id = "canvas2"></canvas>
66
+ <canvas id = "canvas3"></canvas>
67
+ <canvas id = "canvas4"></canvas>
68
+ <div id="overlay_container" style="pointer-events: none">
69
+ <canvas id = "overlay_canvas" width="1" height="1"></canvas>
70
+ </div>
71
+ </div>
72
+ <input type="file" name="file" id="upload_file" accept="image/*" hidden>
73
+ <input type="file" name="state" id="upload_state" accept=".sdinf" hidden>
74
+ <div style="position: relative;">
75
+ <div id="toolbar" style></div>
76
+ </div>
77
+ </div>
78
+ <py-env>
79
+ - numpy
80
+ - Pillow
81
+ - paths:
82
+ - ./canvas.py
83
+ </py-env>
84
+
85
+ <py-script>
86
+ from pyodide import to_js, create_proxy
87
+ from PIL import Image
88
+ import io
89
+ import time
90
+ import base64
91
+ from collections import deque
92
+ import numpy as np
93
+ from js import (
94
+ console,
95
+ document,
96
+ parent,
97
+ devicePixelRatio,
98
+ ImageData,
99
+ Uint8ClampedArray,
100
+ CanvasRenderingContext2D as Context2d,
101
+ requestAnimationFrame,
102
+ window,
103
+ encodeURIComponent,
104
+ w2ui,
105
+ update_eraser,
106
+ update_scale,
107
+ adjust_selection,
108
+ update_count,
109
+ enable_result_lst,
110
+ setup_shortcut,
111
+ update_undo_redo,
112
+ )
113
+
114
+
115
+ from canvas import InfCanvas
116
+
117
+
118
+ class History:
119
+ def __init__(self,maxlen=10):
120
+ self.idx=-1
121
+ self.undo_lst=deque([],maxlen=maxlen)
122
+ self.redo_lst=deque([],maxlen=maxlen)
123
+ self.state=None
124
+
125
+ def undo(self):
126
+ cur=None
127
+ if len(self.undo_lst):
128
+ cur=self.undo_lst.pop()
129
+ self.redo_lst.appendleft(cur)
130
+ return cur
131
+ def redo(self):
132
+ cur=None
133
+ if len(self.redo_lst):
134
+ cur=self.redo_lst.popleft()
135
+ self.undo_lst.append(cur)
136
+ return cur
137
+
138
+ def check(self):
139
+ return len(self.undo_lst)>0,len(self.redo_lst)>0
140
+
141
+ def append(self,state,update=True):
142
+ self.redo_lst.clear()
143
+ self.undo_lst.append(state)
144
+ if update:
145
+ update_undo_redo(*self.check())
146
+
147
+ history = History()
148
+
149
+ base_lst = [None]
150
+ async def draw_canvas() -> None:
151
+ width=1024
152
+ height=600
153
+ canvas=InfCanvas(1024,600)
154
+ update_eraser(canvas.eraser_size,min(canvas.selection_size_h,canvas.selection_size_w))
155
+ document.querySelector("#container").style.height= f"{height}px"
156
+ document.querySelector("#container").style.width = f"{width}px"
157
+ canvas.setup_mouse()
158
+ canvas.clear_background()
159
+ canvas.draw_buffer()
160
+ canvas.draw_selection_box()
161
+ base_lst[0]=canvas
162
+
163
+ async def draw_canvas_func(event):
164
+ try:
165
+ app=parent.document.querySelector("gradio-app")
166
+ if app.shadowRoot:
167
+ app=app.shadowRoot
168
+ width=app.querySelector("#canvas_width input").value
169
+ height=app.querySelector("#canvas_height input").value
170
+ selection_size=app.querySelector("#selection_size input").value
171
+ except:
172
+ width=1024
173
+ height=768
174
+ selection_size=384
175
+ document.querySelector("#container").style.width = f"{width}px"
176
+ document.querySelector("#container").style.height= f"{height}px"
177
+ canvas=InfCanvas(int(width),int(height),selection_size=int(selection_size))
178
+ canvas.setup_mouse()
179
+ canvas.clear_background()
180
+ canvas.draw_buffer()
181
+ canvas.draw_selection_box()
182
+ base_lst[0]=canvas
183
+
184
+ async def export_func(event):
185
+ base=base_lst[0]
186
+ arr=base.export()
187
+ base.draw_buffer()
188
+ base.canvas[2].clear()
189
+ base64_str = base.numpy_to_base64(arr)
190
+ time_str = time.strftime("%Y%m%d_%H%M%S")
191
+ link = document.createElement("a")
192
+ if len(event.data)>2 and event.data[2]:
193
+ filename = event.data[2]
194
+ else:
195
+ filename = f"outpaint_{time_str}"
196
+ # link.download = f"sdinf_state_{time_str}.json"
197
+ link.download = f"{filename}.png"
198
+ # link.download = f"outpaint_{time_str}.png"
199
+ link.href = "data:image/png;base64,"+base64_str
200
+ link.click()
201
+ console.log(f"Canvas saved to {filename}.png")
202
+
203
+ img_candidate_lst=[None,0]
204
+
205
+ async def outpaint_func(event):
206
+ base=base_lst[0]
207
+ if len(event.data)==2:
208
+ app=parent.document.querySelector("gradio-app")
209
+ if app.shadowRoot:
210
+ app=app.shadowRoot
211
+ base64_str_raw=app.querySelector("#output textarea").value
212
+ base64_str_lst=base64_str_raw.split(",")
213
+ img_candidate_lst[0]=base64_str_lst
214
+ img_candidate_lst[1]=0
215
+ elif event.data[2]=="next":
216
+ img_candidate_lst[1]+=1
217
+ elif event.data[2]=="prev":
218
+ img_candidate_lst[1]-=1
219
+ enable_result_lst()
220
+ if img_candidate_lst[0] is None:
221
+ return
222
+ lst=img_candidate_lst[0]
223
+ idx=img_candidate_lst[1]
224
+ update_count(idx%len(lst)+1,len(lst))
225
+ arr=base.base64_to_numpy(lst[idx%len(lst)])
226
+ base.fill_selection(arr)
227
+ base.draw_selection_box()
228
+
229
+ async def undo_func(event):
230
+ base=base_lst[0]
231
+ img_candidate_lst[0]=None
232
+ if base.sel_dirty:
233
+ base.sel_buffer = np.zeros((base.selection_size_h, base.selection_size_w, 4), dtype=np.uint8)
234
+ base.sel_dirty = False
235
+ base.canvas[2].clear()
236
+
237
+ async def commit_func(event):
238
+ base=base_lst[0]
239
+ img_candidate_lst[0]=None
240
+ if base.sel_dirty:
241
+ base.write_selection_to_buffer()
242
+ base.draw_buffer()
243
+ base.canvas[2].clear()
244
+ if len(event.data)>2:
245
+ history.append(base.save())
246
+
247
+ async def history_undo_func(event):
248
+ base=base_lst[0]
249
+ if base.buffer_dirty or len(history.redo_lst)>0:
250
+ state=history.undo()
251
+ else:
252
+ history.undo()
253
+ state=history.undo()
254
+ if state is not None:
255
+ base.load(state)
256
+ update_undo_redo(*history.check())
257
+
258
+ async def history_setup_func(event):
259
+ base=base_lst[0]
260
+ history.undo_lst.clear()
261
+ history.redo_lst.clear()
262
+ history.append(base.save(),update=False)
263
+
264
+ async def history_redo_func(event):
265
+ base=base_lst[0]
266
+ if len(history.undo_lst)>0:
267
+ state=history.redo()
268
+ else:
269
+ history.redo()
270
+ state=history.redo()
271
+ if state is not None:
272
+ base.load(state)
273
+ update_undo_redo(*history.check())
274
+
275
+
276
+ async def transfer_func(event):
277
+ base=base_lst[0]
278
+ base.read_selection_from_buffer()
279
+ sel_buffer=base.sel_buffer
280
+ sel_buffer_str=base.numpy_to_base64(sel_buffer)
281
+ app=parent.document.querySelector("gradio-app")
282
+ if app.shadowRoot:
283
+ app=app.shadowRoot
284
+ app.querySelector("#input textarea").value=sel_buffer_str
285
+ app.querySelector("#proceed").click()
286
+
287
+ async def upload_func(event):
288
+ base=base_lst[0]
289
+ # base64_str=event.data[1]
290
+ base64_str=document.querySelector("#upload_content").value
291
+ base64_str=base64_str.split(",")[-1]
292
+ # base64_str=parent.document.querySelector("gradio-app").shadowRoot.querySelector("#upload textarea").value
293
+ arr=base.base64_to_numpy(base64_str)
294
+ h,w,c=base.buffer.shape
295
+ base.sync_to_buffer()
296
+ base.buffer_dirty=True
297
+ mask=arr[:,:,3:4].repeat(4,axis=2)
298
+ base.buffer[mask>0]=0
299
+ # in case mismatch
300
+ base.buffer[0:h,0:w,:]+=arr
301
+ #base.buffer[yo:yo+h,xo:xo+w,0:3]=arr[:,:,0:3]
302
+ #base.buffer[yo:yo+h,xo:xo+w,-1]=arr[:,:,-1]
303
+ base.draw_buffer()
304
+ if len(event.data)>2:
305
+ history.append(base.save())
306
+
307
+ async def setup_shortcut_func(event):
308
+ setup_shortcut(event.data[1])
309
+
310
+
311
+ document.querySelector("#export").addEventListener("click",create_proxy(export_func))
312
+ document.querySelector("#undo").addEventListener("click",create_proxy(undo_func))
313
+ document.querySelector("#commit").addEventListener("click",create_proxy(commit_func))
314
+ document.querySelector("#outpaint").addEventListener("click",create_proxy(outpaint_func))
315
+ document.querySelector("#upload").addEventListener("click",create_proxy(upload_func))
316
+
317
+ document.querySelector("#transfer").addEventListener("click",create_proxy(transfer_func))
318
+ document.querySelector("#draw").addEventListener("click",create_proxy(draw_canvas_func))
319
+
320
+ async def setup_func():
321
+ document.querySelector("#setup").value="1"
322
+
323
+ async def reset_func(event):
324
+ base=base_lst[0]
325
+ base.reset()
326
+
327
+ async def load_func(event):
328
+ base=base_lst[0]
329
+ base.load(event.data[1])
330
+
331
+ async def save_func(event):
332
+ base=base_lst[0]
333
+ json_str=base.save()
334
+ time_str = time.strftime("%Y%m%d_%H%M%S")
335
+ link = document.createElement("a")
336
+ if len(event.data)>2 and event.data[2]:
337
+ filename = str(event.data[2]).strip()
338
+ else:
339
+ filename = f"outpaint_{time_str}"
340
+ # link.download = f"sdinf_state_{time_str}.json"
341
+ link.download = f"{filename}.sdinf"
342
+ link.href = "data:text/json;charset=utf-8,"+encodeURIComponent(json_str)
343
+ link.click()
344
+
345
+ async def prev_result_func(event):
346
+ base=base_lst[0]
347
+ base.reset()
348
+
349
+ async def next_result_func(event):
350
+ base=base_lst[0]
351
+ base.reset()
352
+
353
+ async def zoom_in_func(event):
354
+ base=base_lst[0]
355
+ scale=base.scale
356
+ if scale>=0.2:
357
+ scale-=0.1
358
+ if len(event.data)>2:
359
+ base.update_scale(scale,int(event.data[2]),int(event.data[3]))
360
+ else:
361
+ base.update_scale(scale)
362
+ scale=base.scale
363
+ update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
364
+
365
+ async def zoom_out_func(event):
366
+ base=base_lst[0]
367
+ scale=base.scale
368
+ if scale<10:
369
+ scale+=0.1
370
+ console.log(len(event.data))
371
+ if len(event.data)>2:
372
+ base.update_scale(scale,int(event.data[2]),int(event.data[3]))
373
+ else:
374
+ base.update_scale(scale)
375
+ scale=base.scale
376
+ update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
377
+
378
+ async def sync_func(event):
379
+ base=base_lst[0]
380
+ base.sync_to_buffer()
381
+ base.canvas[2].clear()
382
+
383
+ async def eraser_size_func(event):
384
+ base=base_lst[0]
385
+ eraser_size=min(int(event.data[1]),min(base.selection_size_h,base.selection_size_w))
386
+ eraser_size=max(8,eraser_size)
387
+ base.eraser_size=eraser_size
388
+
389
+ async def resize_selection_func(event):
390
+ base=base_lst[0]
391
+ cursor=base.cursor
392
+ if len(event.data)>3:
393
+ console.log(event.data)
394
+ base.cursor[0]=int(event.data[1])
395
+ base.cursor[1]=int(event.data[2])
396
+ base.selection_size_w=int(event.data[3])//8*8
397
+ base.selection_size_h=int(event.data[4])//8*8
398
+ base.refine_selection()
399
+ base.draw_selection_box()
400
+ elif len(event.data)>2:
401
+ base.draw_selection_box()
402
+ else:
403
+ base.canvas[-1].clear()
404
+ adjust_selection(cursor[0],cursor[1],base.selection_size_w,base.selection_size_h)
405
+
406
+ async def eraser_func(event):
407
+ base=base_lst[0]
408
+ if event.data[1]!="eraser":
409
+ base.canvas[-2].clear()
410
+ else:
411
+ x,y=base.mouse_pos
412
+ base.draw_eraser(x,y)
413
+
414
+ async def resize_func(event):
415
+ base=base_lst[0]
416
+ width=int(event.data[1])
417
+ height=int(event.data[2])
418
+ if width>=256 and height>=256:
419
+ if max(base.selection_size_h,base.selection_size_w)>min(width,height):
420
+ base.selection_size_h=256
421
+ base.selection_size_w=256
422
+ base.resize(width,height)
423
+
424
+ async def message_func(event):
425
+ if event.data[0]=="click":
426
+ if event.data[1]=="clear":
427
+ await reset_func(event)
428
+ elif event.data[1]=="save":
429
+ await save_func(event)
430
+ elif event.data[1]=="export":
431
+ await export_func(event)
432
+ elif event.data[1]=="accept":
433
+ await commit_func(event)
434
+ elif event.data[1]=="cancel":
435
+ await undo_func(event)
436
+ elif event.data[1]=="zoom_in":
437
+ await zoom_in_func(event)
438
+ elif event.data[1]=="zoom_out":
439
+ await zoom_out_func(event)
440
+ elif event.data[1]=="redo":
441
+ await history_redo_func(event)
442
+ elif event.data[1]=="undo":
443
+ await history_undo_func(event)
444
+ elif event.data[1]=="history":
445
+ await history_setup_func(event)
446
+ elif event.data[0]=="sync":
447
+ await sync_func(event)
448
+ elif event.data[0]=="load":
449
+ await load_func(event)
450
+ elif event.data[0]=="upload":
451
+ await upload_func(event)
452
+ elif event.data[0]=="outpaint":
453
+ await outpaint_func(event)
454
+ elif event.data[0]=="mode":
455
+ if event.data[1]!="selection":
456
+ await sync_func(event)
457
+ await eraser_func(event)
458
+ document.querySelector("#mode").value=event.data[1]
459
+ elif event.data[0]=="transfer":
460
+ await transfer_func(event)
461
+ elif event.data[0]=="setup":
462
+ await draw_canvas_func(event)
463
+ elif event.data[0]=="eraser_size":
464
+ await eraser_size_func(event)
465
+ elif event.data[0]=="resize_selection":
466
+ await resize_selection_func(event)
467
+ elif event.data[0]=="shortcut":
468
+ await setup_shortcut_func(event)
469
+ elif event.data[0]=="resize":
470
+ await resize_func(event)
471
+
472
+ window.addEventListener("message",create_proxy(message_func))
473
+
474
+ import asyncio
475
+
476
+ _ = await asyncio.gather(
477
+ setup_func(),draw_canvas()
478
+ )
479
+ </py-script>
480
+
481
+ </body>
482
+ </html>
interrogate.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MIT License
3
+
4
+ Copyright (c) 2022 pharmapsychotic
5
+ https://github.com/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb
6
+ """
7
+
8
+ import numpy as np
9
+ import os
10
+ import torch
11
+ import torchvision.transforms as T
12
+ import torchvision.transforms.functional as TF
13
+
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torchvision import transforms
17
+ from torchvision.transforms.functional import InterpolationMode
18
+ from transformers import CLIPTokenizer, CLIPModel
19
+ from transformers import CLIPProcessor, CLIPModel
20
+
21
+ data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "blip_model", "data")
22
+ def load_list(filename):
23
+ with open(filename, 'r', encoding='utf-8', errors='replace') as f:
24
+ items = [line.strip() for line in f.readlines()]
25
+ return items
26
+
27
+ artists = load_list(os.path.join(data_path, 'artists.txt'))
28
+ flavors = load_list(os.path.join(data_path, 'flavors.txt'))
29
+ mediums = load_list(os.path.join(data_path, 'mediums.txt'))
30
+ movements = load_list(os.path.join(data_path, 'movements.txt'))
31
+
32
+ sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
33
+ trending_list = [site for site in sites]
34
+ trending_list.extend(["trending on "+site for site in sites])
35
+ trending_list.extend(["featured on "+site for site in sites])
36
+ trending_list.extend([site+" contest winner" for site in sites])
37
+
38
+ device="cpu"
39
+ blip_image_eval_size = 384
40
+ clip_name="openai/clip-vit-large-patch14"
41
+
42
+ blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
43
+
44
+ def generate_caption(blip_model, pil_image, device="cpu"):
45
+ gpu_image = transforms.Compose([
46
+ transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
49
+ ])(pil_image).unsqueeze(0).to(device)
50
+
51
+ with torch.no_grad():
52
+ caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
53
+ return caption[0]
54
+
55
+ def rank(text_features, image_features, text_array, top_count=1):
56
+ top_count = min(top_count, len(text_array))
57
+ similarity = torch.zeros((1, len(text_array)))
58
+ for i in range(image_features.shape[0]):
59
+ similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
60
+ similarity /= image_features.shape[0]
61
+
62
+ top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
63
+ return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
64
+
65
+ class Interrogator:
66
+ def __init__(self) -> None:
67
+ self.tokenizer = CLIPTokenizer.from_pretrained(clip_name)
68
+ try:
69
+ self.get_blip()
70
+ except:
71
+ self.blip_model = None
72
+ self.model = CLIPModel.from_pretrained(clip_name)
73
+ self.processor = CLIPProcessor.from_pretrained(clip_name)
74
+ self.text_feature_lst = [torch.load(os.path.join(data_path, f"{i}.pth")) for i in range(5)]
75
+
76
+ def get_blip(self):
77
+ from blip_model.blip import blip_decoder
78
+ blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base')
79
+ blip_model.eval()
80
+ self.blip_model = blip_model
81
+
82
+
83
+ def interrogate(self,image,use_caption=False):
84
+ if self.blip_model:
85
+ caption = generate_caption(self.blip_model, image)
86
+ else:
87
+ caption = ""
88
+ model,processor=self.model,self.processor
89
+ bests = [[('',0)]]*5
90
+ if True:
91
+ print(f"Interrogating with {clip_name}...")
92
+
93
+ inputs = processor(images=image, return_tensors="pt")
94
+ with torch.no_grad():
95
+ image_features = model.get_image_features(**inputs)
96
+ image_features /= image_features.norm(dim=-1, keepdim=True)
97
+ ranks = [
98
+ rank(self.text_feature_lst[0], image_features, mediums),
99
+ rank(self.text_feature_lst[1], image_features, ["by "+artist for artist in artists]),
100
+ rank(self.text_feature_lst[2], image_features, trending_list),
101
+ rank(self.text_feature_lst[3], image_features, movements),
102
+ rank(self.text_feature_lst[4], image_features, flavors, top_count=3)
103
+ ]
104
+
105
+ for i in range(len(ranks)):
106
+ confidence_sum = 0
107
+ for ci in range(len(ranks[i])):
108
+ confidence_sum += ranks[i][ci][1]
109
+ if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
110
+ bests[i] = ranks[i]
111
+
112
+ flaves = ', '.join([f"{x[0]}" for x in bests[4]])
113
+ medium = bests[0][0][0]
114
+ print(ranks)
115
+ if caption.startswith(medium):
116
+ return f"{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}"
117
+ else:
118
+ return f"{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}"
119
+
120
+
121
+
122
+
123
+
124
+
125
+
js/fabric.min.js ADDED
The diff for this file is too large to render. See raw diff
 
js/keyboard.js ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ window.my_setup_keyboard=setInterval(function(){
3
+ let app=document.querySelector("gradio-app");
4
+ app=app.shadowRoot??app;
5
+ let frame=app.querySelector("#sdinfframe").contentWindow;
6
+ console.log("Check iframe...");
7
+ if(frame.setup_shortcut)
8
+ {
9
+ frame.setup_shortcut(json);
10
+ clearInterval(window.my_setup_keyboard);
11
+ }
12
+ }, 1000);
13
+ var config=JSON.parse(json);
14
+ var key_map={};
15
+ Object.keys(config.shortcut).forEach(k=>{
16
+ key_map[config.shortcut[k]]=k;
17
+ });
18
+ document.addEventListener("keydown", e => {
19
+ if(e.target.tagName!="INPUT"&&e.target.tagName!="GRADIO-APP"&&e.target.tagName!="TEXTAREA")
20
+ {
21
+ let key=e.key;
22
+ if(e.ctrlKey)
23
+ {
24
+ key="Ctrl+"+e.key;
25
+ if(key in key_map)
26
+ {
27
+ e.preventDefault();
28
+ }
29
+ }
30
+ let app=document.querySelector("gradio-app");
31
+ app=app.shadowRoot??app;
32
+ let frame=app.querySelector("#sdinfframe").contentDocument;
33
+ frame.dispatchEvent(
34
+ new KeyboardEvent("keydown", {key: e.key, ctrlKey: e.ctrlKey})
35
+ );
36
+ }
37
+ })
js/mode.js ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ function(mode){
2
+ let app=document.querySelector("gradio-app").shadowRoot;
3
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
4
+ frame.querySelector("#mode").value=mode;
5
+ return mode;
6
+ }
js/outpaint.js ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(a){
2
+ if(!window.my_observe_outpaint)
3
+ {
4
+ console.log("setup outpaint here");
5
+ window.my_observe_outpaint = new MutationObserver(function (event) {
6
+ console.log(event);
7
+ let app=document.querySelector("gradio-app");
8
+ app=app.shadowRoot??app;
9
+ let frame=app.querySelector("#sdinfframe").contentWindow;
10
+ frame.postMessage(["outpaint", ""], "*");
11
+ });
12
+ var app=document.querySelector("gradio-app");
13
+ app=app.shadowRoot??app;
14
+ window.my_observe_outpaint_target=app.querySelector("#output span");
15
+ window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
16
+ attributes: false,
17
+ subtree: true,
18
+ childList: true,
19
+ characterData: true
20
+ });
21
+ }
22
+ return a;
23
+ }
js/proceed.js ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(sel_buffer_str,
2
+ prompt_text,
3
+ negative_prompt_text,
4
+ strength,
5
+ guidance,
6
+ step,
7
+ resize_check,
8
+ fill_mode,
9
+ enable_safety,
10
+ use_correction,
11
+ enable_img2img,
12
+ use_seed,
13
+ seed_val,
14
+ generate_num,
15
+ scheduler,
16
+ scheduler_eta,
17
+ controlnet_union,
18
+ expand_mask,
19
+ color,
20
+ scheduler_type,
21
+ prompt_weight,
22
+ image_resolution,
23
+ img_height,
24
+ img_width,
25
+ loraA,
26
+ loraAscale,
27
+ interrogate_mode,
28
+ state){
29
+ let app=document.querySelector("gradio-app");
30
+ app=app.shadowRoot??app;
31
+ sel_buffer=app.querySelector("#input textarea").value;
32
+ let use_correction_bak=false;
33
+ ({resize_check,enable_safety,enable_img2img,use_seed,seed_val,interrogate_mode}=window.config_obj);
34
+ seed_val=Number(seed_val);
35
+ return [
36
+ sel_buffer,
37
+ prompt_text,
38
+ negative_prompt_text,
39
+ strength,
40
+ guidance,
41
+ step,
42
+ resize_check,
43
+ fill_mode,
44
+ enable_safety,
45
+ use_correction,
46
+ enable_img2img,
47
+ use_seed,
48
+ seed_val,
49
+ generate_num,
50
+ scheduler,
51
+ scheduler_eta,
52
+ controlnet_union,
53
+ expand_mask,
54
+ color,
55
+ scheduler_type,
56
+ prompt_weight,
57
+ image_resolution,
58
+ img_height,
59
+ img_width,
60
+ loraA,
61
+ loraAscale,
62
+ interrogate_mode,
63
+ state,
64
+ ]
65
+ }
js/setup.js ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(token_val, width, height, size, model_choice, model_path){
2
+ let app=document.querySelector("gradio-app");
3
+ app=app.shadowRoot??app;
4
+ app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
5
+ // app.querySelector("#setup_row").style.display="none";
6
+ app.querySelector("#model_path_input").style.display="none";
7
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
8
+
9
+ if(frame.querySelector("#setup").value=="0")
10
+ {
11
+ window.my_setup=setInterval(function(){
12
+ let app=document.querySelector("gradio-app");
13
+ app=app.shadowRoot??app;
14
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
15
+ console.log("Check PyScript...")
16
+ if(frame.querySelector("#setup").value=="1")
17
+ {
18
+ frame.querySelector("#draw").click();
19
+ clearInterval(window.my_setup);
20
+ }
21
+ }, 100)
22
+ }
23
+ else
24
+ {
25
+ frame.querySelector("#draw").click();
26
+ }
27
+ return [token_val, width, height, size, model_choice, model_path];
28
+ }
js/toolbar.js ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://rawgit.com/vitmalina/w2ui/master/dist/w2ui.es6.min.js"
2
+ // import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://cdn.jsdelivr.net/gh/vitmalina/w2ui@master/dist/w2ui.es6.min.js"
3
+
4
+ // https://stackoverflow.com/questions/36280818/how-to-convert-file-to-base64-in-javascript
5
+ function getBase64(file) {
6
+ var reader = new FileReader();
7
+ reader.readAsDataURL(file);
8
+ reader.onload = function () {
9
+ add_image(reader.result);
10
+ // console.log(reader.result);
11
+ };
12
+ reader.onerror = function (error) {
13
+ console.log("Error: ", error);
14
+ };
15
+ }
16
+
17
+ function getText(file) {
18
+ var reader = new FileReader();
19
+ reader.readAsText(file);
20
+ reader.onload = function () {
21
+ window.postMessage(["load",reader.result],"*")
22
+ // console.log(reader.result);
23
+ };
24
+ reader.onerror = function (error) {
25
+ console.log("Error: ", error);
26
+ };
27
+ }
28
+
29
+ document.querySelector("#upload_file").addEventListener("change", (event)=>{
30
+ console.log(event);
31
+ let file = document.querySelector("#upload_file").files[0];
32
+ getBase64(file);
33
+ })
34
+
35
+ document.querySelector("#upload_state").addEventListener("change", (event)=>{
36
+ console.log(event);
37
+ let file = document.querySelector("#upload_state").files[0];
38
+ getText(file);
39
+ })
40
+
41
+ open_setting = function() {
42
+ if (!w2ui.foo) {
43
+ new w2form({
44
+ name: "foo",
45
+ style: "border: 0px; background-color: transparent;",
46
+ fields: [{
47
+ field: "canvas_width",
48
+ type: "int",
49
+ required: true,
50
+ html: {
51
+ label: "Canvas Width"
52
+ }
53
+ },
54
+ {
55
+ field: "canvas_height",
56
+ type: "int",
57
+ required: true,
58
+ html: {
59
+ label: "Canvas Height"
60
+ }
61
+ },
62
+ ],
63
+ record: {
64
+ canvas_width: 1200,
65
+ canvas_height: 600,
66
+ },
67
+ actions: {
68
+ Save() {
69
+ this.validate();
70
+ let record = this.getCleanRecord();
71
+ window.postMessage(["resize",record.canvas_width,record.canvas_height],"*");
72
+ w2popup.close();
73
+ },
74
+ custom: {
75
+ text: "Cancel",
76
+ style: "text-transform: uppercase",
77
+ onClick(event) {
78
+ w2popup.close();
79
+ }
80
+ }
81
+ }
82
+ });
83
+ }
84
+ w2popup.open({
85
+ title: "Form in a Popup",
86
+ body: "<div id='form' style='width: 100%; height: 100%;''></div>",
87
+ style: "padding: 15px 0px 0px 0px",
88
+ width: 500,
89
+ height: 280,
90
+ showMax: true,
91
+ async onToggle(event) {
92
+ await event.complete
93
+ w2ui.foo.resize();
94
+ }
95
+ })
96
+ .then((event) => {
97
+ w2ui.foo.render("#form")
98
+ });
99
+ }
100
+
101
+ var button_lst=["clear", "load", "save", "export", "upload", "selection", "canvas", "eraser", "outpaint", "accept", "cancel", "retry", "prev", "current", "next", "eraser_size_btn", "eraser_size", "resize_selection", "scale", "zoom_in", "zoom_out", "help"];
102
+ var upload_button_lst=['clear', 'load', 'save', "upload", 'export', 'outpaint', 'resize_selection', 'help', "setting", "interrogate"];
103
+ var resize_button_lst=['clear', 'load', 'save', "upload", 'export', "selection", "canvas", "eraser", 'outpaint', 'resize_selection',"zoom_in", "zoom_out", 'help', "setting", "interrogate"];
104
+ var outpaint_button_lst=['clear', 'load', 'save', "canvas", "eraser", "upload", 'export', 'resize_selection', "zoom_in", "zoom_out",'help', "setting", "interrogate", "undo", "redo"];
105
+ var outpaint_result_lst=["accept", "cancel", "retry", "prev", "current", "next"];
106
+ var outpaint_result_func_lst=["accept", "retry", "prev", "current", "next"];
107
+
108
+ function check_button(id,text="",checked=true,tooltip="")
109
+ {
110
+ return { type: "check", id: id, text: text, icon: checked?"fa-solid fa-square-check":"fa-regular fa-square", checked: checked, tooltip: tooltip };
111
+ }
112
+
113
+ var toolbar=new w2toolbar({
114
+ box: "#toolbar",
115
+ name: "toolbar",
116
+ tooltip: "top",
117
+ items: [
118
+ { type: "button", id: "clear", text: "Reset", tooltip: "Reset Canvas", icon: "fa-solid fa-rectangle-xmark" },
119
+ { type: "break" },
120
+ { type: "button", id: "load", tooltip: "Load Canvas", icon: "fa-solid fa-file-import" },
121
+ { type: "button", id: "save", tooltip: "Save Canvas", icon: "fa-solid fa-file-export" },
122
+ { type: "button", id: "export", tooltip: "Export Image", icon: "fa-solid fa-floppy-disk" },
123
+ { type: "break" },
124
+ { type: "button", id: "upload", text: "Upload Image", icon: "fa-solid fa-upload" },
125
+ { type: "break" },
126
+ { type: "radio", id: "selection", group: "1", tooltip: "Selection", icon: "fa-solid fa-arrows-up-down-left-right", checked: true },
127
+ { type: "radio", id: "canvas", group: "1", tooltip: "Canvas", icon: "fa-solid fa-image" },
128
+ { type: "radio", id: "eraser", group: "1", tooltip: "Eraser", icon: "fa-solid fa-eraser" },
129
+ { type: "break" },
130
+ { type: "button", id: "outpaint", text: "Outpaint", tooltip: "Run Outpainting", icon: "fa-solid fa-brush" },
131
+ { type: "button", id: "interrogate", text: "Interrogate", tooltip: "Get a prompt with Clip Interrogator ", icon: "fa-solid fa-magnifying-glass" },
132
+ { type: "break" },
133
+ { type: "button", id: "accept", text: "Accept", tooltip: "Accept current result", icon: "fa-solid fa-check", hidden: true, disabled:true,},
134
+ { type: "button", id: "cancel", text: "Cancel", tooltip: "Cancel current outpainting/error", icon: "fa-solid fa-ban", hidden: true},
135
+ { type: "button", id: "retry", text: "Retry", tooltip: "Retry", icon: "fa-solid fa-rotate", hidden: true, disabled:true,},
136
+ { type: "button", id: "prev", tooltip: "Prev Result", icon: "fa-solid fa-caret-left", hidden: true, disabled:true,},
137
+ { type: "html", id: "current", hidden: true, disabled:true,
138
+ async onRefresh(event) {
139
+ await event.complete
140
+ let fragment = query.html(`
141
+ <div class="w2ui-tb-text">
142
+ <div class="w2ui-tb-count">
143
+ <span>${this.sel_value ?? "1/1"}</span>
144
+ </div> </div>`)
145
+ query(this.box).find("#tb_toolbar_item_current").append(fragment)
146
+ }
147
+ },
148
+ { type: "button", id: "next", tooltip: "Next Result", icon: "fa-solid fa-caret-right", hidden: true,disabled:true,},
149
+ { type: "button", id: "add_image", text: "Add Image", icon: "fa-solid fa-file-circle-plus", hidden: true,disabled:true,},
150
+ { type: "button", id: "delete_image", text: "Delete Image", icon: "fa-solid fa-trash-can", hidden: true,disabled:true,},
151
+ { type: "button", id: "confirm", text: "Confirm", icon: "fa-solid fa-check", hidden: true,disabled:true,},
152
+ { type: "button", id: "cancel_overlay", text: "Cancel", icon: "fa-solid fa-ban", hidden: true,disabled:true,},
153
+ { type: "break" },
154
+ { type: "spacer" },
155
+ { type: "break" },
156
+ { type: "button", id: "eraser_size_btn", tooltip: "Eraser Size", text:"Size", icon: "fa-solid fa-eraser", hidden: true, count: 32},
157
+ { type: "html", id: "eraser_size", hidden: true,
158
+ async onRefresh(event) {
159
+ await event.complete
160
+ // let fragment = query.html(`
161
+ // <input type="number" size="${this.eraser_size ? this.eraser_size.length:"2"}" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">
162
+ // <input type="range" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">`)
163
+ let fragment = query.html(`
164
+ <input type="range" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">
165
+ `)
166
+ fragment.filter("input").on("change", event => {
167
+ this.eraser_size = event.target.value;
168
+ window.overlay.freeDrawingBrush.width=this.eraser_size;
169
+ this.setCount("eraser_size_btn", event.target.value);
170
+ window.postMessage(["eraser_size", event.target.value],"*")
171
+ this.refresh();
172
+ })
173
+ query(this.box).find("#tb_toolbar_item_eraser_size").append(fragment)
174
+ }
175
+ },
176
+ // { type: "button", id: "resize_eraser", tooltip: "Resize Eraser", icon: "fa-solid fa-sliders" },
177
+ { type: "button", id: "resize_selection", text: "Resize Selection", tooltip: "Resize Selection", icon: "fa-solid fa-expand" },
178
+ { type: "break" },
179
+ { type: "html", id: "scale",
180
+ async onRefresh(event) {
181
+ await event.complete
182
+ let fragment = query.html(`
183
+ <div class="">
184
+ <div style="padding: 4px; border: 1px solid silver">
185
+ <span>${this.scale_value ?? "100%"}</span>
186
+ </div></div>`)
187
+ query(this.box).find("#tb_toolbar_item_scale").append(fragment)
188
+ }
189
+ },
190
+ { type: "button", id: "zoom_in", tooltip: "Zoom In", icon: "fa-solid fa-magnifying-glass-plus" },
191
+ { type: "button", id: "zoom_out", tooltip: "Zoom Out", icon: "fa-solid fa-magnifying-glass-minus" },
192
+ { type: "break" },
193
+ { type: "button", id: "help", tooltip: "Help", icon: "fa-solid fa-circle-info" },
194
+ { type: "new-line"},
195
+ { type: "button", id: "setting", text: "Canvas Setting", tooltip: "Resize Canvas Here", icon: "fa-solid fa-sliders" },
196
+ { type: "break" },
197
+ check_button("enable_history","Enable History:",false, "Enable Canvas History"),
198
+ { type: "button", id: "undo", tooltip: "Undo last erasing/last outpainting", icon: "fa-solid fa-rotate-left", disabled: true },
199
+ { type: "button", id: "redo", tooltip: "Redo", icon: "fa-solid fa-rotate-right", disabled: true },
200
+ { type: "break" },
201
+ check_button("enable_img2img","Enable Img2Img",false),
202
+ // check_button("use_correction","Photometric Correction",false),
203
+ check_button("resize_check","Resize Small Input",true),
204
+ check_button("enable_safety","Enable Safety Checker",true),
205
+ check_button("square_selection","Square Selection Only",false),
206
+ {type: "break"},
207
+ check_button("use_seed","Use Seed:",false),
208
+ { type: "html", id: "seed_val",
209
+ async onRefresh(event) {
210
+ await event.complete
211
+ let fragment = query.html(`
212
+ <input type="number" style="margin: 0px 3px; padding: 4px; width:100px;" value="${this.config_obj.seed_val ?? "0"}">`)
213
+ fragment.filter("input").on("change", event => {
214
+ this.config_obj.seed_val = event.target.value;
215
+ parent.config_obj=this.config_obj;
216
+ this.refresh();
217
+ })
218
+ query(this.box).find("#tb_toolbar_item_seed_val").append(fragment)
219
+ }
220
+ },
221
+ { type: "button", id: "random_seed", tooltip: "Set a random seed", icon: "fa-solid fa-dice" },
222
+ ],
223
+ onClick(event) {
224
+ switch(event.target){
225
+ case "setting":
226
+ open_setting();
227
+ break;
228
+ case "upload":
229
+ this.upload_mode=true
230
+ document.querySelector("#overlay_container").style.pointerEvents="auto";
231
+ this.click("canvas");
232
+ this.click("selection");
233
+ this.show("confirm","cancel_overlay","add_image","delete_image");
234
+ this.enable("confirm","cancel_overlay","add_image","delete_image");
235
+ this.disable(...upload_button_lst);
236
+ this.disable("undo","redo")
237
+ query("#upload_file").click();
238
+ if(this.upload_tip)
239
+ {
240
+ this.upload_tip=false;
241
+ w2utils.notify("Note that only visible images will be added to canvas",{timeout:10000,where:query("#container")})
242
+ }
243
+ break;
244
+ case "resize_selection":
245
+ this.resize_mode=true;
246
+ this.disable(...resize_button_lst);
247
+ this.enable("confirm","cancel_overlay");
248
+ this.show("confirm","cancel_overlay");
249
+ window.postMessage(["resize_selection",""],"*");
250
+ document.querySelector("#overlay_container").style.pointerEvents="auto";
251
+ break;
252
+ case "confirm":
253
+ if(this.upload_mode)
254
+ {
255
+ export_image();
256
+ }
257
+ else
258
+ {
259
+ let sel_box=this.selection_box;
260
+ if(sel_box.width*sel_box.height>512*512)
261
+ {
262
+ w2utils.notify("Note that the outpainting will be much slower when the area of selection is larger than 512x512",{timeout:2000,where:query("#container")})
263
+ }
264
+ window.postMessage(["resize_selection",sel_box.x,sel_box.y,sel_box.width,sel_box.height],"*");
265
+ }
266
+ case "cancel_overlay":
267
+ end_overlay();
268
+ this.hide("confirm","cancel_overlay","add_image","delete_image");
269
+ if(this.upload_mode){
270
+ this.enable(...upload_button_lst);
271
+ }
272
+ else
273
+ {
274
+ this.enable(...resize_button_lst);
275
+ window.postMessage(["resize_selection","",""],"*");
276
+ if(event.target=="cancel_overlay")
277
+ {
278
+ this.selection_box=this.selection_box_bak;
279
+ }
280
+ }
281
+ if(this.selection_box)
282
+ {
283
+ this.setCount("resize_selection",`${Math.floor(this.selection_box.width/8)*8}x${Math.floor(this.selection_box.height/8)*8}`);
284
+ }
285
+ this.disable("confirm","cancel_overlay","add_image","delete_image");
286
+ this.upload_mode=false;
287
+ this.resize_mode=false;
288
+ this.click("selection");
289
+ window.update_undo_redo(window.undo_redo_state.undo, window.undo_redo_state.redo);
290
+ break;
291
+ case "add_image":
292
+ query("#upload_file").click();
293
+ break;
294
+ case "delete_image":
295
+ let active_obj = window.overlay.getActiveObject();
296
+ if(active_obj)
297
+ {
298
+ window.overlay.remove(active_obj);
299
+ window.overlay.renderAll();
300
+ }
301
+ else
302
+ {
303
+ w2utils.notify("You need to select an image first",{error:true,timeout:2000,where:query("#container")})
304
+ }
305
+ break;
306
+ case "load":
307
+ query("#upload_state").click();
308
+ this.selection_box=null;
309
+ this.setCount("resize_selection","");
310
+ break;
311
+ case "next":
312
+ case "prev":
313
+ window.postMessage(["outpaint", "", event.target], "*");
314
+ break;
315
+ case "outpaint":
316
+ this.click("selection");
317
+ this.disable(...outpaint_button_lst);
318
+ this.show(...outpaint_result_lst);
319
+ this.disable("undo","redo");
320
+ if(this.outpaint_tip)
321
+ {
322
+ this.outpaint_tip=false;
323
+ w2utils.notify("The canvas stays locked until you accept/cancel current outpainting. You can modify the 'sample number' to get multiple results; you can resize the canvas/selection with 'canvas setting'/'resize selection'; you can use 'photometric correction' to help preserve existing contents",{timeout:15000,where:query("#container")})
324
+ }
325
+ document.querySelector("#container").style.pointerEvents="none";
326
+ case "retry":
327
+ this.disable(...outpaint_result_func_lst);
328
+ parent.config_obj["interrogate_mode"]=false;
329
+ window.postMessage(["transfer",""],"*")
330
+ break;
331
+ case "interrogate":
332
+ if(this.interrogate_tip)
333
+ {
334
+ this.interrogate_tip=false;
335
+ w2utils.notify("ClipInterrogator v1 will be dynamically loaded when run at the first time, which may take a while",{timeout:10000,where:query("#container")})
336
+ }
337
+ parent.config_obj["interrogate_mode"]=true;
338
+ window.postMessage(["transfer",""],"*")
339
+ break
340
+ case "accept":
341
+ case "cancel":
342
+ this.hide(...outpaint_result_lst);
343
+ this.disable(...outpaint_result_func_lst);
344
+ this.enable(...outpaint_button_lst);
345
+ document.querySelector("#container").style.pointerEvents = "auto";
346
+
347
+ if (this.config_obj.enable_history) {
348
+ window.postMessage(["click", event.target, ""], "*");
349
+ } else {
350
+ window.postMessage(["click", event.target], "*");
351
+ }
352
+
353
+ // Instead of directly accessing shadowRoot, send a message to iframe
354
+ let frame = parent.document.querySelector("#sdinfframe");
355
+ if (frame && frame.contentWindow) {
356
+ frame.contentWindow.postMessage(["click", "cancel"], "*");
357
+ }
358
+
359
+ window.update_undo_redo(
360
+ window.undo_redo_state.undo,
361
+ window.undo_redo_state.redo
362
+ );
363
+ break;
364
+ case "eraser":
365
+ case "selection":
366
+ case "canvas":
367
+ if(event.target=="eraser")
368
+ {
369
+ this.show("eraser_size","eraser_size_btn");
370
+ window.overlay.freeDrawingBrush.width=this.eraser_size;
371
+ window.overlay.isDrawingMode = true;
372
+ }
373
+ else
374
+ {
375
+ this.hide("eraser_size","eraser_size_btn");
376
+ window.overlay.isDrawingMode = false;
377
+ }
378
+ if(this.upload_mode)
379
+ {
380
+ if(event.target=="canvas")
381
+ {
382
+ window.postMessage(["mode", event.target],"*")
383
+ document.querySelector("#overlay_container").style.pointerEvents="none";
384
+ document.querySelector("#overlay_container").style.opacity = 0.5;
385
+ }
386
+ else
387
+ {
388
+ document.querySelector("#overlay_container").style.pointerEvents="auto";
389
+ document.querySelector("#overlay_container").style.opacity = 1.0;
390
+ }
391
+ }
392
+ else
393
+ {
394
+ window.postMessage(["mode", event.target],"*")
395
+ }
396
+ break;
397
+ case "help":
398
+ w2popup.open({
399
+ title: "Document",
400
+ body: "Usage: <a href='https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md' target='_blank'>https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md</a>"
401
+ })
402
+ break;
403
+ case "clear":
404
+ w2confirm("Reset canvas?").yes(() => {
405
+ window.postMessage(["click", event.target],"*");
406
+ }).no(() => {})
407
+ break;
408
+ case "random_seed":
409
+ this.config_obj.seed_val=Math.floor(Math.random() * 3000000000);
410
+ parent.config_obj=this.config_obj;
411
+ this.refresh();
412
+ break;
413
+ case "enable_history":
414
+ case "enable_img2img":
415
+ case "use_correction":
416
+ case "resize_check":
417
+ case "enable_safety":
418
+ case "use_seed":
419
+ case "square_selection":
420
+ let target=this.get(event.target);
421
+ if(event.target=="enable_history")
422
+ {
423
+ if(!target.checked)
424
+ {
425
+ w2utils.notify("Enable canvas history might increase resource usage / slow down the canvas ", {error:true,timeout:3000,where:query("#container")})
426
+ window.postMessage(["click","history"],"*");
427
+ }
428
+ else
429
+ {
430
+ window.undo_redo_state.undo=false;
431
+ window.undo_redo_state.redo=false;
432
+ this.disable("undo","redo");
433
+ }
434
+ }
435
+ target.icon=target.checked?"fa-regular fa-square":"fa-solid fa-square-check";
436
+ this.config_obj[event.target]=!target.checked;
437
+ parent.config_obj=this.config_obj;
438
+ this.refresh();
439
+ break;
440
+ case "save":
441
+ case "export":
442
+ ask_filename(event.target);
443
+ break;
444
+ default:
445
+ // clear, save, export, outpaint, retry
446
+ // break, save, export, accept, retry, outpaint
447
+ window.postMessage(["click", event.target],"*")
448
+ }
449
+ console.log("Target: "+ event.target, event)
450
+ }
451
+ })
452
+ window.w2ui=w2ui;
453
+ w2ui.toolbar.config_obj={
454
+ resize_check: true,
455
+ enable_safety: true,
456
+ use_correction: false,
457
+ enable_img2img: false,
458
+ use_seed: false,
459
+ seed_val: 0,
460
+ square_selection: false,
461
+ enable_history: false,
462
+ };
463
+ w2ui.toolbar.outpaint_tip=true;
464
+ w2ui.toolbar.upload_tip=true;
465
+ w2ui.toolbar.interrogate_tip=true;
466
+ window.update_count=function(cur,total){
467
+ w2ui.toolbar.sel_value=`${cur}/${total}`;
468
+ w2ui.toolbar.refresh();
469
+ }
470
+ window.update_eraser=function(val,max_val){
471
+ w2ui.toolbar.eraser_size=`${val}`;
472
+ w2ui.toolbar.eraser_max=`${max_val}`;
473
+ w2ui.toolbar.setCount("eraser_size_btn", `${val}`);
474
+ w2ui.toolbar.refresh();
475
+ }
476
+ window.update_scale=function(val){
477
+ w2ui.toolbar.scale_value=`${val}`;
478
+ w2ui.toolbar.refresh();
479
+ }
480
+ window.enable_result_lst=function(){
481
+ w2ui.toolbar.enable(...outpaint_result_lst);
482
+ }
483
+ function onObjectScaled(e)
484
+ {
485
+ let object = e.target;
486
+ if(object.isType("rect"))
487
+ {
488
+ let width=object.getScaledWidth();
489
+ let height=object.getScaledHeight();
490
+ object.scale(1);
491
+ width=Math.max(Math.min(width,window.overlay.width-object.left),256);
492
+ height=Math.max(Math.min(height,window.overlay.height-object.top),256);
493
+ let l=Math.max(Math.min(object.left,window.overlay.width-width-object.strokeWidth),0);
494
+ let t=Math.max(Math.min(object.top,window.overlay.height-height-object.strokeWidth),0);
495
+ if(window.w2ui.toolbar.config_obj.square_selection)
496
+ {
497
+ let max_val = Math.min(Math.max(width,height),window.overlay.width,window.overlay.height);
498
+ width=max_val;
499
+ height=max_val;
500
+ }
501
+ object.set({ width: width, height: height, left:l,top:t})
502
+ window.w2ui.toolbar.selection_box={width: width, height: height, x:object.left, y:object.top};
503
+ window.w2ui.toolbar.setCount("resize_selection",`${Math.floor(width/8)*8}x${Math.floor(height/8)*8}`);
504
+ window.w2ui.toolbar.refresh();
505
+ }
506
+ }
507
+ function onObjectMoved(e)
508
+ {
509
+ let object = e.target;
510
+ if(object.isType("rect"))
511
+ {
512
+ let l=Math.max(Math.min(object.left,window.overlay.width-object.width-object.strokeWidth),0);
513
+ let t=Math.max(Math.min(object.top,window.overlay.height-object.height-object.strokeWidth),0);
514
+ object.set({left:l,top:t});
515
+ window.w2ui.toolbar.selection_box={width: object.width, height: object.height, x:object.left, y:object.top};
516
+ }
517
+ }
518
+ window.setup_overlay = function (width, height) {
519
+ if (window.overlay) {
520
+ window.overlay.setDimensions({ width: width, height: height });
521
+
522
+ // Find iframe safely
523
+ let frameEl = parent.document.querySelector("#sdinfframe");
524
+ if (frameEl) {
525
+ frameEl.style.height = 80 + Number(height) + "px";
526
+ } else {
527
+ console.warn("#sdinfframe not found in parent document.");
528
+ }
529
+
530
+ document.querySelector("#container").style.height = height + "px";
531
+ document.querySelector("#container").style.width = width + "px";
532
+ } else {
533
+ canvas = new fabric.Canvas("overlay_canvas");
534
+ canvas.setDimensions({ width: width, height: height });
535
+
536
+ // Find iframe safely
537
+ let frameEl = parent.document.querySelector("#sdinfframe");
538
+ if (frameEl) {
539
+ frameEl.style.height = 80 + Number(height) + "px";
540
+ } else {
541
+ console.warn("#sdinfframe not found in parent document.");
542
+ }
543
+
544
+ canvas.freeDrawingBrush = new fabric.EraserBrush(canvas);
545
+ canvas.on("object:scaling", onObjectScaled);
546
+ canvas.on("object:moving", onObjectMoved);
547
+ window.overlay = canvas;
548
+ }
549
+
550
+ document.querySelector("#overlay_container").style.pointerEvents = "none";
551
+ };
552
+
553
+ window.update_overlay=function(width,height)
554
+ {
555
+ window.overlay.setDimensions({width:width,height:height},{backstoreOnly:true});
556
+ // document.querySelector("#overlay_container").style.pointerEvents="none";
557
+ }
558
+ window.adjust_selection=function(x,y,width,height)
559
+ {
560
+ var rect = new fabric.Rect({
561
+ left: x,
562
+ top: y,
563
+ fill: "rgba(0,0,0,0)",
564
+ strokeWidth: 3,
565
+ stroke: "rgba(0,0,0,0.7)",
566
+ cornerColor: "red",
567
+ cornerStrokeColor: "red",
568
+ borderColor: "rgba(255, 0, 0, 1.0)",
569
+ width: width,
570
+ height: height,
571
+ lockRotation: true,
572
+ });
573
+ rect.setControlsVisibility({ mtr: false });
574
+ window.overlay.add(rect);
575
+ window.overlay.setActiveObject(window.overlay.item(0));
576
+ window.w2ui.toolbar.selection_box={width: width, height: height, x:x, y:y};
577
+ window.w2ui.toolbar.selection_box_bak={width: width, height: height, x:x, y:y};
578
+ }
579
+ function add_image(url)
580
+ {
581
+ fabric.Image.fromURL(url,function(img){
582
+ window.overlay.add(img);
583
+ window.overlay.setActiveObject(img);
584
+ },{left:100,top:100});
585
+ }
586
+ function export_image()
587
+ {
588
+ data=window.overlay.toDataURL();
589
+ document.querySelector("#upload_content").value=data;
590
+ if(window.w2ui.toolbar.config_obj.enable_history)
591
+ {
592
+ window.postMessage(["upload","",""],"*");
593
+ window.w2ui.toolbar.enable("undo");
594
+ window.w2ui.toolbar.disable("redo");
595
+ }
596
+ else
597
+ {
598
+ window.postMessage(["upload",""],"*");
599
+ }
600
+ end_overlay();
601
+ }
602
+ function end_overlay()
603
+ {
604
+ window.overlay.clear();
605
+ document.querySelector("#overlay_container").style.opacity = 1.0;
606
+ document.querySelector("#overlay_container").style.pointerEvents="none";
607
+ }
608
+ function ask_filename(target)
609
+ {
610
+ w2prompt({
611
+ label: "Enter filename",
612
+ value: `outpaint_${((new Date(Date.now() -(new Date()).getTimezoneOffset() * 60000))).toISOString().replace("T","_").replace(/[^0-9_]/g, "").substring(0,15)}`,
613
+ })
614
+ .change((event) => {
615
+ console.log("change", event.detail.originalEvent.target.value);
616
+ })
617
+ .ok((event) => {
618
+ console.log("value=", event.detail.value);
619
+ window.postMessage(["click",target,event.detail.value],"*");
620
+ })
621
+ .cancel((event) => {
622
+ console.log("cancel");
623
+ });
624
+ }
625
+
626
+ document.querySelector("#container").addEventListener("wheel",(e)=>{e.preventDefault()})
627
+ window.setup_shortcut=function(json)
628
+ {
629
+ var config=JSON.parse(json);
630
+ var key_map={};
631
+ Object.keys(config.shortcut).forEach(k=>{
632
+ key_map[config.shortcut[k]]=k;
633
+ })
634
+ document.addEventListener("keydown",(e)=>{
635
+ if(e.target.tagName!="INPUT")
636
+ {
637
+ let key=e.key;
638
+ if(e.ctrlKey)
639
+ {
640
+ key="Ctrl+"+e.key;
641
+ if(key in key_map)
642
+ {
643
+ e.preventDefault();
644
+ }
645
+ }
646
+ if(key in key_map)
647
+ {
648
+ w2ui.toolbar.click(key_map[key]);
649
+ }
650
+ }
651
+ })
652
+ }
653
+ window.undo_redo_state={undo:false,redo:false};
654
+ window.update_undo_redo=function(s0,s1)
655
+ {
656
+ if(s0)
657
+ {
658
+ w2ui.toolbar.enable("undo");
659
+ }
660
+ else
661
+ {
662
+ w2ui.toolbar.disable("undo");
663
+ }
664
+ if(s1)
665
+ {
666
+ w2ui.toolbar.enable("redo");
667
+ }
668
+ else
669
+ {
670
+ w2ui.toolbar.disable("redo");
671
+ }
672
+ window.undo_redo_state.undo=s0;
673
+ window.undo_redo_state.redo=s1;
674
+ }
js/upload.js ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(a,b){
2
+ if(!window.my_observe_upload)
3
+ {
4
+ console.log("setup upload here");
5
+ window.my_observe_upload = new MutationObserver(function (event) {
6
+ console.log(event);
7
+ var frame=document.querySelector("gradio-app").shadowRoot.querySelector("#sdinfframe").contentWindow.document;
8
+ frame.querySelector("#upload").click();
9
+ });
10
+ window.my_observe_upload_target = document.querySelector("gradio-app").shadowRoot.querySelector("#upload span");
11
+ window.my_observe_upload.observe(window.my_observe_upload_target, {
12
+ attributes: false,
13
+ subtree: true,
14
+ childList: true,
15
+ characterData: true
16
+ });
17
+ }
18
+ return [a,b];
19
+ }
js/w2ui.min.js ADDED
The diff for this file is too large to render. See raw diff
 
js/xss.js ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ var setup_outpaint=function(){
2
+ if(!window.my_observe_outpaint)
3
+ {
4
+ console.log("setup outpaint here");
5
+ window.my_observe_outpaint = new MutationObserver(function (event) {
6
+ console.log(event);
7
+ let app=document.querySelector("gradio-app");
8
+ app=app.shadowRoot??app;
9
+ let frame=app.querySelector("#sdinfframe").contentWindow;
10
+ frame.postMessage(["outpaint", ""], "*");
11
+ });
12
+ var app=document.querySelector("gradio-app");
13
+ app=app.shadowRoot??app;
14
+ window.my_observe_outpaint_target=app.querySelector("#output span");
15
+ window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
16
+ attributes: false,
17
+ subtree: true,
18
+ childList: true,
19
+ characterData: true
20
+ });
21
+ }
22
+ };
23
+ window.config_obj={
24
+ resize_check: true,
25
+ enable_safety: true,
26
+ use_correction: false,
27
+ enable_img2img: false,
28
+ use_seed: false,
29
+ seed_val: 0,
30
+ interrogate_mode: false,
31
+ };
32
+ setup_outpaint();
models/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
models/v1-inpainting-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 7.5e-05
3
+ target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: hybrid # important
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ finetune_keys: null
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 9 # 4 data + 4 downscaled image + 1 mask
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
opencv.pc ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prefix=/usr
2
+ exec_prefix=${prefix}
3
+ includedir=${prefix}/include
4
+ libdir=${exec_prefix}/lib
5
+
6
+ Name: opencv
7
+ Description: The opencv library
8
+ Version: 2.x.x
9
+ Cflags: -I${includedir}/opencv4
10
+ #Cflags: -I${includedir}/opencv -I${includedir}/opencv2
11
+ Libs: -L${libdir} -lopencv_calib3d -lopencv_imgproc -lopencv_xobjdetect -lopencv_hdf -lopencv_flann -lopencv_core -lopencv_dpm -lopencv_videoio -lopencv_reg -lopencv_quality -lopencv_tracking -lopencv_dnn_superres -lopencv_objdetect -lopencv_stitching -lopencv_saliency -lopencv_intensity_transform -lopencv_rapid -lopencv_dnn -lopencv_features2d -lopencv_text -lopencv_calib3d -lopencv_line_descriptor -lopencv_superres -lopencv_ml -lopencv_alphamat -lopencv_viz -lopencv_optflow -lopencv_videostab -lopencv_bioinspired -lopencv_highgui -lopencv_img_hash -lopencv_freetype -lopencv_imgcodecs -lopencv_mcc -lopencv_video -lopencv_photo -lopencv_surface_matching -lopencv_rgbd -lopencv_datasets -lopencv_ximgproc -lopencv_plot -lopencv_face -lopencv_stereo -lopencv_aruco -lopencv_dnn_objdetect -lopencv_phase_unwrapping -lopencv_bgsegm -lopencv_ccalib -lopencv_hfs -lopencv_imgproc -lopencv_shape -lopencv_xphoto -lopencv_structured_light -lopencv_fuzzy
packages.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ build-essential
2
+ python3-opencv
3
+ libopencv-dev
4
+ cmake
5
+ pkg-config
perlin2d.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ ##########
4
+ # https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921
5
+ def perlin(x, y, seed=0):
6
+ # permutation table
7
+ np.random.seed(seed)
8
+ p = np.arange(256, dtype=int)
9
+ np.random.shuffle(p)
10
+ p = np.stack([p, p]).flatten()
11
+ # coordinates of the top-left
12
+ xi, yi = x.astype(int), y.astype(int)
13
+ # internal coordinates
14
+ xf, yf = x - xi, y - yi
15
+ # fade factors
16
+ u, v = fade(xf), fade(yf)
17
+ # noise components
18
+ n00 = gradient(p[p[xi] + yi], xf, yf)
19
+ n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
20
+ n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
21
+ n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
22
+ # combine noises
23
+ x1 = lerp(n00, n10, u)
24
+ x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
25
+ return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
26
+
27
+
28
+ def lerp(a, b, x):
29
+ "linear interpolation"
30
+ return a + x * (b - a)
31
+
32
+
33
+ def fade(t):
34
+ "6t^5 - 15t^4 + 10t^3"
35
+ return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
36
+
37
+
38
+ def gradient(h, x, y):
39
+ "grad converts h to the right gradient vector and return the dot product with (x,y)"
40
+ vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
41
+ g = vectors[h % 4]
42
+ return g[:, :, 0] * x + g[:, :, 1] * y
43
+
44
+
45
+ ##########
postprocess.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
3
+ MIT License
4
+
5
+ Copyright (c) 2022 Jiayi Weng
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+
26
+ import time
27
+ import argparse
28
+ import os
29
+ import fpie
30
+ from process import ALL_BACKEND, CPU_COUNT, DEFAULT_BACKEND
31
+ from fpie.io import read_images, write_image
32
+ from process import BaseProcessor, EquProcessor, GridProcessor
33
+
34
+ from PIL import Image
35
+ import numpy as np
36
+ import skimage
37
+ import skimage.measure
38
+ import scipy
39
+ import scipy.signal
40
+
41
+
42
+ class PhotometricCorrection:
43
+ def __init__(self,quite=False):
44
+ self.get_parser("cli")
45
+ args=self.parser.parse_args(["--method","grid","-g","src","-s","a","-t","a","-o","a"])
46
+ args.mpi_sync_interval = getattr(args, "mpi_sync_interval", 0)
47
+ self.backend=args.backend
48
+ self.args=args
49
+ self.quite=quite
50
+ proc: BaseProcessor
51
+ proc = GridProcessor(
52
+ args.gradient,
53
+ args.backend,
54
+ args.cpu,
55
+ args.mpi_sync_interval,
56
+ args.block_size,
57
+ args.grid_x,
58
+ args.grid_y,
59
+ )
60
+ print(
61
+ f"[PIE]Successfully initialize PIE {args.method} solver "
62
+ f"with {args.backend} backend"
63
+ )
64
+ self.proc=proc
65
+
66
+ def run(self, original_image, inpainted_image, mode="mask_mode"):
67
+ print(f"[PIE] start")
68
+ if mode=="disabled":
69
+ return inpainted_image
70
+ input_arr=np.array(original_image)
71
+ if input_arr[:,:,-1].sum()<1:
72
+ return inpainted_image
73
+ output_arr=np.array(inpainted_image)
74
+ mask=input_arr[:,:,-1]
75
+ mask=255-mask
76
+ if mask.sum()<1 and mode=="mask_mode":
77
+ mode=""
78
+ if mode=="mask_mode":
79
+ mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
80
+ mask = mask.repeat(8, axis=0).repeat(8, axis=1)
81
+ else:
82
+ mask[8:-9,8:-9]=255
83
+ mask = mask[:,:,np.newaxis].repeat(3,axis=2)
84
+ nmask=mask.copy()
85
+ output_arr2=output_arr[:,:,0:3].copy()
86
+ input_arr2=input_arr[:,:,0:3].copy()
87
+ output_arr2[nmask<128]=0
88
+ input_arr2[nmask>=128]=0
89
+ output_arr2+=input_arr2
90
+ src = output_arr2[:,:,0:3]
91
+ tgt = src.copy()
92
+ proc=self.proc
93
+ args=self.args
94
+ if proc.root:
95
+ n = proc.reset(src, mask, tgt, (args.h0, args.w0), (args.h1, args.w1))
96
+ proc.sync()
97
+ if proc.root:
98
+ result = tgt
99
+ t = time.time()
100
+ if args.p == 0:
101
+ args.p = args.n
102
+
103
+ for i in range(0, args.n, args.p):
104
+ if proc.root:
105
+ result, err = proc.step(args.p) # type: ignore
106
+ print(f"[PIE] Iter {i + args.p}, abs_err {err}")
107
+ else:
108
+ proc.step(args.p)
109
+
110
+ if proc.root:
111
+ dt = time.time() - t
112
+ print(f"[PIE] Time elapsed: {dt:.4f}s")
113
+ # make sure consistent with dummy process
114
+ return Image.fromarray(result)
115
+
116
+
117
+ def get_parser(self,gen_type: str) -> argparse.Namespace:
118
+ parser = argparse.ArgumentParser()
119
+ parser.add_argument(
120
+ "-v", "--version", action="store_true", help="show the version and exit"
121
+ )
122
+ parser.add_argument(
123
+ "--check-backend", action="store_true", help="print all available backends"
124
+ )
125
+ if gen_type == "gui" and "mpi" in ALL_BACKEND:
126
+ # gui doesn't support MPI backend
127
+ ALL_BACKEND.remove("mpi")
128
+ parser.add_argument(
129
+ "-b",
130
+ "--backend",
131
+ type=str,
132
+ choices=ALL_BACKEND,
133
+ default=DEFAULT_BACKEND,
134
+ help="backend choice",
135
+ )
136
+ parser.add_argument(
137
+ "-c",
138
+ "--cpu",
139
+ type=int,
140
+ default=CPU_COUNT,
141
+ help="number of CPU used",
142
+ )
143
+ parser.add_argument(
144
+ "-z",
145
+ "--block-size",
146
+ type=int,
147
+ default=1024,
148
+ help="cuda block size (only for equ solver)",
149
+ )
150
+ parser.add_argument(
151
+ "--method",
152
+ type=str,
153
+ choices=["equ", "grid"],
154
+ default="equ",
155
+ help="how to parallelize computation",
156
+ )
157
+ parser.add_argument("-s", "--source", type=str, help="source image filename")
158
+ if gen_type == "cli":
159
+ parser.add_argument(
160
+ "-m",
161
+ "--mask",
162
+ type=str,
163
+ help="mask image filename (default is to use the whole source image)",
164
+ default="",
165
+ )
166
+ parser.add_argument("-t", "--target", type=str, help="target image filename")
167
+ parser.add_argument("-o", "--output", type=str, help="output image filename")
168
+ if gen_type == "cli":
169
+ parser.add_argument(
170
+ "-h0", type=int, help="mask position (height) on source image", default=0
171
+ )
172
+ parser.add_argument(
173
+ "-w0", type=int, help="mask position (width) on source image", default=0
174
+ )
175
+ parser.add_argument(
176
+ "-h1", type=int, help="mask position (height) on target image", default=0
177
+ )
178
+ parser.add_argument(
179
+ "-w1", type=int, help="mask position (width) on target image", default=0
180
+ )
181
+ parser.add_argument(
182
+ "-g",
183
+ "--gradient",
184
+ type=str,
185
+ choices=["max", "src", "avg"],
186
+ default="max",
187
+ help="how to calculate gradient for PIE",
188
+ )
189
+ parser.add_argument(
190
+ "-n",
191
+ type=int,
192
+ help="how many iteration would you perfer, the more the better",
193
+ default=5000,
194
+ )
195
+ if gen_type == "cli":
196
+ parser.add_argument(
197
+ "-p", type=int, help="output result every P iteration", default=0
198
+ )
199
+ if "mpi" in ALL_BACKEND:
200
+ parser.add_argument(
201
+ "--mpi-sync-interval",
202
+ type=int,
203
+ help="MPI sync iteration interval",
204
+ default=100,
205
+ )
206
+ parser.add_argument(
207
+ "--grid-x", type=int, help="x axis stride for grid solver", default=8
208
+ )
209
+ parser.add_argument(
210
+ "--grid-y", type=int, help="y axis stride for grid solver", default=8
211
+ )
212
+ self.parser=parser
213
+
214
+ if __name__ =="__main__":
215
+ import sys
216
+ import io
217
+ import base64
218
+ from PIL import Image
219
+ def base64_to_pil(base64_str):
220
+ data = base64.b64decode(str(base64_str))
221
+ pil = Image.open(io.BytesIO(data))
222
+ return pil
223
+
224
+ def pil_to_base64(out_pil):
225
+ out_buffer = io.BytesIO()
226
+ out_pil.save(out_buffer, format="PNG")
227
+ out_buffer.seek(0)
228
+ base64_bytes = base64.b64encode(out_buffer.read())
229
+ base64_str = base64_bytes.decode("ascii")
230
+ return base64_str
231
+ correction_func=PhotometricCorrection(quite=True)
232
+ while True:
233
+ buffer = sys.stdin.readline()
234
+ print(f"[PIE] suprocess {len(buffer)} {type(buffer)} ")
235
+ if len(buffer)==0:
236
+ break
237
+ if isinstance(buffer,str):
238
+ lst=buffer.strip().split(",")
239
+ else:
240
+ lst=buffer.decode("ascii").strip().split(",")
241
+ img0=base64_to_pil(lst[0])
242
+ img1=base64_to_pil(lst[1])
243
+ ret=correction_func.run(img0,img1,mode=lst[2])
244
+ ret_base64=pil_to_base64(ret)
245
+ if isinstance(buffer,str):
246
+ sys.stdout.write(f"{ret_base64}\n")
247
+ else:
248
+ sys.stdout.write(f"{ret_base64}\n".encode())
249
+ sys.stdout.flush()
process.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
3
+ MIT License
4
+
5
+ Copyright (c) 2022 Jiayi Weng
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+ """
25
+ import os
26
+ from abc import ABC, abstractmethod
27
+ from typing import Any, Optional, Tuple
28
+
29
+ import numpy as np
30
+
31
+ from fpie import np_solver
32
+
33
+ import scipy
34
+ import scipy.signal
35
+
36
+ CPU_COUNT = os.cpu_count() or 1
37
+ DEFAULT_BACKEND = "numpy"
38
+ ALL_BACKEND = ["numpy"]
39
+
40
+ try:
41
+ from fpie import numba_solver
42
+ ALL_BACKEND += ["numba"]
43
+ DEFAULT_BACKEND = "numba"
44
+ except ImportError:
45
+ numba_solver = None # type: ignore
46
+
47
+ try:
48
+ from fpie import taichi_solver
49
+ ALL_BACKEND += ["taichi-cpu", "taichi-gpu"]
50
+ DEFAULT_BACKEND = "taichi-cpu"
51
+ except ImportError:
52
+ taichi_solver = None # type: ignore
53
+
54
+ # try:
55
+ # from fpie import core_gcc # type: ignore
56
+ # DEFAULT_BACKEND = "gcc"
57
+ # ALL_BACKEND.append("gcc")
58
+ # except ImportError:
59
+ # core_gcc = None
60
+
61
+ # try:
62
+ # from fpie import core_openmp # type: ignore
63
+ # DEFAULT_BACKEND = "openmp"
64
+ # ALL_BACKEND.append("openmp")
65
+ # except ImportError:
66
+ # core_openmp = None
67
+
68
+ # try:
69
+ # from mpi4py import MPI
70
+
71
+ # from fpie import core_mpi # type: ignore
72
+ # ALL_BACKEND.append("mpi")
73
+ # except ImportError:
74
+ # MPI = None # type: ignore
75
+ # core_mpi = None
76
+
77
+ try:
78
+ from fpie import core_cuda # type: ignore
79
+ DEFAULT_BACKEND = "cuda"
80
+ ALL_BACKEND.append("cuda")
81
+ except ImportError:
82
+ core_cuda = None
83
+
84
+
85
+ class BaseProcessor(ABC):
86
+ """API definition for processor class."""
87
+
88
+ def __init__(
89
+ self, gradient: str, rank: int, backend: str, core: Optional[Any]
90
+ ):
91
+ if core is None:
92
+ error_msg = {
93
+ "numpy":
94
+ "Please run `pip install numpy`.",
95
+ "numba":
96
+ "Please run `pip install numba`.",
97
+ "gcc":
98
+ "Please install cmake and gcc in your operating system.",
99
+ "openmp":
100
+ "Please make sure your gcc is compatible with `-fopenmp` option.",
101
+ "mpi":
102
+ "Please install MPI and run `pip install mpi4py`.",
103
+ "cuda":
104
+ "Please make sure nvcc and cuda-related libraries are available.",
105
+ "taichi":
106
+ "Please run `pip install taichi`.",
107
+ }
108
+ print(error_msg[backend.split("-")[0]])
109
+
110
+ raise AssertionError(f"Invalid backend {backend}.")
111
+
112
+ self.gradient = gradient
113
+ self.rank = rank
114
+ self.backend = backend
115
+ self.core = core
116
+ self.root = rank == 0
117
+
118
+ def mixgrad(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
119
+ if self.gradient == "src":
120
+ return a
121
+ if self.gradient == "avg":
122
+ return (a + b) / 2
123
+ # mix gradient, see Equ. 12 in PIE paper
124
+ mask = np.abs(a) < np.abs(b)
125
+ a[mask] = b[mask]
126
+ return a
127
+
128
+ @abstractmethod
129
+ def reset(
130
+ self,
131
+ src: np.ndarray,
132
+ mask: np.ndarray,
133
+ tgt: np.ndarray,
134
+ mask_on_src: Tuple[int, int],
135
+ mask_on_tgt: Tuple[int, int],
136
+ ) -> int:
137
+ pass
138
+
139
+ def sync(self) -> None:
140
+ self.core.sync()
141
+
142
+ @abstractmethod
143
+ def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
144
+ pass
145
+
146
+
147
+ class EquProcessor(BaseProcessor):
148
+ """PIE Jacobi equation processor."""
149
+
150
+ def __init__(
151
+ self,
152
+ gradient: str = "max",
153
+ backend: str = DEFAULT_BACKEND,
154
+ n_cpu: int = CPU_COUNT,
155
+ min_interval: int = 100,
156
+ block_size: int = 1024,
157
+ ):
158
+ core: Optional[Any] = None
159
+ rank = 0
160
+
161
+ if backend == "numpy":
162
+ core = np_solver.EquSolver()
163
+ elif backend == "numba" and numba_solver is not None:
164
+ core = numba_solver.EquSolver()
165
+ elif backend == "gcc":
166
+ core = core_gcc.EquSolver()
167
+ elif backend == "openmp" and core_openmp is not None:
168
+ core = core_openmp.EquSolver(n_cpu)
169
+ elif backend == "mpi" and core_mpi is not None:
170
+ core = core_mpi.EquSolver(min_interval)
171
+ rank = MPI.COMM_WORLD.Get_rank()
172
+ elif backend == "cuda" and core_cuda is not None:
173
+ core = core_cuda.EquSolver(block_size)
174
+ elif backend.startswith("taichi") and taichi_solver is not None:
175
+ core = taichi_solver.EquSolver(backend, n_cpu, block_size)
176
+
177
+ super().__init__(gradient, rank, backend, core)
178
+
179
+ def mask2index(
180
+ self, mask: np.ndarray
181
+ ) -> Tuple[np.ndarray, int, np.ndarray, np.ndarray]:
182
+ x, y = np.nonzero(mask)
183
+ max_id = x.shape[0] + 1
184
+ index = np.zeros((max_id, 3))
185
+ ids = self.core.partition(mask)
186
+ ids[mask == 0] = 0 # reserve id=0 for constant
187
+ index = ids[x, y].argsort()
188
+ return ids, max_id, x[index], y[index]
189
+
190
+ def reset(
191
+ self,
192
+ src: np.ndarray,
193
+ mask: np.ndarray,
194
+ tgt: np.ndarray,
195
+ mask_on_src: Tuple[int, int],
196
+ mask_on_tgt: Tuple[int, int],
197
+ ) -> int:
198
+ assert self.root
199
+ # check validity
200
+ # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
201
+ # assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
202
+ # assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
203
+ # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
204
+ # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
205
+
206
+ if len(mask.shape) == 3:
207
+ mask = mask.mean(-1)
208
+ mask = (mask >= 128).astype(np.int32)
209
+
210
+ # zero-out edge
211
+ mask[0] = 0
212
+ mask[-1] = 0
213
+ mask[:, 0] = 0
214
+ mask[:, -1] = 0
215
+
216
+ x, y = np.nonzero(mask)
217
+ x0, x1 = x.min() - 1, x.max() + 2
218
+ y0, y1 = y.min() - 1, y.max() + 2
219
+ mask_on_src = (x0 + mask_on_src[0], y0 + mask_on_src[1])
220
+ mask_on_tgt = (x0 + mask_on_tgt[0], y0 + mask_on_tgt[1])
221
+ mask = mask[x0:x1, y0:y1]
222
+ ids, max_id, index_x, index_y = self.mask2index(mask)
223
+
224
+ src_x, src_y = index_x + mask_on_src[0], index_y + mask_on_src[1]
225
+ tgt_x, tgt_y = index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]
226
+
227
+ src_C = src[src_x, src_y].astype(np.float32)
228
+ src_U = src[src_x - 1, src_y].astype(np.float32)
229
+ src_D = src[src_x + 1, src_y].astype(np.float32)
230
+ src_L = src[src_x, src_y - 1].astype(np.float32)
231
+ src_R = src[src_x, src_y + 1].astype(np.float32)
232
+ tgt_C = tgt[tgt_x, tgt_y].astype(np.float32)
233
+ tgt_U = tgt[tgt_x - 1, tgt_y].astype(np.float32)
234
+ tgt_D = tgt[tgt_x + 1, tgt_y].astype(np.float32)
235
+ tgt_L = tgt[tgt_x, tgt_y - 1].astype(np.float32)
236
+ tgt_R = tgt[tgt_x, tgt_y + 1].astype(np.float32)
237
+
238
+ grad = self.mixgrad(src_C - src_L, tgt_C - tgt_L) \
239
+ + self.mixgrad(src_C - src_R, tgt_C - tgt_R) \
240
+ + self.mixgrad(src_C - src_U, tgt_C - tgt_U) \
241
+ + self.mixgrad(src_C - src_D, tgt_C - tgt_D)
242
+
243
+ A = np.zeros((max_id, 4), np.int32)
244
+ X = np.zeros((max_id, 3), np.float32)
245
+ B = np.zeros((max_id, 3), np.float32)
246
+
247
+ X[1:] = tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]]
248
+ # four-way
249
+ A[1:, 0] = ids[index_x - 1, index_y]
250
+ A[1:, 1] = ids[index_x + 1, index_y]
251
+ A[1:, 2] = ids[index_x, index_y - 1]
252
+ A[1:, 3] = ids[index_x, index_y + 1]
253
+ B[1:] = grad
254
+ m = (mask[index_x - 1, index_y] == 0).astype(float).reshape(-1, 1)
255
+ B[1:] += m * tgt[index_x + mask_on_tgt[0] - 1, index_y + mask_on_tgt[1]]
256
+ m = (mask[index_x, index_y - 1] == 0).astype(float).reshape(-1, 1)
257
+ B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] - 1]
258
+ m = (mask[index_x, index_y + 1] == 0).astype(float).reshape(-1, 1)
259
+ B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] + 1]
260
+ m = (mask[index_x + 1, index_y] == 0).astype(float).reshape(-1, 1)
261
+ B[1:] += m * tgt[index_x + mask_on_tgt[0] + 1, index_y + mask_on_tgt[1]]
262
+
263
+ self.tgt = tgt.copy()
264
+ self.tgt_index = (index_x + mask_on_tgt[0], index_y + mask_on_tgt[1])
265
+ self.core.reset(max_id, A, X, B)
266
+ return max_id
267
+
268
+ def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
269
+ result = self.core.step(iteration)
270
+ if self.root:
271
+ x, err = result
272
+ self.tgt[self.tgt_index] = x[1:]
273
+ return self.tgt, err
274
+ return None
275
+
276
+
277
+ class GridProcessor(BaseProcessor):
278
+ """PIE grid processor."""
279
+
280
+ def __init__(
281
+ self,
282
+ gradient: str = "max",
283
+ backend: str = DEFAULT_BACKEND,
284
+ n_cpu: int = CPU_COUNT,
285
+ min_interval: int = 100,
286
+ block_size: int = 1024,
287
+ grid_x: int = 8,
288
+ grid_y: int = 8,
289
+ ):
290
+ core: Optional[Any] = None
291
+ rank = 0
292
+
293
+ if backend == "numpy":
294
+ core = np_solver.GridSolver()
295
+ elif backend == "numba" and numba_solver is not None:
296
+ core = numba_solver.GridSolver()
297
+ elif backend == "gcc":
298
+ core = core_gcc.GridSolver(grid_x, grid_y)
299
+ elif backend == "openmp" and core_openmp is not None:
300
+ core = core_openmp.GridSolver(grid_x, grid_y, n_cpu)
301
+ elif backend == "mpi" and core_mpi is not None:
302
+ core = core_mpi.GridSolver(min_interval)
303
+ rank = MPI.COMM_WORLD.Get_rank()
304
+ elif backend == "cuda" and core_cuda is not None:
305
+ core = core_cuda.GridSolver(grid_x, grid_y)
306
+ elif backend.startswith("taichi") and taichi_solver is not None:
307
+ core = taichi_solver.GridSolver(
308
+ grid_x, grid_y, backend, n_cpu, block_size
309
+ )
310
+
311
+ super().__init__(gradient, rank, backend, core)
312
+
313
+ def reset(
314
+ self,
315
+ src: np.ndarray,
316
+ mask: np.ndarray,
317
+ tgt: np.ndarray,
318
+ mask_on_src: Tuple[int, int],
319
+ mask_on_tgt: Tuple[int, int],
320
+ ) -> int:
321
+ assert self.root
322
+ # check validity
323
+ # assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
324
+ # assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
325
+ # assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
326
+ # assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
327
+ # assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
328
+
329
+ if len(mask.shape) == 3:
330
+ mask = mask.mean(-1)
331
+ mask = (mask >= 128).astype(np.int32)
332
+
333
+ # zero-out edge
334
+ mask[0] = 0
335
+ mask[-1] = 0
336
+ mask[:, 0] = 0
337
+ mask[:, -1] = 0
338
+
339
+ x, y = np.nonzero(mask)
340
+ x0, x1 = x.min() - 1, x.max() + 2
341
+ y0, y1 = y.min() - 1, y.max() + 2
342
+ mask = mask[x0:x1, y0:y1]
343
+ max_id = np.prod(mask.shape)
344
+
345
+ src_crop = src[mask_on_src[0] + x0:mask_on_src[0] + x1,
346
+ mask_on_src[1] + y0:mask_on_src[1] + y1].astype(np.float32)
347
+ tgt_crop = tgt[mask_on_tgt[0] + x0:mask_on_tgt[0] + x1,
348
+ mask_on_tgt[1] + y0:mask_on_tgt[1] + y1].astype(np.float32)
349
+ grad = np.zeros([*mask.shape, 3], np.float32)
350
+ grad[1:] += self.mixgrad(
351
+ src_crop[1:] - src_crop[:-1], tgt_crop[1:] - tgt_crop[:-1]
352
+ )
353
+ grad[:-1] += self.mixgrad(
354
+ src_crop[:-1] - src_crop[1:], tgt_crop[:-1] - tgt_crop[1:]
355
+ )
356
+ grad[:, 1:] += self.mixgrad(
357
+ src_crop[:, 1:] - src_crop[:, :-1], tgt_crop[:, 1:] - tgt_crop[:, :-1]
358
+ )
359
+ grad[:, :-1] += self.mixgrad(
360
+ src_crop[:, :-1] - src_crop[:, 1:], tgt_crop[:, :-1] - tgt_crop[:, 1:]
361
+ )
362
+
363
+ grad[mask == 0] = 0
364
+ if True:
365
+ kernel = [[1] * 3 for _ in range(3)]
366
+ nmask = mask.copy()
367
+ nmask[nmask > 0] = 1
368
+ res = scipy.signal.convolve2d(
369
+ nmask, kernel, mode="same", boundary="fill", fillvalue=1
370
+ )
371
+ res[nmask < 1] = 0
372
+ res[res == 9] = 0
373
+ res[res > 0] = 1
374
+ grad[res>0]=0
375
+ # ylst, xlst = res.nonzero()
376
+ # for y, x in zip(ylst, xlst):
377
+ # grad[y,x]=0
378
+ # for yi in range(-1,2):
379
+ # for xi in range(-1,2):
380
+ # grad[y+yi,x+xi]=0
381
+ self.x0 = mask_on_tgt[0] + x0
382
+ self.x1 = mask_on_tgt[0] + x1
383
+ self.y0 = mask_on_tgt[1] + y0
384
+ self.y1 = mask_on_tgt[1] + y1
385
+ self.tgt = tgt.copy()
386
+ self.core.reset(max_id, mask, tgt_crop, grad)
387
+ return max_id
388
+
389
+ def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
390
+ result = self.core.step(iteration)
391
+ if self.root:
392
+ tgt, err = result
393
+ self.tgt[self.x0:self.x1, self.y0:self.y1] = tgt
394
+ return self.tgt, err
395
+ return None
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --extra-index-url https://download.pytorch.org/whl/cu118
2
+ imageio
3
+ imageio-ffmpeg
4
+ numpy<=1.22.4
5
+ torch==2.5.1
6
+ torchvision
7
+ torchaudio
8
+ pydantic==2.10.6
9
+ Pillow
10
+ scipy
11
+ scikit-image
12
+ diffusers
13
+ stablepy
14
+ transformers==4.49.0
15
+ # ftfy
16
+ # fpie
17
+ accelerate
18
+ ninja
19
+ opencv-python
20
+ opencv-python-headless
utils.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from PIL import ImageFilter
3
+ import cv2
4
+ import numpy as np
5
+ import scipy
6
+ import scipy.signal
7
+ from scipy.spatial import cKDTree
8
+
9
+ import os
10
+ from perlin2d import *
11
+
12
+ patch_match_compiled = True
13
+
14
+ try:
15
+ from PyPatchMatch import patch_match
16
+ except Exception as e:
17
+ try:
18
+ import patch_match
19
+ except Exception as e:
20
+ patch_match_compiled = False
21
+
22
+ try:
23
+ patch_match
24
+ except NameError:
25
+ print("patch_match compiling failed, will fall back to edge_pad")
26
+ patch_match_compiled = False
27
+
28
+
29
+
30
+
31
+ def edge_pad(img, mask, mode=1):
32
+ if mode == 0:
33
+ nmask = mask.copy()
34
+ nmask[nmask > 0] = 1
35
+ res0 = 1 - nmask
36
+ res1 = nmask
37
+ p0 = np.stack(res0.nonzero(), axis=0).transpose()
38
+ p1 = np.stack(res1.nonzero(), axis=0).transpose()
39
+ min_dists, min_dist_idx = cKDTree(p1).query(p0, 1)
40
+ loc = p1[min_dist_idx]
41
+ for (a, b), (c, d) in zip(p0, loc):
42
+ img[a, b] = img[c, d]
43
+ elif mode == 1:
44
+ record = {}
45
+ kernel = [[1] * 3 for _ in range(3)]
46
+ nmask = mask.copy()
47
+ nmask[nmask > 0] = 1
48
+ res = scipy.signal.convolve2d(
49
+ nmask, kernel, mode="same", boundary="fill", fillvalue=1
50
+ )
51
+ res[nmask < 1] = 0
52
+ res[res == 9] = 0
53
+ res[res > 0] = 1
54
+ ylst, xlst = res.nonzero()
55
+ queue = [(y, x) for y, x in zip(ylst, xlst)]
56
+ # bfs here
57
+ cnt = res.astype(np.float32)
58
+ acc = img.astype(np.float32)
59
+ step = 1
60
+ h = acc.shape[0]
61
+ w = acc.shape[1]
62
+ offset = [(1, 0), (-1, 0), (0, 1), (0, -1)]
63
+ while queue:
64
+ target = []
65
+ for y, x in queue:
66
+ val = acc[y][x]
67
+ for yo, xo in offset:
68
+ yn = y + yo
69
+ xn = x + xo
70
+ if 0 <= yn < h and 0 <= xn < w and nmask[yn][xn] < 1:
71
+ if record.get((yn, xn), step) == step:
72
+ acc[yn][xn] = acc[yn][xn] * cnt[yn][xn] + val
73
+ cnt[yn][xn] += 1
74
+ acc[yn][xn] /= cnt[yn][xn]
75
+ if (yn, xn) not in record:
76
+ record[(yn, xn)] = step
77
+ target.append((yn, xn))
78
+ step += 1
79
+ queue = target
80
+ img = acc.astype(np.uint8)
81
+ else:
82
+ nmask = mask.copy()
83
+ ylst, xlst = nmask.nonzero()
84
+ yt, xt = ylst.min(), xlst.min()
85
+ yb, xb = ylst.max(), xlst.max()
86
+ content = img[yt : yb + 1, xt : xb + 1]
87
+ img = np.pad(
88
+ content,
89
+ ((yt, mask.shape[0] - yb - 1), (xt, mask.shape[1] - xb - 1), (0, 0)),
90
+ mode="edge",
91
+ )
92
+ return img, mask
93
+
94
+
95
+ def perlin_noise(img, mask):
96
+ lin_x = np.linspace(0, 5, mask.shape[1], endpoint=False)
97
+ lin_y = np.linspace(0, 5, mask.shape[0], endpoint=False)
98
+ x, y = np.meshgrid(lin_x, lin_y)
99
+ avg = img.mean(axis=0).mean(axis=0)
100
+ # noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)]
101
+ noise = [((perlin(x, y) + 1) * 0.5 * 255).astype(np.uint8) for i in range(3)]
102
+ noise = np.stack(noise, axis=-1)
103
+ # mask=skimage.measure.block_reduce(mask,(8,8),np.min)
104
+ # mask=mask.repeat(8, axis=0).repeat(8, axis=1)
105
+ # mask_image=Image.fromarray(mask)
106
+ # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 4))
107
+ # mask=np.array(mask_image)
108
+ nmask = mask.copy()
109
+ # nmask=nmask/255.0
110
+ nmask[mask > 0] = 1
111
+ img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
112
+ # img=img.astype(np.uint8)
113
+ return img, mask
114
+
115
+
116
+ def gaussian_noise(img, mask):
117
+ noise = np.random.randn(mask.shape[0], mask.shape[1], 3)
118
+ noise = (noise + 1) / 2 * 255
119
+ noise = noise.astype(np.uint8)
120
+ nmask = mask.copy()
121
+ nmask[mask > 0] = 1
122
+ img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
123
+ return img, mask
124
+
125
+
126
+ def cv2_telea(img, mask):
127
+ ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_TELEA)
128
+ return ret, mask
129
+
130
+
131
+ def cv2_ns(img, mask):
132
+ ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_NS)
133
+ return ret, mask
134
+
135
+
136
+ def patch_match_func(img, mask):
137
+ ret = patch_match.inpaint(img, mask=255 - mask, patch_size=3)
138
+ return ret, mask
139
+
140
+
141
+ def mean_fill(img, mask):
142
+ avg = img.mean(axis=0).mean(axis=0)
143
+ img[mask < 1] = avg
144
+ return img, mask
145
+
146
+ """
147
+ Apache-2.0 license
148
+ https://github.com/hafriedlander/stable-diffusion-grpcserver/blob/main/sdgrpcserver/services/generate.py
149
+ https://github.com/parlance-zz/g-diffuser-bot/tree/g-diffuser-bot-beta2
150
+ _handleImageAdjustment
151
+ """
152
+ try:
153
+ from sd_grpcserver.sdgrpcserver import images
154
+ import torch
155
+ from math import sqrt
156
+ def handleImageAdjustment(array, adjustments):
157
+ tensor = images.fromPIL(Image.fromarray(array))
158
+ for adjustment in adjustments:
159
+ which = adjustment[0]
160
+
161
+ if which == "blur":
162
+ sigma = adjustment[1]
163
+ direction = adjustment[2]
164
+
165
+ if direction == "DOWN" or direction == "UP":
166
+ orig = tensor
167
+ repeatCount=256
168
+ sigma /= sqrt(repeatCount)
169
+
170
+ for _ in range(repeatCount):
171
+ tensor = images.gaussianblur(tensor, sigma)
172
+ if direction == "DOWN":
173
+ tensor = torch.minimum(tensor, orig)
174
+ else:
175
+ tensor = torch.maximum(tensor, orig)
176
+ else:
177
+ tensor = images.gaussianblur(tensor, adjustment.blur.sigma)
178
+ elif which == "invert":
179
+ tensor = images.invert(tensor)
180
+ elif which == "levels":
181
+ tensor = images.levels(tensor, adjustment[1], adjustment[2], adjustment[3], adjustment[4])
182
+ elif which == "channels":
183
+ tensor = images.channelmap(tensor, [adjustment.channels.r, adjustment.channels.g, adjustment.channels.b, adjustment.channels.a])
184
+ elif which == "rescale":
185
+ self.unimp("Rescale")
186
+ elif which == "crop":
187
+ tensor = images.crop(tensor, adjustment.crop.top, adjustment.crop.left, adjustment.crop.height, adjustment.crop.width)
188
+ return np.array(images.toPIL(tensor)[0])
189
+
190
+ def g_diffuser(img,mask):
191
+ adjustments=[["blur",32,"UP"],["level",0,0.05,0,1]]
192
+ mask=handleImageAdjustment(mask,adjustments)
193
+ out_mask=handleImageAdjustment(mask,adjustments)
194
+ return img, mask
195
+ except:
196
+ def g_diffuser(img,mask):
197
+ return img,mask
198
+
199
+ def dummy_fill(img,mask):
200
+ return img,mask
201
+ functbl = {
202
+ "gaussian": gaussian_noise,
203
+ "perlin": perlin_noise,
204
+ "edge_pad": edge_pad,
205
+ "patchmatch": patch_match_func if patch_match_compiled else edge_pad,
206
+ "cv2_ns": cv2_ns,
207
+ "cv2_telea": cv2_telea,
208
+ "g_diffuser": g_diffuser,
209
+ "g_diffuser_lib": dummy_fill,
210
+ }
211
+
212
+ try:
213
+ from postprocess import PhotometricCorrection
214
+ correction_func = PhotometricCorrection()
215
+ except Exception as e:
216
+ print(e, "so PhotometricCorrection is disabled")
217
+ class DummyCorrection:
218
+ def __init__(self):
219
+ self.backend=""
220
+ pass
221
+ def run(self,a,b,**kwargs):
222
+ return b
223
+ correction_func=DummyCorrection()
224
+
225
+ class DummyInterrogator:
226
+ def __init__(self) -> None:
227
+ pass
228
+ def interrogate(self,pil):
229
+ return "Interrogator init failed"
230
+
231
+ if "taichi" in correction_func.backend:
232
+ import sys
233
+ import io
234
+ import base64
235
+ from PIL import Image
236
+ def base64_to_pil(base64_str):
237
+ data = base64.b64decode(str(base64_str))
238
+ pil = Image.open(io.BytesIO(data))
239
+ return pil
240
+
241
+ def pil_to_base64(out_pil):
242
+ out_buffer = io.BytesIO()
243
+ out_pil.save(out_buffer, format="PNG")
244
+ out_buffer.seek(0)
245
+ base64_bytes = base64.b64encode(out_buffer.read())
246
+ base64_str = base64_bytes.decode("ascii")
247
+ return base64_str
248
+ from subprocess import Popen, PIPE, STDOUT
249
+ class SubprocessCorrection:
250
+ def __init__(self):
251
+ self.backend=correction_func.backend
252
+ self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
253
+ def run(self,img_input,img_inpainted,mode):
254
+ if mode=="disabled":
255
+ return img_inpainted
256
+ base64_str_input = pil_to_base64(img_input)
257
+ base64_str_inpainted = pil_to_base64(img_inpainted)
258
+ try:
259
+ if self.child.poll():
260
+ self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
261
+ self.child.stdin.write(f"{base64_str_input},{base64_str_inpainted},{mode}\n".encode())
262
+ self.child.stdin.flush()
263
+ out = self.child.stdout.readline()
264
+ base64_str=out.decode().strip()
265
+ while base64_str and base64_str[0]=="[":
266
+ print(base64_str)
267
+ out = self.child.stdout.readline()
268
+ base64_str=out.decode().strip()
269
+ ret=base64_to_pil(base64_str)
270
+ except:
271
+ print("[PIE] not working, photometric correction is disabled")
272
+ ret=img_inpainted
273
+ return ret
274
+ correction_func = SubprocessCorrection()