Super-squash branch 'main' using huggingface_hub
Browse filesCo-authored-by: lnyan <lnyan@users.noreply.huggingface.co>
- .gitattributes +31 -0
- PyPatchMatch/.gitignore +4 -0
- PyPatchMatch/LICENSE +21 -0
- PyPatchMatch/Makefile +54 -0
- PyPatchMatch/README.md +64 -0
- PyPatchMatch/csrc/inpaint.cpp +234 -0
- PyPatchMatch/csrc/inpaint.h +27 -0
- PyPatchMatch/csrc/masked_image.cpp +138 -0
- PyPatchMatch/csrc/masked_image.h +112 -0
- PyPatchMatch/csrc/nnf.cpp +268 -0
- PyPatchMatch/csrc/nnf.h +133 -0
- PyPatchMatch/csrc/pyinterface.cpp +107 -0
- PyPatchMatch/csrc/pyinterface.h +38 -0
- PyPatchMatch/examples/.gitignore +2 -0
- PyPatchMatch/examples/cpp_example.cpp +31 -0
- PyPatchMatch/examples/cpp_example_run.sh +18 -0
- PyPatchMatch/examples/images/forest.bmp +0 -0
- PyPatchMatch/examples/images/forest_pruned.bmp +0 -0
- PyPatchMatch/examples/py_example.py +21 -0
- PyPatchMatch/examples/py_example_global_mask.py +27 -0
- PyPatchMatch/patch_match.py +263 -0
- PyPatchMatch/travis.sh +9 -0
- README.md +13 -0
- app.py +1043 -0
- canvas.py +648 -0
- config.yaml +18 -0
- convert_checkpoint.py +706 -0
- css/w2ui.min.css +0 -0
- index.html +482 -0
- interrogate.py +125 -0
- js/fabric.min.js +0 -0
- js/keyboard.js +37 -0
- js/mode.js +6 -0
- js/outpaint.js +23 -0
- js/proceed.js +65 -0
- js/setup.js +28 -0
- js/toolbar.js +674 -0
- js/upload.js +19 -0
- js/w2ui.min.js +0 -0
- js/xss.js +32 -0
- models/v1-inference.yaml +70 -0
- models/v1-inpainting-inference.yaml +70 -0
- opencv.pc +11 -0
- packages.txt +5 -0
- perlin2d.py +45 -0
- postprocess.py +249 -0
- process.py +395 -0
- requirements.txt +20 -0
- utils.py +274 -0
.gitattributes
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
PyPatchMatch/.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/build/
|
| 2 |
+
/*.so
|
| 3 |
+
__pycache__
|
| 4 |
+
*.py[cod]
|
PyPatchMatch/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2020 Jiayuan Mao
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
PyPatchMatch/Makefile
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Makefile
|
| 3 |
+
# Jiayuan Mao, 2019-01-09 13:59
|
| 4 |
+
#
|
| 5 |
+
|
| 6 |
+
SRC_DIR = csrc
|
| 7 |
+
INC_DIR = csrc
|
| 8 |
+
OBJ_DIR = build/obj
|
| 9 |
+
TARGET = libpatchmatch.so
|
| 10 |
+
|
| 11 |
+
LIB_TARGET = $(TARGET)
|
| 12 |
+
INCLUDE_DIR = -I $(SRC_DIR) -I $(INC_DIR)
|
| 13 |
+
|
| 14 |
+
CXX = $(ENVIRONMENT_OPTIONS) g++
|
| 15 |
+
CXXFLAGS = -std=c++14
|
| 16 |
+
CXXFLAGS += -Ofast -ffast-math -w
|
| 17 |
+
# CXXFLAGS += -g
|
| 18 |
+
CXXFLAGS += $(shell pkg-config --cflags opencv) -fPIC
|
| 19 |
+
CXXFLAGS += $(INCLUDE_DIR)
|
| 20 |
+
LDFLAGS = $(shell pkg-config --cflags --libs opencv) -shared -fPIC
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
CXXSOURCES = $(shell find $(SRC_DIR)/ -name "*.cpp")
|
| 24 |
+
OBJS = $(addprefix $(OBJ_DIR)/,$(CXXSOURCES:.cpp=.o))
|
| 25 |
+
DEPFILES = $(OBJS:.o=.d)
|
| 26 |
+
|
| 27 |
+
.PHONY: all clean rebuild test
|
| 28 |
+
|
| 29 |
+
all: $(LIB_TARGET)
|
| 30 |
+
|
| 31 |
+
$(OBJ_DIR)/%.o: %.cpp
|
| 32 |
+
@echo "[CC] $< ..."
|
| 33 |
+
@$(CXX) -c $< $(CXXFLAGS) -o $@
|
| 34 |
+
|
| 35 |
+
$(OBJ_DIR)/%.d: %.cpp
|
| 36 |
+
@mkdir -pv $(dir $@)
|
| 37 |
+
@echo "[dep] $< ..."
|
| 38 |
+
@$(CXX) $(INCLUDE_DIR) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cpp=.o) $(OBJ_DIR)/$(<:.cpp=.d)" "$<" > "$@"
|
| 39 |
+
|
| 40 |
+
sinclude $(DEPFILES)
|
| 41 |
+
|
| 42 |
+
$(LIB_TARGET): $(OBJS)
|
| 43 |
+
@echo "[link] $(LIB_TARGET) ..."
|
| 44 |
+
@$(CXX) $(OBJS) -o $@ $(CXXFLAGS) $(LDFLAGS)
|
| 45 |
+
|
| 46 |
+
clean:
|
| 47 |
+
rm -rf $(OBJ_DIR) $(LIB_TARGET)
|
| 48 |
+
|
| 49 |
+
rebuild:
|
| 50 |
+
+@make clean
|
| 51 |
+
+@make
|
| 52 |
+
|
| 53 |
+
# vim:ft=make
|
| 54 |
+
#
|
PyPatchMatch/README.md
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PatchMatch based Inpainting
|
| 2 |
+
=====================================
|
| 3 |
+
This library implements the PatchMatch based inpainting algorithm. It provides both C++ and Python interfaces.
|
| 4 |
+
This implementation is heavily based on the implementation by Younesse ANDAM:
|
| 5 |
+
(younesse-cv/PatchMatch)[https://github.com/younesse-cv/PatchMatch], with some bugs fix.
|
| 6 |
+
|
| 7 |
+
Usage
|
| 8 |
+
-------------------------------------
|
| 9 |
+
|
| 10 |
+
You need to first install OpenCV to compile the C++ libraries. Then, run `make` to compile the
|
| 11 |
+
shared library `libpatchmatch.so`.
|
| 12 |
+
|
| 13 |
+
For Python users (example available at `examples/py_example.py`)
|
| 14 |
+
|
| 15 |
+
```python
|
| 16 |
+
import patch_match
|
| 17 |
+
|
| 18 |
+
image = ... # either a numpy ndarray or a PIL Image object.
|
| 19 |
+
mask = ... # either a numpy ndarray or a PIL Image object.
|
| 20 |
+
result = patch_match.inpaint(image, mask, patch_size=5)
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
For C++ users (examples available at `examples/cpp_example.cpp`)
|
| 24 |
+
|
| 25 |
+
```cpp
|
| 26 |
+
#include "inpaint.h"
|
| 27 |
+
|
| 28 |
+
int main() {
|
| 29 |
+
cv::Mat image = ...
|
| 30 |
+
cv::Mat mask = ...
|
| 31 |
+
|
| 32 |
+
cv::Mat result = Inpainting(image, mask, 5).run();
|
| 33 |
+
|
| 34 |
+
return 0;
|
| 35 |
+
}
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
README and COPYRIGHT by Younesse ANDAM
|
| 40 |
+
-------------------------------------
|
| 41 |
+
@Author: Younesse ANDAM
|
| 42 |
+
|
| 43 |
+
@Contact: younesse.andam@gmail.com
|
| 44 |
+
|
| 45 |
+
Description: This project is a personal implementation of an algorithm called PATCHMATCH that restores missing areas in an image.
|
| 46 |
+
The algorithm is presented in the following paper
|
| 47 |
+
PatchMatch A Randomized Correspondence Algorithm
|
| 48 |
+
for Structural Image Editing
|
| 49 |
+
by C.Barnes,E.Shechtman,A.Finkelstein and Dan B.Goldman
|
| 50 |
+
ACM Transactions on Graphics (Proc. SIGGRAPH), vol.28, aug-2009
|
| 51 |
+
|
| 52 |
+
For more information please refer to
|
| 53 |
+
http://www.cs.princeton.edu/gfx/pubs/Barnes_2009_PAR/index.php
|
| 54 |
+
|
| 55 |
+
Copyright (c) 2010-2011
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
Requirements
|
| 59 |
+
-------------------------------------
|
| 60 |
+
|
| 61 |
+
To run the project you need to install Opencv library and link it to your project.
|
| 62 |
+
Opencv can be download it here
|
| 63 |
+
http://opencv.org/downloads.html
|
| 64 |
+
|
PyPatchMatch/csrc/inpaint.cpp
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <algorithm>
|
| 2 |
+
#include <iostream>
|
| 3 |
+
#include <opencv2/imgcodecs.hpp>
|
| 4 |
+
#include <opencv2/imgproc.hpp>
|
| 5 |
+
#include <opencv2/highgui.hpp>
|
| 6 |
+
|
| 7 |
+
#include "inpaint.h"
|
| 8 |
+
|
| 9 |
+
namespace {
|
| 10 |
+
static std::vector<double> kDistance2Similarity;
|
| 11 |
+
|
| 12 |
+
void init_kDistance2Similarity() {
|
| 13 |
+
double base[11] = {1.0, 0.99, 0.96, 0.83, 0.38, 0.11, 0.02, 0.005, 0.0006, 0.0001, 0};
|
| 14 |
+
int length = (PatchDistanceMetric::kDistanceScale + 1);
|
| 15 |
+
kDistance2Similarity.resize(length);
|
| 16 |
+
for (int i = 0; i < length; ++i) {
|
| 17 |
+
double t = (double) i / length;
|
| 18 |
+
int j = (int) (100 * t);
|
| 19 |
+
int k = j + 1;
|
| 20 |
+
double vj = (j < 11) ? base[j] : 0;
|
| 21 |
+
double vk = (k < 11) ? base[k] : 0;
|
| 22 |
+
kDistance2Similarity[i] = vj + (100 * t - j) * (vk - vj);
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
inline void _weighted_copy(const MaskedImage &source, int ys, int xs, cv::Mat &target, int yt, int xt, double weight) {
|
| 28 |
+
if (source.is_masked(ys, xs)) return;
|
| 29 |
+
if (source.is_globally_masked(ys, xs)) return;
|
| 30 |
+
|
| 31 |
+
auto source_ptr = source.get_image(ys, xs);
|
| 32 |
+
auto target_ptr = target.ptr<double>(yt, xt);
|
| 33 |
+
|
| 34 |
+
#pragma unroll
|
| 35 |
+
for (int c = 0; c < 3; ++c)
|
| 36 |
+
target_ptr[c] += static_cast<double>(source_ptr[c]) * weight;
|
| 37 |
+
target_ptr[3] += weight;
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
/**
|
| 42 |
+
* This algorithme uses a version proposed by Xavier Philippeau.
|
| 43 |
+
*/
|
| 44 |
+
|
| 45 |
+
Inpainting::Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric)
|
| 46 |
+
: m_initial(image, mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
|
| 47 |
+
_initialize_pyramid();
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
Inpainting::Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric)
|
| 51 |
+
: m_initial(image, mask, global_mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
|
| 52 |
+
_initialize_pyramid();
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
void Inpainting::_initialize_pyramid() {
|
| 56 |
+
auto source = m_initial;
|
| 57 |
+
m_pyramid.push_back(source);
|
| 58 |
+
while (source.size().height > m_distance_metric->patch_size() && source.size().width > m_distance_metric->patch_size()) {
|
| 59 |
+
source = source.downsample();
|
| 60 |
+
m_pyramid.push_back(source);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
if (kDistance2Similarity.size() == 0) {
|
| 64 |
+
init_kDistance2Similarity();
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
cv::Mat Inpainting::run(bool verbose, bool verbose_visualize, unsigned int random_seed) {
|
| 69 |
+
srand(random_seed);
|
| 70 |
+
const int nr_levels = m_pyramid.size();
|
| 71 |
+
|
| 72 |
+
MaskedImage source, target;
|
| 73 |
+
for (int level = nr_levels - 1; level >= 0; --level) {
|
| 74 |
+
if (verbose) std::cerr << "Inpainting level: " << level << std::endl;
|
| 75 |
+
|
| 76 |
+
source = m_pyramid[level];
|
| 77 |
+
|
| 78 |
+
if (level == nr_levels - 1) {
|
| 79 |
+
target = source.clone();
|
| 80 |
+
target.clear_mask();
|
| 81 |
+
m_source2target = NearestNeighborField(source, target, m_distance_metric);
|
| 82 |
+
m_target2source = NearestNeighborField(target, source, m_distance_metric);
|
| 83 |
+
} else {
|
| 84 |
+
m_source2target = NearestNeighborField(source, target, m_distance_metric, m_source2target);
|
| 85 |
+
m_target2source = NearestNeighborField(target, source, m_distance_metric, m_target2source);
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
if (verbose) std::cerr << "Initialization done." << std::endl;
|
| 89 |
+
|
| 90 |
+
if (verbose_visualize) {
|
| 91 |
+
auto visualize_size = m_initial.size();
|
| 92 |
+
cv::Mat source_visualize(visualize_size, m_initial.image().type());
|
| 93 |
+
cv::resize(source.image(), source_visualize, visualize_size);
|
| 94 |
+
cv::imshow("Source", source_visualize);
|
| 95 |
+
cv::Mat target_visualize(visualize_size, m_initial.image().type());
|
| 96 |
+
cv::resize(target.image(), target_visualize, visualize_size);
|
| 97 |
+
cv::imshow("Target", target_visualize);
|
| 98 |
+
cv::waitKey(0);
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
target = _expectation_maximization(source, target, level, verbose);
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
return target.image();
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
// EM-Like algorithm (see "PatchMatch" - page 6).
|
| 108 |
+
// Returns a double sized target image (unless level = 0).
|
| 109 |
+
MaskedImage Inpainting::_expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose) {
|
| 110 |
+
const int nr_iters_em = 1 + 2 * level;
|
| 111 |
+
const int nr_iters_nnf = static_cast<int>(std::min(7, 1 + level));
|
| 112 |
+
const int patch_size = m_distance_metric->patch_size();
|
| 113 |
+
|
| 114 |
+
MaskedImage new_source, new_target;
|
| 115 |
+
|
| 116 |
+
for (int iter_em = 0; iter_em < nr_iters_em; ++iter_em) {
|
| 117 |
+
if (iter_em != 0) {
|
| 118 |
+
m_source2target.set_target(new_target);
|
| 119 |
+
m_target2source.set_source(new_target);
|
| 120 |
+
target = new_target;
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
if (verbose) std::cerr << "EM Iteration: " << iter_em << std::endl;
|
| 124 |
+
|
| 125 |
+
auto size = source.size();
|
| 126 |
+
for (int i = 0; i < size.height; ++i) {
|
| 127 |
+
for (int j = 0; j < size.width; ++j) {
|
| 128 |
+
if (!source.contains_mask(i, j, patch_size)) {
|
| 129 |
+
m_source2target.set_identity(i, j);
|
| 130 |
+
m_target2source.set_identity(i, j);
|
| 131 |
+
}
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
if (verbose) std::cerr << " NNF minimization started." << std::endl;
|
| 135 |
+
m_source2target.minimize(nr_iters_nnf);
|
| 136 |
+
m_target2source.minimize(nr_iters_nnf);
|
| 137 |
+
if (verbose) std::cerr << " NNF minimization finished." << std::endl;
|
| 138 |
+
|
| 139 |
+
// Instead of upsizing the final target, we build the last target from the next level source image.
|
| 140 |
+
// Thus, the final target is less blurry (see "Space-Time Video Completion" - page 5).
|
| 141 |
+
bool upscaled = false;
|
| 142 |
+
if (level >= 1 && iter_em == nr_iters_em - 1) {
|
| 143 |
+
new_source = m_pyramid[level - 1];
|
| 144 |
+
new_target = target.upsample(new_source.size().width, new_source.size().height, m_pyramid[level - 1].global_mask());
|
| 145 |
+
upscaled = true;
|
| 146 |
+
} else {
|
| 147 |
+
new_source = m_pyramid[level];
|
| 148 |
+
new_target = target.clone();
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
auto vote = cv::Mat(new_target.size(), CV_64FC4);
|
| 152 |
+
vote.setTo(cv::Scalar::all(0));
|
| 153 |
+
|
| 154 |
+
// Votes for best patch from NNF Source->Target (completeness) and Target->Source (coherence).
|
| 155 |
+
_expectation_step(m_source2target, 1, vote, new_source, upscaled);
|
| 156 |
+
if (verbose) std::cerr << " Expectation source to target finished." << std::endl;
|
| 157 |
+
_expectation_step(m_target2source, 0, vote, new_source, upscaled);
|
| 158 |
+
if (verbose) std::cerr << " Expectation target to source finished." << std::endl;
|
| 159 |
+
|
| 160 |
+
// Compile votes and update pixel values.
|
| 161 |
+
_maximization_step(new_target, vote);
|
| 162 |
+
if (verbose) std::cerr << " Minimization step finished." << std::endl;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
return new_target;
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
// Expectation step: vote for best estimations of each pixel.
|
| 169 |
+
void Inpainting::_expectation_step(
|
| 170 |
+
const NearestNeighborField &nnf, bool source2target,
|
| 171 |
+
cv::Mat &vote, const MaskedImage &source, bool upscaled
|
| 172 |
+
) {
|
| 173 |
+
auto source_size = nnf.source_size();
|
| 174 |
+
auto target_size = nnf.target_size();
|
| 175 |
+
const int patch_size = m_distance_metric->patch_size();
|
| 176 |
+
|
| 177 |
+
for (int i = 0; i < source_size.height; ++i) {
|
| 178 |
+
for (int j = 0; j < source_size.width; ++j) {
|
| 179 |
+
if (nnf.source().is_globally_masked(i, j)) continue;
|
| 180 |
+
int yp = nnf.at(i, j, 0), xp = nnf.at(i, j, 1), dp = nnf.at(i, j, 2);
|
| 181 |
+
double w = kDistance2Similarity[dp];
|
| 182 |
+
|
| 183 |
+
for (int di = -patch_size; di <= patch_size; ++di) {
|
| 184 |
+
for (int dj = -patch_size; dj <= patch_size; ++dj) {
|
| 185 |
+
int ys = i + di, xs = j + dj, yt = yp + di, xt = xp + dj;
|
| 186 |
+
if (!(ys >= 0 && ys < source_size.height && xs >= 0 && xs < source_size.width)) continue;
|
| 187 |
+
if (nnf.source().is_globally_masked(ys, xs)) continue;
|
| 188 |
+
if (!(yt >= 0 && yt < target_size.height && xt >= 0 && xt < target_size.width)) continue;
|
| 189 |
+
if (nnf.target().is_globally_masked(yt, xt)) continue;
|
| 190 |
+
|
| 191 |
+
if (!source2target) {
|
| 192 |
+
std::swap(ys, yt);
|
| 193 |
+
std::swap(xs, xt);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
if (upscaled) {
|
| 197 |
+
for (int uy = 0; uy < 2; ++uy) {
|
| 198 |
+
for (int ux = 0; ux < 2; ++ux) {
|
| 199 |
+
_weighted_copy(source, 2 * ys + uy, 2 * xs + ux, vote, 2 * yt + uy, 2 * xt + ux, w);
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
} else {
|
| 203 |
+
_weighted_copy(source, ys, xs, vote, yt, xt, w);
|
| 204 |
+
}
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
// Maximization Step: maximum likelihood of target pixel.
|
| 212 |
+
void Inpainting::_maximization_step(MaskedImage &target, const cv::Mat &vote) {
|
| 213 |
+
auto target_size = target.size();
|
| 214 |
+
for (int i = 0; i < target_size.height; ++i) {
|
| 215 |
+
for (int j = 0; j < target_size.width; ++j) {
|
| 216 |
+
const double *source_ptr = vote.ptr<double>(i, j);
|
| 217 |
+
unsigned char *target_ptr = target.get_mutable_image(i, j);
|
| 218 |
+
|
| 219 |
+
if (target.is_globally_masked(i, j)) {
|
| 220 |
+
continue;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
if (source_ptr[3] > 0) {
|
| 224 |
+
unsigned char r = cv::saturate_cast<unsigned char>(source_ptr[0] / source_ptr[3]);
|
| 225 |
+
unsigned char g = cv::saturate_cast<unsigned char>(source_ptr[1] / source_ptr[3]);
|
| 226 |
+
unsigned char b = cv::saturate_cast<unsigned char>(source_ptr[2] / source_ptr[3]);
|
| 227 |
+
target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
|
| 228 |
+
} else {
|
| 229 |
+
target.set_mask(i, j, 0);
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
|
PyPatchMatch/csrc/inpaint.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <vector>
|
| 4 |
+
|
| 5 |
+
#include "masked_image.h"
|
| 6 |
+
#include "nnf.h"
|
| 7 |
+
|
| 8 |
+
class Inpainting {
|
| 9 |
+
public:
|
| 10 |
+
Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric);
|
| 11 |
+
Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric);
|
| 12 |
+
cv::Mat run(bool verbose = false, bool verbose_visualize = false, unsigned int random_seed = 1212);
|
| 13 |
+
|
| 14 |
+
private:
|
| 15 |
+
void _initialize_pyramid(void);
|
| 16 |
+
MaskedImage _expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose);
|
| 17 |
+
void _expectation_step(const NearestNeighborField &nnf, bool source2target, cv::Mat &vote, const MaskedImage &source, bool upscaled);
|
| 18 |
+
void _maximization_step(MaskedImage &target, const cv::Mat &vote);
|
| 19 |
+
|
| 20 |
+
MaskedImage m_initial;
|
| 21 |
+
std::vector<MaskedImage> m_pyramid;
|
| 22 |
+
|
| 23 |
+
NearestNeighborField m_source2target;
|
| 24 |
+
NearestNeighborField m_target2source;
|
| 25 |
+
const PatchDistanceMetric *m_distance_metric;
|
| 26 |
+
};
|
| 27 |
+
|
PyPatchMatch/csrc/masked_image.cpp
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "masked_image.h"
|
| 2 |
+
#include <algorithm>
|
| 3 |
+
#include <iostream>
|
| 4 |
+
|
| 5 |
+
const cv::Size MaskedImage::kDownsampleKernelSize = cv::Size(6, 6);
|
| 6 |
+
const int MaskedImage::kDownsampleKernel[6] = {1, 5, 10, 10, 5, 1};
|
| 7 |
+
|
| 8 |
+
bool MaskedImage::contains_mask(int y, int x, int patch_size) const {
|
| 9 |
+
auto mask_size = size();
|
| 10 |
+
for (int dy = -patch_size; dy <= patch_size; ++dy) {
|
| 11 |
+
for (int dx = -patch_size; dx <= patch_size; ++dx) {
|
| 12 |
+
int yy = y + dy, xx = x + dx;
|
| 13 |
+
if (yy >= 0 && yy < mask_size.height && xx >= 0 && xx < mask_size.width) {
|
| 14 |
+
if (is_masked(yy, xx) && !is_globally_masked(yy, xx)) return true;
|
| 15 |
+
}
|
| 16 |
+
}
|
| 17 |
+
}
|
| 18 |
+
return false;
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
MaskedImage MaskedImage::downsample() const {
|
| 22 |
+
const auto &kernel_size = MaskedImage::kDownsampleKernelSize;
|
| 23 |
+
const auto &kernel = MaskedImage::kDownsampleKernel;
|
| 24 |
+
|
| 25 |
+
const auto size = this->size();
|
| 26 |
+
const auto new_size = cv::Size(size.width / 2, size.height / 2);
|
| 27 |
+
|
| 28 |
+
auto ret = MaskedImage(new_size.width, new_size.height);
|
| 29 |
+
if (!m_global_mask.empty()) ret.init_global_mask_mat();
|
| 30 |
+
for (int y = 0; y < size.height - 1; y += 2) {
|
| 31 |
+
for (int x = 0; x < size.width - 1; x += 2) {
|
| 32 |
+
int r = 0, g = 0, b = 0, ksum = 0;
|
| 33 |
+
bool is_gmasked = true;
|
| 34 |
+
|
| 35 |
+
for (int dy = -kernel_size.height / 2 + 1; dy <= kernel_size.height / 2; ++dy) {
|
| 36 |
+
for (int dx = -kernel_size.width / 2 + 1; dx <= kernel_size.width / 2; ++dx) {
|
| 37 |
+
int yy = y + dy, xx = x + dx;
|
| 38 |
+
if (yy >= 0 && yy < size.height && xx >= 0 && xx < size.width) {
|
| 39 |
+
if (!is_globally_masked(yy, xx)) {
|
| 40 |
+
is_gmasked = false;
|
| 41 |
+
}
|
| 42 |
+
if (!is_masked(yy, xx)) {
|
| 43 |
+
auto source_ptr = get_image(yy, xx);
|
| 44 |
+
int k = kernel[kernel_size.height / 2 - 1 + dy] * kernel[kernel_size.width / 2 - 1 + dx];
|
| 45 |
+
r += source_ptr[0] * k, g += source_ptr[1] * k, b += source_ptr[2] * k;
|
| 46 |
+
ksum += k;
|
| 47 |
+
}
|
| 48 |
+
}
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
if (ksum > 0) r /= ksum, g /= ksum, b /= ksum;
|
| 53 |
+
|
| 54 |
+
if (!m_global_mask.empty()) {
|
| 55 |
+
ret.set_global_mask(y / 2, x / 2, is_gmasked);
|
| 56 |
+
}
|
| 57 |
+
if (ksum > 0) {
|
| 58 |
+
auto target_ptr = ret.get_mutable_image(y / 2, x / 2);
|
| 59 |
+
target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
|
| 60 |
+
ret.set_mask(y / 2, x / 2, 0);
|
| 61 |
+
} else {
|
| 62 |
+
ret.set_mask(y / 2, x / 2, 1);
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
return ret;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
MaskedImage MaskedImage::upsample(int new_w, int new_h) const {
|
| 71 |
+
const auto size = this->size();
|
| 72 |
+
auto ret = MaskedImage(new_w, new_h);
|
| 73 |
+
if (!m_global_mask.empty()) ret.init_global_mask_mat();
|
| 74 |
+
for (int y = 0; y < new_h; ++y) {
|
| 75 |
+
for (int x = 0; x < new_w; ++x) {
|
| 76 |
+
int yy = y * size.height / new_h;
|
| 77 |
+
int xx = x * size.width / new_w;
|
| 78 |
+
|
| 79 |
+
if (is_globally_masked(yy, xx)) {
|
| 80 |
+
ret.set_global_mask(y, x, 1);
|
| 81 |
+
ret.set_mask(y, x, 1);
|
| 82 |
+
} else {
|
| 83 |
+
if (!m_global_mask.empty()) ret.set_global_mask(y, x, 0);
|
| 84 |
+
|
| 85 |
+
if (is_masked(yy, xx)) {
|
| 86 |
+
ret.set_mask(y, x, 1);
|
| 87 |
+
} else {
|
| 88 |
+
auto source_ptr = get_image(yy, xx);
|
| 89 |
+
auto target_ptr = ret.get_mutable_image(y, x);
|
| 90 |
+
for (int c = 0; c < 3; ++c)
|
| 91 |
+
target_ptr[c] = source_ptr[c];
|
| 92 |
+
ret.set_mask(y, x, 0);
|
| 93 |
+
}
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
return ret;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
MaskedImage MaskedImage::upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const {
|
| 102 |
+
auto ret = upsample(new_w, new_h);
|
| 103 |
+
ret.set_global_mask_mat(new_global_mask);
|
| 104 |
+
return ret;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
void MaskedImage::compute_image_gradients() {
|
| 108 |
+
if (m_image_grad_computed) {
|
| 109 |
+
return;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
const auto size = m_image.size();
|
| 113 |
+
m_image_grady = cv::Mat(size, CV_8UC3);
|
| 114 |
+
m_image_gradx = cv::Mat(size, CV_8UC3);
|
| 115 |
+
m_image_grady = cv::Scalar::all(0);
|
| 116 |
+
m_image_gradx = cv::Scalar::all(0);
|
| 117 |
+
|
| 118 |
+
for (int i = 1; i < size.height - 1; ++i) {
|
| 119 |
+
const auto *ptr = m_image.ptr<unsigned char>(i, 0);
|
| 120 |
+
const auto *ptry1 = m_image.ptr<unsigned char>(i + 1, 0);
|
| 121 |
+
const auto *ptry2 = m_image.ptr<unsigned char>(i - 1, 0);
|
| 122 |
+
const auto *ptrx1 = m_image.ptr<unsigned char>(i, 0) + 3;
|
| 123 |
+
const auto *ptrx2 = m_image.ptr<unsigned char>(i, 0) - 3;
|
| 124 |
+
auto *mptry = m_image_grady.ptr<unsigned char>(i, 0);
|
| 125 |
+
auto *mptrx = m_image_gradx.ptr<unsigned char>(i, 0);
|
| 126 |
+
for (int j = 3; j < size.width * 3 - 3; ++j) {
|
| 127 |
+
mptry[j] = (ptry1[j] / 2 - ptry2[j] / 2) + 128;
|
| 128 |
+
mptrx[j] = (ptrx1[j] / 2 - ptrx2[j] / 2) + 128;
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
m_image_grad_computed = true;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
void MaskedImage::compute_image_gradients() const {
|
| 136 |
+
const_cast<MaskedImage *>(this)->compute_image_gradients();
|
| 137 |
+
}
|
| 138 |
+
|
PyPatchMatch/csrc/masked_image.h
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <opencv2/core.hpp>
|
| 4 |
+
|
| 5 |
+
class MaskedImage {
|
| 6 |
+
public:
|
| 7 |
+
MaskedImage() : m_image(), m_mask(), m_global_mask(), m_image_grady(), m_image_gradx(), m_image_grad_computed(false) {
|
| 8 |
+
// pass
|
| 9 |
+
}
|
| 10 |
+
MaskedImage(cv::Mat image, cv::Mat mask) : m_image(image), m_mask(mask), m_image_grad_computed(false) {
|
| 11 |
+
// pass
|
| 12 |
+
}
|
| 13 |
+
MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask) : m_image(image), m_mask(mask), m_global_mask(global_mask), m_image_grad_computed(false) {
|
| 14 |
+
// pass
|
| 15 |
+
}
|
| 16 |
+
MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask, cv::Mat grady, cv::Mat gradx, bool grad_computed) :
|
| 17 |
+
m_image(image), m_mask(mask), m_global_mask(global_mask),
|
| 18 |
+
m_image_grady(grady), m_image_gradx(gradx), m_image_grad_computed(grad_computed) {
|
| 19 |
+
// pass
|
| 20 |
+
}
|
| 21 |
+
MaskedImage(int width, int height) : m_global_mask(), m_image_grady(), m_image_gradx() {
|
| 22 |
+
m_image = cv::Mat(cv::Size(width, height), CV_8UC3);
|
| 23 |
+
m_image = cv::Scalar::all(0);
|
| 24 |
+
|
| 25 |
+
m_mask = cv::Mat(cv::Size(width, height), CV_8U);
|
| 26 |
+
m_mask = cv::Scalar::all(0);
|
| 27 |
+
}
|
| 28 |
+
inline MaskedImage clone() {
|
| 29 |
+
return MaskedImage(
|
| 30 |
+
m_image.clone(), m_mask.clone(), m_global_mask.clone(),
|
| 31 |
+
m_image_grady.clone(), m_image_gradx.clone(), m_image_grad_computed
|
| 32 |
+
);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
inline cv::Size size() const {
|
| 36 |
+
return m_image.size();
|
| 37 |
+
}
|
| 38 |
+
inline const cv::Mat &image() const {
|
| 39 |
+
return m_image;
|
| 40 |
+
}
|
| 41 |
+
inline const cv::Mat &mask() const {
|
| 42 |
+
return m_mask;
|
| 43 |
+
}
|
| 44 |
+
inline const cv::Mat &global_mask() const {
|
| 45 |
+
return m_global_mask;
|
| 46 |
+
}
|
| 47 |
+
inline const cv::Mat &grady() const {
|
| 48 |
+
assert(m_image_grad_computed);
|
| 49 |
+
return m_image_grady;
|
| 50 |
+
}
|
| 51 |
+
inline const cv::Mat &gradx() const {
|
| 52 |
+
assert(m_image_grad_computed);
|
| 53 |
+
return m_image_gradx;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
inline void init_global_mask_mat() {
|
| 57 |
+
m_global_mask = cv::Mat(m_mask.size(), CV_8U);
|
| 58 |
+
m_global_mask.setTo(cv::Scalar(0));
|
| 59 |
+
}
|
| 60 |
+
inline void set_global_mask_mat(const cv::Mat &other) {
|
| 61 |
+
m_global_mask = other;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
inline bool is_masked(int y, int x) const {
|
| 65 |
+
return static_cast<bool>(m_mask.at<unsigned char>(y, x));
|
| 66 |
+
}
|
| 67 |
+
inline bool is_globally_masked(int y, int x) const {
|
| 68 |
+
return !m_global_mask.empty() && static_cast<bool>(m_global_mask.at<unsigned char>(y, x));
|
| 69 |
+
}
|
| 70 |
+
inline void set_mask(int y, int x, bool value) {
|
| 71 |
+
m_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
|
| 72 |
+
}
|
| 73 |
+
inline void set_global_mask(int y, int x, bool value) {
|
| 74 |
+
m_global_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
|
| 75 |
+
}
|
| 76 |
+
inline void clear_mask() {
|
| 77 |
+
m_mask.setTo(cv::Scalar(0));
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
inline const unsigned char *get_image(int y, int x) const {
|
| 81 |
+
return m_image.ptr<unsigned char>(y, x);
|
| 82 |
+
}
|
| 83 |
+
inline unsigned char *get_mutable_image(int y, int x) {
|
| 84 |
+
return m_image.ptr<unsigned char>(y, x);
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
inline unsigned char get_image(int y, int x, int c) const {
|
| 88 |
+
return m_image.ptr<unsigned char>(y, x)[c];
|
| 89 |
+
}
|
| 90 |
+
inline int get_image_int(int y, int x, int c) const {
|
| 91 |
+
return static_cast<int>(m_image.ptr<unsigned char>(y, x)[c]);
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
bool contains_mask(int y, int x, int patch_size) const;
|
| 95 |
+
MaskedImage downsample() const;
|
| 96 |
+
MaskedImage upsample(int new_w, int new_h) const;
|
| 97 |
+
MaskedImage upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const;
|
| 98 |
+
void compute_image_gradients();
|
| 99 |
+
void compute_image_gradients() const;
|
| 100 |
+
|
| 101 |
+
static const cv::Size kDownsampleKernelSize;
|
| 102 |
+
static const int kDownsampleKernel[6];
|
| 103 |
+
|
| 104 |
+
private:
|
| 105 |
+
cv::Mat m_image;
|
| 106 |
+
cv::Mat m_mask;
|
| 107 |
+
cv::Mat m_global_mask;
|
| 108 |
+
cv::Mat m_image_grady;
|
| 109 |
+
cv::Mat m_image_gradx;
|
| 110 |
+
bool m_image_grad_computed = false;
|
| 111 |
+
};
|
| 112 |
+
|
PyPatchMatch/csrc/nnf.cpp
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <algorithm>
|
| 2 |
+
#include <iostream>
|
| 3 |
+
#include <cmath>
|
| 4 |
+
|
| 5 |
+
#include "masked_image.h"
|
| 6 |
+
#include "nnf.h"
|
| 7 |
+
|
| 8 |
+
/**
|
| 9 |
+
* Nearest-Neighbor Field (see PatchMatch algorithm).
|
| 10 |
+
* This algorithme uses a version proposed by Xavier Philippeau.
|
| 11 |
+
*
|
| 12 |
+
*/
|
| 13 |
+
|
| 14 |
+
template <typename T>
|
| 15 |
+
T clamp(T value, T min_value, T max_value) {
|
| 16 |
+
return std::min(std::max(value, min_value), max_value);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
void NearestNeighborField::_randomize_field(int max_retry, bool reset) {
|
| 20 |
+
auto this_size = source_size();
|
| 21 |
+
for (int i = 0; i < this_size.height; ++i) {
|
| 22 |
+
for (int j = 0; j < this_size.width; ++j) {
|
| 23 |
+
if (m_source.is_globally_masked(i, j)) continue;
|
| 24 |
+
|
| 25 |
+
auto this_ptr = mutable_ptr(i, j);
|
| 26 |
+
int distance = reset ? PatchDistanceMetric::kDistanceScale : this_ptr[2];
|
| 27 |
+
if (distance < PatchDistanceMetric::kDistanceScale) {
|
| 28 |
+
continue;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
int i_target = 0, j_target = 0;
|
| 32 |
+
for (int t = 0; t < max_retry; ++t) {
|
| 33 |
+
i_target = rand() % this_size.height;
|
| 34 |
+
j_target = rand() % this_size.width;
|
| 35 |
+
if (m_target.is_globally_masked(i_target, j_target)) continue;
|
| 36 |
+
|
| 37 |
+
distance = _distance(i, j, i_target, j_target);
|
| 38 |
+
if (distance < PatchDistanceMetric::kDistanceScale)
|
| 39 |
+
break;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
this_ptr[0] = i_target, this_ptr[1] = j_target, this_ptr[2] = distance;
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
void NearestNeighborField::_initialize_field_from(const NearestNeighborField &other, int max_retry) {
|
| 48 |
+
const auto &this_size = source_size();
|
| 49 |
+
const auto &other_size = other.source_size();
|
| 50 |
+
double fi = static_cast<double>(this_size.height) / other_size.height;
|
| 51 |
+
double fj = static_cast<double>(this_size.width) / other_size.width;
|
| 52 |
+
|
| 53 |
+
for (int i = 0; i < this_size.height; ++i) {
|
| 54 |
+
for (int j = 0; j < this_size.width; ++j) {
|
| 55 |
+
if (m_source.is_globally_masked(i, j)) continue;
|
| 56 |
+
|
| 57 |
+
int ilow = static_cast<int>(std::min(i / fi, static_cast<double>(other_size.height - 1)));
|
| 58 |
+
int jlow = static_cast<int>(std::min(j / fj, static_cast<double>(other_size.width - 1)));
|
| 59 |
+
auto this_value = mutable_ptr(i, j);
|
| 60 |
+
auto other_value = other.ptr(ilow, jlow);
|
| 61 |
+
|
| 62 |
+
this_value[0] = static_cast<int>(other_value[0] * fi);
|
| 63 |
+
this_value[1] = static_cast<int>(other_value[1] * fj);
|
| 64 |
+
this_value[2] = _distance(i, j, this_value[0], this_value[1]);
|
| 65 |
+
}
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
_randomize_field(max_retry, false);
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
void NearestNeighborField::minimize(int nr_pass) {
|
| 72 |
+
const auto &this_size = source_size();
|
| 73 |
+
while (nr_pass--) {
|
| 74 |
+
for (int i = 0; i < this_size.height; ++i)
|
| 75 |
+
for (int j = 0; j < this_size.width; ++j) {
|
| 76 |
+
if (m_source.is_globally_masked(i, j)) continue;
|
| 77 |
+
if (at(i, j, 2) > 0) _minimize_link(i, j, +1);
|
| 78 |
+
}
|
| 79 |
+
for (int i = this_size.height - 1; i >= 0; --i)
|
| 80 |
+
for (int j = this_size.width - 1; j >= 0; --j) {
|
| 81 |
+
if (m_source.is_globally_masked(i, j)) continue;
|
| 82 |
+
if (at(i, j, 2) > 0) _minimize_link(i, j, -1);
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
void NearestNeighborField::_minimize_link(int y, int x, int direction) {
|
| 88 |
+
const auto &this_size = source_size();
|
| 89 |
+
const auto &this_target_size = target_size();
|
| 90 |
+
auto this_ptr = mutable_ptr(y, x);
|
| 91 |
+
|
| 92 |
+
// propagation along the y direction.
|
| 93 |
+
if (y - direction >= 0 && y - direction < this_size.height && !m_source.is_globally_masked(y - direction, x)) {
|
| 94 |
+
int yp = at(y - direction, x, 0) + direction;
|
| 95 |
+
int xp = at(y - direction, x, 1);
|
| 96 |
+
int dp = _distance(y, x, yp, xp);
|
| 97 |
+
if (dp < at(y, x, 2)) {
|
| 98 |
+
this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
|
| 99 |
+
}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
// propagation along the x direction.
|
| 103 |
+
if (x - direction >= 0 && x - direction < this_size.width && !m_source.is_globally_masked(y, x - direction)) {
|
| 104 |
+
int yp = at(y, x - direction, 0);
|
| 105 |
+
int xp = at(y, x - direction, 1) + direction;
|
| 106 |
+
int dp = _distance(y, x, yp, xp);
|
| 107 |
+
if (dp < at(y, x, 2)) {
|
| 108 |
+
this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
// random search with a progressive step size.
|
| 113 |
+
int random_scale = (std::min(this_target_size.height, this_target_size.width) - 1) / 2;
|
| 114 |
+
while (random_scale > 0) {
|
| 115 |
+
int yp = this_ptr[0] + (rand() % (2 * random_scale + 1) - random_scale);
|
| 116 |
+
int xp = this_ptr[1] + (rand() % (2 * random_scale + 1) - random_scale);
|
| 117 |
+
yp = clamp(yp, 0, target_size().height - 1);
|
| 118 |
+
xp = clamp(xp, 0, target_size().width - 1);
|
| 119 |
+
|
| 120 |
+
if (m_target.is_globally_masked(yp, xp)) {
|
| 121 |
+
random_scale /= 2;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
int dp = _distance(y, x, yp, xp);
|
| 125 |
+
if (dp < at(y, x, 2)) {
|
| 126 |
+
this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
|
| 127 |
+
}
|
| 128 |
+
random_scale /= 2;
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
const int PatchDistanceMetric::kDistanceScale = 65535;
|
| 133 |
+
const int PatchSSDDistanceMetric::kSSDScale = 9 * 255 * 255;
|
| 134 |
+
|
| 135 |
+
namespace {
|
| 136 |
+
|
| 137 |
+
inline int pow2(int i) {
|
| 138 |
+
return i * i;
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
int distance_masked_images(
|
| 142 |
+
const MaskedImage &source, int ys, int xs,
|
| 143 |
+
const MaskedImage &target, int yt, int xt,
|
| 144 |
+
int patch_size
|
| 145 |
+
) {
|
| 146 |
+
long double distance = 0;
|
| 147 |
+
long double wsum = 0;
|
| 148 |
+
|
| 149 |
+
source.compute_image_gradients();
|
| 150 |
+
target.compute_image_gradients();
|
| 151 |
+
|
| 152 |
+
auto source_size = source.size();
|
| 153 |
+
auto target_size = target.size();
|
| 154 |
+
|
| 155 |
+
for (int dy = -patch_size; dy <= patch_size; ++dy) {
|
| 156 |
+
const int yys = ys + dy, yyt = yt + dy;
|
| 157 |
+
|
| 158 |
+
if (yys <= 0 || yys >= source_size.height - 1 || yyt <= 0 || yyt >= target_size.height - 1) {
|
| 159 |
+
distance += (long double)(PatchSSDDistanceMetric::kSSDScale) * (2 * patch_size + 1);
|
| 160 |
+
wsum += 2 * patch_size + 1;
|
| 161 |
+
continue;
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
const auto *p_si = source.image().ptr<unsigned char>(yys, 0);
|
| 165 |
+
const auto *p_ti = target.image().ptr<unsigned char>(yyt, 0);
|
| 166 |
+
const auto *p_sm = source.mask().ptr<unsigned char>(yys, 0);
|
| 167 |
+
const auto *p_tm = target.mask().ptr<unsigned char>(yyt, 0);
|
| 168 |
+
|
| 169 |
+
const unsigned char *p_sgm = nullptr;
|
| 170 |
+
const unsigned char *p_tgm = nullptr;
|
| 171 |
+
if (!source.global_mask().empty()) {
|
| 172 |
+
p_sgm = source.global_mask().ptr<unsigned char>(yys, 0);
|
| 173 |
+
p_tgm = target.global_mask().ptr<unsigned char>(yyt, 0);
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
const auto *p_sgy = source.grady().ptr<unsigned char>(yys, 0);
|
| 177 |
+
const auto *p_tgy = target.grady().ptr<unsigned char>(yyt, 0);
|
| 178 |
+
const auto *p_sgx = source.gradx().ptr<unsigned char>(yys, 0);
|
| 179 |
+
const auto *p_tgx = target.gradx().ptr<unsigned char>(yyt, 0);
|
| 180 |
+
|
| 181 |
+
for (int dx = -patch_size; dx <= patch_size; ++dx) {
|
| 182 |
+
int xxs = xs + dx, xxt = xt + dx;
|
| 183 |
+
wsum += 1;
|
| 184 |
+
|
| 185 |
+
if (xxs <= 0 || xxs >= source_size.width - 1 || xxt <= 0 || xxt >= source_size.width - 1) {
|
| 186 |
+
distance += PatchSSDDistanceMetric::kSSDScale;
|
| 187 |
+
continue;
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
if (p_sm[xxs] || p_tm[xxt] || (p_sgm && p_sgm[xxs]) || (p_tgm && p_tgm[xxt]) ) {
|
| 191 |
+
distance += PatchSSDDistanceMetric::kSSDScale;
|
| 192 |
+
continue;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
int ssd = 0;
|
| 196 |
+
for (int c = 0; c < 3; ++c) {
|
| 197 |
+
int s_value = p_si[xxs * 3 + c];
|
| 198 |
+
int t_value = p_ti[xxt * 3 + c];
|
| 199 |
+
int s_gy = p_sgy[xxs * 3 + c];
|
| 200 |
+
int t_gy = p_tgy[xxt * 3 + c];
|
| 201 |
+
int s_gx = p_sgx[xxs * 3 + c];
|
| 202 |
+
int t_gx = p_tgx[xxt * 3 + c];
|
| 203 |
+
|
| 204 |
+
ssd += pow2(static_cast<int>(s_value) - t_value);
|
| 205 |
+
ssd += pow2(static_cast<int>(s_gx) - t_gx);
|
| 206 |
+
ssd += pow2(static_cast<int>(s_gy) - t_gy);
|
| 207 |
+
}
|
| 208 |
+
distance += ssd;
|
| 209 |
+
}
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
distance /= (long double)(PatchSSDDistanceMetric::kSSDScale);
|
| 213 |
+
|
| 214 |
+
int res = int(PatchDistanceMetric::kDistanceScale * distance / wsum);
|
| 215 |
+
if (res < 0 || res > PatchDistanceMetric::kDistanceScale) return PatchDistanceMetric::kDistanceScale;
|
| 216 |
+
return res;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
int PatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
|
| 222 |
+
return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
int DebugPatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
|
| 226 |
+
fprintf(stderr, "DebugPatchSSDDistanceMetric: %d %d %d %d\n", source.size().width, source.size().height, m_width, m_height);
|
| 227 |
+
return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
int RegularityGuidedPatchDistanceMetricV1::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
|
| 231 |
+
double dx = remainder(double(source_x - target_x) / source.size().width, m_dx1);
|
| 232 |
+
double dy = remainder(double(source_y - target_y) / source.size().height, m_dy2);
|
| 233 |
+
|
| 234 |
+
double score1 = sqrt(dx * dx + dy *dy) / m_scale;
|
| 235 |
+
if (score1 < 0 || score1 > 1) score1 = 1;
|
| 236 |
+
score1 *= PatchDistanceMetric::kDistanceScale;
|
| 237 |
+
|
| 238 |
+
double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
|
| 239 |
+
double score = score1 * m_weight + score2 / (1 + m_weight);
|
| 240 |
+
return static_cast<int>(score / (1 + m_weight));
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
int RegularityGuidedPatchDistanceMetricV2::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
|
| 244 |
+
if (target_y < 0 || target_y >= target.size().height || target_x < 0 || target_x >= target.size().width)
|
| 245 |
+
return PatchDistanceMetric::kDistanceScale;
|
| 246 |
+
|
| 247 |
+
int source_scale = m_ijmap.size().height / source.size().height;
|
| 248 |
+
int target_scale = m_ijmap.size().height / target.size().height;
|
| 249 |
+
|
| 250 |
+
// fprintf(stderr, "RegularityGuidedPatchDistanceMetricV2 %d %d %d %d\n", source_y * source_scale, m_ijmap.size().height, source_x * source_scale, m_ijmap.size().width);
|
| 251 |
+
|
| 252 |
+
double score1 = PatchDistanceMetric::kDistanceScale;
|
| 253 |
+
if (!source.is_globally_masked(source_y, source_x) && !target.is_globally_masked(target_y, target_x)) {
|
| 254 |
+
auto source_ij = m_ijmap.ptr<float>(source_y * source_scale, source_x * source_scale);
|
| 255 |
+
auto target_ij = m_ijmap.ptr<float>(target_y * target_scale, target_x * target_scale);
|
| 256 |
+
|
| 257 |
+
float di = fabs(source_ij[0] - target_ij[0]); if (di > 0.5) di = 1 - di;
|
| 258 |
+
float dj = fabs(source_ij[1] - target_ij[1]); if (dj > 0.5) dj = 1 - dj;
|
| 259 |
+
score1 = sqrt(di * di + dj *dj) / 0.707;
|
| 260 |
+
if (score1 < 0 || score1 > 1) score1 = 1;
|
| 261 |
+
score1 *= PatchDistanceMetric::kDistanceScale;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
|
| 265 |
+
double score = score1 * m_weight + score2;
|
| 266 |
+
return int(score / (1 + m_weight));
|
| 267 |
+
}
|
| 268 |
+
|
PyPatchMatch/csrc/nnf.h
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <opencv2/core.hpp>
|
| 4 |
+
#include "masked_image.h"
|
| 5 |
+
|
| 6 |
+
class PatchDistanceMetric {
|
| 7 |
+
public:
|
| 8 |
+
PatchDistanceMetric(int patch_size) : m_patch_size(patch_size) {}
|
| 9 |
+
virtual ~PatchDistanceMetric() = default;
|
| 10 |
+
|
| 11 |
+
inline int patch_size() const { return m_patch_size; }
|
| 12 |
+
virtual int operator()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const = 0;
|
| 13 |
+
static const int kDistanceScale;
|
| 14 |
+
|
| 15 |
+
protected:
|
| 16 |
+
int m_patch_size;
|
| 17 |
+
};
|
| 18 |
+
|
| 19 |
+
class NearestNeighborField {
|
| 20 |
+
public:
|
| 21 |
+
NearestNeighborField() : m_source(), m_target(), m_field(), m_distance_metric(nullptr) {
|
| 22 |
+
// pass
|
| 23 |
+
}
|
| 24 |
+
NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, int max_retry = 20)
|
| 25 |
+
: m_source(source), m_target(target), m_distance_metric(metric) {
|
| 26 |
+
m_field = cv::Mat(m_source.size(), CV_32SC3);
|
| 27 |
+
_randomize_field(max_retry);
|
| 28 |
+
}
|
| 29 |
+
NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, const NearestNeighborField &other, int max_retry = 20)
|
| 30 |
+
: m_source(source), m_target(target), m_distance_metric(metric) {
|
| 31 |
+
m_field = cv::Mat(m_source.size(), CV_32SC3);
|
| 32 |
+
_initialize_field_from(other, max_retry);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
const MaskedImage &source() const {
|
| 36 |
+
return m_source;
|
| 37 |
+
}
|
| 38 |
+
const MaskedImage &target() const {
|
| 39 |
+
return m_target;
|
| 40 |
+
}
|
| 41 |
+
inline cv::Size source_size() const {
|
| 42 |
+
return m_source.size();
|
| 43 |
+
}
|
| 44 |
+
inline cv::Size target_size() const {
|
| 45 |
+
return m_target.size();
|
| 46 |
+
}
|
| 47 |
+
inline void set_source(const MaskedImage &source) {
|
| 48 |
+
m_source = source;
|
| 49 |
+
}
|
| 50 |
+
inline void set_target(const MaskedImage &target) {
|
| 51 |
+
m_target = target;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
inline int *mutable_ptr(int y, int x) {
|
| 55 |
+
return m_field.ptr<int>(y, x);
|
| 56 |
+
}
|
| 57 |
+
inline const int *ptr(int y, int x) const {
|
| 58 |
+
return m_field.ptr<int>(y, x);
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
inline int at(int y, int x, int c) const {
|
| 62 |
+
return m_field.ptr<int>(y, x)[c];
|
| 63 |
+
}
|
| 64 |
+
inline int &at(int y, int x, int c) {
|
| 65 |
+
return m_field.ptr<int>(y, x)[c];
|
| 66 |
+
}
|
| 67 |
+
inline void set_identity(int y, int x) {
|
| 68 |
+
auto ptr = mutable_ptr(y, x);
|
| 69 |
+
ptr[0] = y, ptr[1] = x, ptr[2] = 0;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
void minimize(int nr_pass);
|
| 73 |
+
|
| 74 |
+
private:
|
| 75 |
+
inline int _distance(int source_y, int source_x, int target_y, int target_x) {
|
| 76 |
+
return (*m_distance_metric)(m_source, source_y, source_x, m_target, target_y, target_x);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
void _randomize_field(int max_retry = 20, bool reset = true);
|
| 80 |
+
void _initialize_field_from(const NearestNeighborField &other, int max_retry);
|
| 81 |
+
void _minimize_link(int y, int x, int direction);
|
| 82 |
+
|
| 83 |
+
MaskedImage m_source;
|
| 84 |
+
MaskedImage m_target;
|
| 85 |
+
cv::Mat m_field; // { y_target, x_target, distance_scaled }
|
| 86 |
+
const PatchDistanceMetric *m_distance_metric;
|
| 87 |
+
};
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class PatchSSDDistanceMetric : public PatchDistanceMetric {
|
| 91 |
+
public:
|
| 92 |
+
using PatchDistanceMetric::PatchDistanceMetric;
|
| 93 |
+
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
|
| 94 |
+
static const int kSSDScale;
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
class DebugPatchSSDDistanceMetric : public PatchDistanceMetric {
|
| 98 |
+
public:
|
| 99 |
+
DebugPatchSSDDistanceMetric(int patch_size, int width, int height) : PatchDistanceMetric(patch_size), m_width(width), m_height(height) {}
|
| 100 |
+
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
|
| 101 |
+
protected:
|
| 102 |
+
int m_width, m_height;
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
class RegularityGuidedPatchDistanceMetricV1 : public PatchDistanceMetric {
|
| 106 |
+
public:
|
| 107 |
+
RegularityGuidedPatchDistanceMetricV1(int patch_size, double dx1, double dy1, double dx2, double dy2, double weight)
|
| 108 |
+
: PatchDistanceMetric(patch_size), m_dx1(dx1), m_dy1(dy1), m_dx2(dx2), m_dy2(dy2), m_weight(weight) {
|
| 109 |
+
|
| 110 |
+
assert(m_dy1 == 0);
|
| 111 |
+
assert(m_dx2 == 0);
|
| 112 |
+
m_scale = sqrt(m_dx1 * m_dx1 + m_dy2 * m_dy2) / 4;
|
| 113 |
+
}
|
| 114 |
+
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
|
| 115 |
+
|
| 116 |
+
protected:
|
| 117 |
+
double m_dx1, m_dy1, m_dx2, m_dy2;
|
| 118 |
+
double m_scale, m_weight;
|
| 119 |
+
};
|
| 120 |
+
|
| 121 |
+
class RegularityGuidedPatchDistanceMetricV2 : public PatchDistanceMetric {
|
| 122 |
+
public:
|
| 123 |
+
RegularityGuidedPatchDistanceMetricV2(int patch_size, cv::Mat ijmap, double weight)
|
| 124 |
+
: PatchDistanceMetric(patch_size), m_ijmap(ijmap), m_weight(weight) {
|
| 125 |
+
|
| 126 |
+
}
|
| 127 |
+
virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
|
| 128 |
+
|
| 129 |
+
protected:
|
| 130 |
+
cv::Mat m_ijmap;
|
| 131 |
+
double m_width, m_height, m_weight;
|
| 132 |
+
};
|
| 133 |
+
|
PyPatchMatch/csrc/pyinterface.cpp
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "pyinterface.h"
|
| 2 |
+
#include "inpaint.h"
|
| 3 |
+
|
| 4 |
+
static unsigned int PM_seed = 1212;
|
| 5 |
+
static bool PM_verbose = false;
|
| 6 |
+
|
| 7 |
+
int _dtype_py_to_cv(int dtype_py);
|
| 8 |
+
int _dtype_cv_to_py(int dtype_cv);
|
| 9 |
+
cv::Mat _py_to_cv2(PM_mat_t pymat);
|
| 10 |
+
PM_mat_t _cv2_to_py(cv::Mat cvmat);
|
| 11 |
+
|
| 12 |
+
void PM_set_random_seed(unsigned int seed) {
|
| 13 |
+
PM_seed = seed;
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
void PM_set_verbose(int value) {
|
| 17 |
+
PM_verbose = static_cast<bool>(value);
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
void PM_free_pymat(PM_mat_t pymat) {
|
| 21 |
+
free(pymat.data_ptr);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
PM_mat_t PM_inpaint(PM_mat_t source_py, PM_mat_t mask_py, int patch_size) {
|
| 25 |
+
cv::Mat source = _py_to_cv2(source_py);
|
| 26 |
+
cv::Mat mask = _py_to_cv2(mask_py);
|
| 27 |
+
auto metric = PatchSSDDistanceMetric(patch_size);
|
| 28 |
+
cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
|
| 29 |
+
return _cv2_to_py(result);
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
PM_mat_t PM_inpaint_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
|
| 33 |
+
cv::Mat source = _py_to_cv2(source_py);
|
| 34 |
+
cv::Mat mask = _py_to_cv2(mask_py);
|
| 35 |
+
cv::Mat ijmap = _py_to_cv2(ijmap_py);
|
| 36 |
+
|
| 37 |
+
auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
|
| 38 |
+
cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
|
| 39 |
+
return _cv2_to_py(result);
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
PM_mat_t PM_inpaint2(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, int patch_size) {
|
| 43 |
+
cv::Mat source = _py_to_cv2(source_py);
|
| 44 |
+
cv::Mat mask = _py_to_cv2(mask_py);
|
| 45 |
+
cv::Mat global_mask = _py_to_cv2(global_mask_py);
|
| 46 |
+
|
| 47 |
+
auto metric = PatchSSDDistanceMetric(patch_size);
|
| 48 |
+
cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
|
| 49 |
+
return _cv2_to_py(result);
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
PM_mat_t PM_inpaint2_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
|
| 53 |
+
cv::Mat source = _py_to_cv2(source_py);
|
| 54 |
+
cv::Mat mask = _py_to_cv2(mask_py);
|
| 55 |
+
cv::Mat global_mask = _py_to_cv2(global_mask_py);
|
| 56 |
+
cv::Mat ijmap = _py_to_cv2(ijmap_py);
|
| 57 |
+
|
| 58 |
+
auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
|
| 59 |
+
cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
|
| 60 |
+
return _cv2_to_py(result);
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
int _dtype_py_to_cv(int dtype_py) {
|
| 64 |
+
switch (dtype_py) {
|
| 65 |
+
case PM_UINT8: return CV_8U;
|
| 66 |
+
case PM_INT8: return CV_8S;
|
| 67 |
+
case PM_UINT16: return CV_16U;
|
| 68 |
+
case PM_INT16: return CV_16S;
|
| 69 |
+
case PM_INT32: return CV_32S;
|
| 70 |
+
case PM_FLOAT32: return CV_32F;
|
| 71 |
+
case PM_FLOAT64: return CV_64F;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
return CV_8U;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
int _dtype_cv_to_py(int dtype_cv) {
|
| 78 |
+
switch (dtype_cv) {
|
| 79 |
+
case CV_8U: return PM_UINT8;
|
| 80 |
+
case CV_8S: return PM_INT8;
|
| 81 |
+
case CV_16U: return PM_UINT16;
|
| 82 |
+
case CV_16S: return PM_INT16;
|
| 83 |
+
case CV_32S: return PM_INT32;
|
| 84 |
+
case CV_32F: return PM_FLOAT32;
|
| 85 |
+
case CV_64F: return PM_FLOAT64;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
return PM_UINT8;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
cv::Mat _py_to_cv2(PM_mat_t pymat) {
|
| 92 |
+
int dtype = _dtype_py_to_cv(pymat.dtype);
|
| 93 |
+
dtype = CV_MAKETYPE(pymat.dtype, pymat.shape.channels);
|
| 94 |
+
return cv::Mat(cv::Size(pymat.shape.width, pymat.shape.height), dtype, pymat.data_ptr).clone();
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
PM_mat_t _cv2_to_py(cv::Mat cvmat) {
|
| 98 |
+
PM_shape_t shape = {cvmat.size().width, cvmat.size().height, cvmat.channels()};
|
| 99 |
+
int dtype = _dtype_cv_to_py(cvmat.depth());
|
| 100 |
+
size_t dsize = cvmat.total() * cvmat.elemSize();
|
| 101 |
+
|
| 102 |
+
void *data_ptr = reinterpret_cast<void *>(malloc(dsize));
|
| 103 |
+
memcpy(data_ptr, reinterpret_cast<void *>(cvmat.data), dsize);
|
| 104 |
+
|
| 105 |
+
return PM_mat_t {data_ptr, shape, dtype};
|
| 106 |
+
}
|
| 107 |
+
|
PyPatchMatch/csrc/pyinterface.h
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <opencv2/core.hpp>
|
| 2 |
+
#include <cstdlib>
|
| 3 |
+
#include <cstdio>
|
| 4 |
+
#include <cstring>
|
| 5 |
+
|
| 6 |
+
extern "C" {
|
| 7 |
+
|
| 8 |
+
struct PM_shape_t {
|
| 9 |
+
int width, height, channels;
|
| 10 |
+
};
|
| 11 |
+
|
| 12 |
+
enum PM_dtype_e {
|
| 13 |
+
PM_UINT8,
|
| 14 |
+
PM_INT8,
|
| 15 |
+
PM_UINT16,
|
| 16 |
+
PM_INT16,
|
| 17 |
+
PM_INT32,
|
| 18 |
+
PM_FLOAT32,
|
| 19 |
+
PM_FLOAT64,
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
struct PM_mat_t {
|
| 23 |
+
void *data_ptr;
|
| 24 |
+
PM_shape_t shape;
|
| 25 |
+
int dtype;
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
void PM_set_random_seed(unsigned int seed);
|
| 29 |
+
void PM_set_verbose(int value);
|
| 30 |
+
|
| 31 |
+
void PM_free_pymat(PM_mat_t pymat);
|
| 32 |
+
PM_mat_t PM_inpaint(PM_mat_t image, PM_mat_t mask, int patch_size);
|
| 33 |
+
PM_mat_t PM_inpaint_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t ijmap, int patch_size, float guide_weight);
|
| 34 |
+
PM_mat_t PM_inpaint2(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, int patch_size);
|
| 35 |
+
PM_mat_t PM_inpaint2_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, PM_mat_t ijmap, int patch_size, float guide_weight);
|
| 36 |
+
|
| 37 |
+
} /* extern "C" */
|
| 38 |
+
|
PyPatchMatch/examples/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/cpp_example.exe
|
| 2 |
+
/images/*recovered.bmp
|
PyPatchMatch/examples/cpp_example.cpp
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <iostream>
|
| 2 |
+
#include <opencv2/imgcodecs.hpp>
|
| 3 |
+
#include <opencv2/highgui.hpp>
|
| 4 |
+
|
| 5 |
+
#include "masked_image.h"
|
| 6 |
+
#include "nnf.h"
|
| 7 |
+
#include "inpaint.h"
|
| 8 |
+
|
| 9 |
+
int main() {
|
| 10 |
+
auto source = cv::imread("./images/forest_pruned.bmp", cv::IMREAD_COLOR);
|
| 11 |
+
|
| 12 |
+
auto mask = cv::Mat(source.size(), CV_8UC1);
|
| 13 |
+
mask = cv::Scalar::all(0);
|
| 14 |
+
for (int i = 0; i < source.size().height; ++i) {
|
| 15 |
+
for (int j = 0; j < source.size().width; ++j) {
|
| 16 |
+
auto source_ptr = source.ptr<unsigned char>(i, j);
|
| 17 |
+
if (source_ptr[0] == 255 && source_ptr[1] == 255 && source_ptr[2] == 255) {
|
| 18 |
+
mask.at<unsigned char>(i, j) = 1;
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
auto metric = PatchSSDDistanceMetric(3);
|
| 24 |
+
auto result = Inpainting(source, mask, &metric).run(true, true);
|
| 25 |
+
// cv::imwrite("./images/forest_recovered.bmp", result);
|
| 26 |
+
// cv::imshow("Result", result);
|
| 27 |
+
// cv::waitKey();
|
| 28 |
+
|
| 29 |
+
return 0;
|
| 30 |
+
}
|
| 31 |
+
|
PyPatchMatch/examples/cpp_example_run.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /bin/bash
|
| 2 |
+
#
|
| 3 |
+
# cpp_example_run.sh
|
| 4 |
+
# Copyright (C) 2020 Jiayuan Mao <maojiayuan@gmail.com>
|
| 5 |
+
#
|
| 6 |
+
# Distributed under terms of the MIT license.
|
| 7 |
+
#
|
| 8 |
+
|
| 9 |
+
set -x
|
| 10 |
+
|
| 11 |
+
CFLAGS="-std=c++14 -O2 $(pkg-config --cflags opencv)"
|
| 12 |
+
LDFLAGS="$(pkg-config --libs opencv)"
|
| 13 |
+
g++ $CFLAGS cpp_example.cpp -I../csrc/ -L../ -lpatchmatch $LDFLAGS -o cpp_example.exe
|
| 14 |
+
|
| 15 |
+
export DYLD_LIBRARY_PATH=../:$DYLD_LIBRARY_PATH # For macOS
|
| 16 |
+
export LD_LIBRARY_PATH=../:$LD_LIBRARY_PATH # For Linux
|
| 17 |
+
time ./cpp_example.exe
|
| 18 |
+
|
PyPatchMatch/examples/images/forest.bmp
ADDED
|
PyPatchMatch/examples/images/forest_pruned.bmp
ADDED
|
PyPatchMatch/examples/py_example.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# File : test.py
|
| 4 |
+
# Author : Jiayuan Mao
|
| 5 |
+
# Email : maojiayuan@gmail.com
|
| 6 |
+
# Date : 01/09/2020
|
| 7 |
+
#
|
| 8 |
+
# Distributed under terms of the MIT license.
|
| 9 |
+
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
sys.path.insert(0, '../')
|
| 14 |
+
import patch_match
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
if __name__ == '__main__':
|
| 18 |
+
source = Image.open('./images/forest_pruned.bmp')
|
| 19 |
+
result = patch_match.inpaint(source, patch_size=3)
|
| 20 |
+
Image.fromarray(result).save('./images/forest_recovered.bmp')
|
| 21 |
+
|
PyPatchMatch/examples/py_example_global_mask.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# File : test.py
|
| 4 |
+
# Author : Jiayuan Mao
|
| 5 |
+
# Email : maojiayuan@gmail.com
|
| 6 |
+
# Date : 01/09/2020
|
| 7 |
+
#
|
| 8 |
+
# Distributed under terms of the MIT license.
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
|
| 13 |
+
import sys
|
| 14 |
+
sys.path.insert(0, '../')
|
| 15 |
+
import patch_match
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if __name__ == '__main__':
|
| 19 |
+
patch_match.set_verbose(True)
|
| 20 |
+
source = Image.open('./images/forest_pruned.bmp')
|
| 21 |
+
source = np.array(source)
|
| 22 |
+
source[:100, :100] = 255
|
| 23 |
+
global_mask = np.zeros_like(source[..., 0])
|
| 24 |
+
global_mask[:100, :100] = 1
|
| 25 |
+
result = patch_match.inpaint(source, global_mask=global_mask, patch_size=3)
|
| 26 |
+
Image.fromarray(result).save('./images/forest_recovered.bmp')
|
| 27 |
+
|
PyPatchMatch/patch_match.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
# File : patch_match.py
|
| 4 |
+
# Author : Jiayuan Mao
|
| 5 |
+
# Email : maojiayuan@gmail.com
|
| 6 |
+
# Date : 01/09/2020
|
| 7 |
+
#
|
| 8 |
+
# Distributed under terms of the MIT license.
|
| 9 |
+
|
| 10 |
+
import ctypes
|
| 11 |
+
import os.path as osp
|
| 12 |
+
from typing import Optional, Union
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
if os.name!="nt":
|
| 20 |
+
# Otherwise, fall back to the subprocess.
|
| 21 |
+
import subprocess
|
| 22 |
+
print('Compiling and loading c extensions from "{}".'.format(osp.realpath(osp.dirname(__file__))))
|
| 23 |
+
# subprocess.check_call(['./travis.sh'], cwd=osp.dirname(__file__))
|
| 24 |
+
subprocess.check_call("make clean && make", cwd=osp.dirname(__file__), shell=True)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
__all__ = ['set_random_seed', 'set_verbose', 'inpaint', 'inpaint_regularity']
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class CShapeT(ctypes.Structure):
|
| 31 |
+
_fields_ = [
|
| 32 |
+
('width', ctypes.c_int),
|
| 33 |
+
('height', ctypes.c_int),
|
| 34 |
+
('channels', ctypes.c_int),
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CMatT(ctypes.Structure):
|
| 39 |
+
_fields_ = [
|
| 40 |
+
('data_ptr', ctypes.c_void_p),
|
| 41 |
+
('shape', CShapeT),
|
| 42 |
+
('dtype', ctypes.c_int)
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
import tempfile
|
| 46 |
+
from urllib.request import urlopen, Request
|
| 47 |
+
import shutil
|
| 48 |
+
from pathlib import Path
|
| 49 |
+
from tqdm import tqdm
|
| 50 |
+
|
| 51 |
+
def download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
| 52 |
+
r"""Download object at the given URL to a local path.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
url (string): URL of the object to download
|
| 56 |
+
dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file``
|
| 57 |
+
hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
|
| 58 |
+
Default: None
|
| 59 |
+
progress (bool, optional): whether or not to display a progress bar to stderr
|
| 60 |
+
Default: True
|
| 61 |
+
https://pytorch.org/docs/stable/_modules/torch/hub.html#load_state_dict_from_url
|
| 62 |
+
"""
|
| 63 |
+
file_size = None
|
| 64 |
+
req = Request(url)
|
| 65 |
+
u = urlopen(req)
|
| 66 |
+
meta = u.info()
|
| 67 |
+
if hasattr(meta, 'getheaders'):
|
| 68 |
+
content_length = meta.getheaders("Content-Length")
|
| 69 |
+
else:
|
| 70 |
+
content_length = meta.get_all("Content-Length")
|
| 71 |
+
if content_length is not None and len(content_length) > 0:
|
| 72 |
+
file_size = int(content_length[0])
|
| 73 |
+
|
| 74 |
+
# We deliberately save it in a temp file and move it after
|
| 75 |
+
# download is complete. This prevents a local working checkpoint
|
| 76 |
+
# being overridden by a broken download.
|
| 77 |
+
dst = os.path.expanduser(dst)
|
| 78 |
+
dst_dir = os.path.dirname(dst)
|
| 79 |
+
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
| 80 |
+
|
| 81 |
+
try:
|
| 82 |
+
with tqdm(total=file_size, disable=not progress,
|
| 83 |
+
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
| 84 |
+
while True:
|
| 85 |
+
buffer = u.read(8192)
|
| 86 |
+
if len(buffer) == 0:
|
| 87 |
+
break
|
| 88 |
+
f.write(buffer)
|
| 89 |
+
pbar.update(len(buffer))
|
| 90 |
+
|
| 91 |
+
f.close()
|
| 92 |
+
shutil.move(f.name, dst)
|
| 93 |
+
finally:
|
| 94 |
+
f.close()
|
| 95 |
+
if os.path.exists(f.name):
|
| 96 |
+
os.remove(f.name)
|
| 97 |
+
|
| 98 |
+
if os.name!="nt":
|
| 99 |
+
PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.so'))
|
| 100 |
+
else:
|
| 101 |
+
if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')):
|
| 102 |
+
download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll",dst=osp.join(osp.dirname(__file__), 'libpatchmatch.dll'))
|
| 103 |
+
if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')):
|
| 104 |
+
download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll",dst=osp.join(osp.dirname(__file__), 'opencv_world460.dll'))
|
| 105 |
+
if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')):
|
| 106 |
+
print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll and put it into the PyPatchMatch folder")
|
| 107 |
+
if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')):
|
| 108 |
+
print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll and put it into the PyPatchMatch folder")
|
| 109 |
+
PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.dll'))
|
| 110 |
+
|
| 111 |
+
PMLIB.PM_set_random_seed.argtypes = [ctypes.c_uint]
|
| 112 |
+
PMLIB.PM_set_verbose.argtypes = [ctypes.c_int]
|
| 113 |
+
PMLIB.PM_free_pymat.argtypes = [CMatT]
|
| 114 |
+
PMLIB.PM_inpaint.argtypes = [CMatT, CMatT, ctypes.c_int]
|
| 115 |
+
PMLIB.PM_inpaint.restype = CMatT
|
| 116 |
+
PMLIB.PM_inpaint_regularity.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
|
| 117 |
+
PMLIB.PM_inpaint_regularity.restype = CMatT
|
| 118 |
+
PMLIB.PM_inpaint2.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int]
|
| 119 |
+
PMLIB.PM_inpaint2.restype = CMatT
|
| 120 |
+
PMLIB.PM_inpaint2_regularity.argtypes = [CMatT, CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
|
| 121 |
+
PMLIB.PM_inpaint2_regularity.restype = CMatT
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def set_random_seed(seed: int):
|
| 125 |
+
PMLIB.PM_set_random_seed(ctypes.c_uint(seed))
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def set_verbose(verbose: bool):
|
| 129 |
+
PMLIB.PM_set_verbose(ctypes.c_int(verbose))
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def inpaint(
|
| 133 |
+
image: Union[np.ndarray, Image.Image],
|
| 134 |
+
mask: Optional[Union[np.ndarray, Image.Image]] = None,
|
| 135 |
+
*,
|
| 136 |
+
global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
|
| 137 |
+
patch_size: int = 15
|
| 138 |
+
) -> np.ndarray:
|
| 139 |
+
"""
|
| 140 |
+
PatchMatch based inpainting proposed in:
|
| 141 |
+
|
| 142 |
+
PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing
|
| 143 |
+
C.Barnes, E.Shechtman, A.Finkelstein and Dan B.Goldman
|
| 144 |
+
SIGGRAPH 2009
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
image (Union[np.ndarray, Image.Image]): the input image, should be 3-channel RGB/BGR.
|
| 148 |
+
mask (Union[np.array, Image.Image], optional): the mask of the hole(s) to be filled, should be 1-channel.
|
| 149 |
+
If not provided (None), the algorithm will treat all purely white pixels as the holes (255, 255, 255).
|
| 150 |
+
global_mask (Union[np.array, Image.Image], optional): the target mask of the output image.
|
| 151 |
+
patch_size (int): the patch size for the inpainting algorithm.
|
| 152 |
+
|
| 153 |
+
Return:
|
| 154 |
+
result (np.ndarray): the repaired image, of the same size as the input image.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
if isinstance(image, Image.Image):
|
| 158 |
+
image = np.array(image)
|
| 159 |
+
image = np.ascontiguousarray(image)
|
| 160 |
+
assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
|
| 161 |
+
|
| 162 |
+
if mask is None:
|
| 163 |
+
mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
|
| 164 |
+
mask = np.ascontiguousarray(mask)
|
| 165 |
+
else:
|
| 166 |
+
mask = _canonize_mask_array(mask)
|
| 167 |
+
|
| 168 |
+
if global_mask is None:
|
| 169 |
+
ret_pymat = PMLIB.PM_inpaint(np_to_pymat(image), np_to_pymat(mask), ctypes.c_int(patch_size))
|
| 170 |
+
else:
|
| 171 |
+
global_mask = _canonize_mask_array(global_mask)
|
| 172 |
+
ret_pymat = PMLIB.PM_inpaint2(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), ctypes.c_int(patch_size))
|
| 173 |
+
|
| 174 |
+
ret_npmat = pymat_to_np(ret_pymat)
|
| 175 |
+
PMLIB.PM_free_pymat(ret_pymat)
|
| 176 |
+
|
| 177 |
+
return ret_npmat
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def inpaint_regularity(
|
| 181 |
+
image: Union[np.ndarray, Image.Image],
|
| 182 |
+
mask: Optional[Union[np.ndarray, Image.Image]],
|
| 183 |
+
ijmap: np.ndarray,
|
| 184 |
+
*,
|
| 185 |
+
global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
|
| 186 |
+
patch_size: int = 15, guide_weight: float = 0.25
|
| 187 |
+
) -> np.ndarray:
|
| 188 |
+
if isinstance(image, Image.Image):
|
| 189 |
+
image = np.array(image)
|
| 190 |
+
image = np.ascontiguousarray(image)
|
| 191 |
+
|
| 192 |
+
assert isinstance(ijmap, np.ndarray) and ijmap.ndim == 3 and ijmap.shape[2] == 3 and ijmap.dtype == 'float32'
|
| 193 |
+
ijmap = np.ascontiguousarray(ijmap)
|
| 194 |
+
|
| 195 |
+
assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
|
| 196 |
+
if mask is None:
|
| 197 |
+
mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
|
| 198 |
+
mask = np.ascontiguousarray(mask)
|
| 199 |
+
else:
|
| 200 |
+
mask = _canonize_mask_array(mask)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if global_mask is None:
|
| 204 |
+
ret_pymat = PMLIB.PM_inpaint_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
|
| 205 |
+
else:
|
| 206 |
+
global_mask = _canonize_mask_array(global_mask)
|
| 207 |
+
ret_pymat = PMLIB.PM_inpaint2_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
|
| 208 |
+
|
| 209 |
+
ret_npmat = pymat_to_np(ret_pymat)
|
| 210 |
+
PMLIB.PM_free_pymat(ret_pymat)
|
| 211 |
+
|
| 212 |
+
return ret_npmat
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _canonize_mask_array(mask):
|
| 216 |
+
if isinstance(mask, Image.Image):
|
| 217 |
+
mask = np.array(mask)
|
| 218 |
+
if mask.ndim == 2 and mask.dtype == 'uint8':
|
| 219 |
+
mask = mask[..., np.newaxis]
|
| 220 |
+
assert mask.ndim == 3 and mask.shape[2] == 1 and mask.dtype == 'uint8'
|
| 221 |
+
return np.ascontiguousarray(mask)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
dtype_pymat_to_ctypes = [
|
| 225 |
+
ctypes.c_uint8,
|
| 226 |
+
ctypes.c_int8,
|
| 227 |
+
ctypes.c_uint16,
|
| 228 |
+
ctypes.c_int16,
|
| 229 |
+
ctypes.c_int32,
|
| 230 |
+
ctypes.c_float,
|
| 231 |
+
ctypes.c_double,
|
| 232 |
+
]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
dtype_np_to_pymat = {
|
| 236 |
+
'uint8': 0,
|
| 237 |
+
'int8': 1,
|
| 238 |
+
'uint16': 2,
|
| 239 |
+
'int16': 3,
|
| 240 |
+
'int32': 4,
|
| 241 |
+
'float32': 5,
|
| 242 |
+
'float64': 6,
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def np_to_pymat(npmat):
|
| 247 |
+
assert npmat.ndim == 3
|
| 248 |
+
return CMatT(
|
| 249 |
+
ctypes.cast(npmat.ctypes.data, ctypes.c_void_p),
|
| 250 |
+
CShapeT(npmat.shape[1], npmat.shape[0], npmat.shape[2]),
|
| 251 |
+
dtype_np_to_pymat[str(npmat.dtype)]
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def pymat_to_np(pymat):
|
| 256 |
+
npmat = np.ctypeslib.as_array(
|
| 257 |
+
ctypes.cast(pymat.data_ptr, ctypes.POINTER(dtype_pymat_to_ctypes[pymat.dtype])),
|
| 258 |
+
(pymat.shape.height, pymat.shape.width, pymat.shape.channels)
|
| 259 |
+
)
|
| 260 |
+
ret = np.empty(npmat.shape, npmat.dtype)
|
| 261 |
+
ret[:] = npmat
|
| 262 |
+
return ret
|
| 263 |
+
|
PyPatchMatch/travis.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#! /bin/bash
|
| 2 |
+
#
|
| 3 |
+
# travis.sh
|
| 4 |
+
# Copyright (C) 2020 Jiayuan Mao <maojiayuan@gmail.com>
|
| 5 |
+
#
|
| 6 |
+
# Distributed under terms of the MIT license.
|
| 7 |
+
#
|
| 8 |
+
|
| 9 |
+
make clean && make
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Stablediffusion Infinity
|
| 3 |
+
emoji: ♾️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
app_file: app.py
|
| 8 |
+
pinned: true
|
| 9 |
+
license: apache-2.0
|
| 10 |
+
sdk_version: 4.31.3
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,1043 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.system("cp opencv.pc /usr/local/lib/pkgconfig/")
|
| 3 |
+
os.system("pip install 'numpy<2'")
|
| 4 |
+
os.system("pip uninstall triton -y")
|
| 5 |
+
import spaces
|
| 6 |
+
import io
|
| 7 |
+
import base64
|
| 8 |
+
import sys
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image, ImageOps
|
| 12 |
+
import gradio as gr
|
| 13 |
+
import skimage
|
| 14 |
+
import skimage.measure
|
| 15 |
+
import yaml
|
| 16 |
+
import json
|
| 17 |
+
from enum import Enum
|
| 18 |
+
from utils import *
|
| 19 |
+
from collections import Counter
|
| 20 |
+
import argparse
|
| 21 |
+
from stablepy import Model_Diffusers, scheduler_names, ALL_PROMPT_WEIGHT_OPTIONS, SCHEDULE_TYPE_OPTIONS
|
| 22 |
+
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
|
| 23 |
+
from datetime import datetime
|
| 24 |
+
|
| 25 |
+
parser = argparse.ArgumentParser(description="stablediffusion-infinity")
|
| 26 |
+
parser.add_argument("--port", type=int, help="listen port", dest="server_port")
|
| 27 |
+
parser.add_argument("--host", type=str, help="host", dest="server_name")
|
| 28 |
+
parser.add_argument("--share", action="store_true", help="share this app?")
|
| 29 |
+
parser.add_argument("--debug", action="store_true", help="debug mode")
|
| 30 |
+
parser.add_argument("--fp32", action="store_true", help="using full precision")
|
| 31 |
+
parser.add_argument("--lowvram", action="store_true", help="using lowvram mode")
|
| 32 |
+
parser.add_argument("--encrypt", action="store_true", help="using https?")
|
| 33 |
+
parser.add_argument("--ssl_keyfile", type=str, help="path to ssl_keyfile")
|
| 34 |
+
parser.add_argument("--ssl_certfile", type=str, help="path to ssl_certfile")
|
| 35 |
+
parser.add_argument("--ssl_keyfile_password", type=str, help="ssl_keyfile_password")
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--auth", nargs=2, metavar=("username", "password"), help="use username password"
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--remote_model",
|
| 41 |
+
type=str,
|
| 42 |
+
help="use a model (e.g. dreambooth fined) from huggingface hub",
|
| 43 |
+
default="",
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--local_model", type=str, help="use a model stored on your PC", default=""
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--stablepy_model",
|
| 50 |
+
type=str,
|
| 51 |
+
help="Model source: can be a Hugging Face Diffusers repo or a local .safetensors file path",
|
| 52 |
+
default="SG161222/RealVisXL_V5.0_Lightning"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
abspath = os.path.abspath(__file__)
|
| 57 |
+
dirname = os.path.dirname(abspath)
|
| 58 |
+
os.chdir(dirname)
|
| 59 |
+
except:
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
Interrogator = DummyInterrogator
|
| 63 |
+
|
| 64 |
+
START_DEVICE_STABLEPY = "cpu" if os.getenv("SPACES_ZERO_GPU") else None
|
| 65 |
+
DEBUG_MODE = False
|
| 66 |
+
AUTO_SETUP = True
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
with open("config.yaml", "r") as yaml_in:
|
| 70 |
+
yaml_object = yaml.safe_load(yaml_in)
|
| 71 |
+
config_json = json.dumps(yaml_object)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def parse_color(color):
|
| 75 |
+
"""
|
| 76 |
+
Convert color to Pillow-friendly (R, G, B, A) tuple in 0–255 range.
|
| 77 |
+
Supports:
|
| 78 |
+
- tuple/list of floats or ints
|
| 79 |
+
- 'rgba(r, g, b, a)' string
|
| 80 |
+
- 'rgb(r, g, b)' string
|
| 81 |
+
- hex colors: '#RRGGBB' or '#RRGGBBAA'
|
| 82 |
+
"""
|
| 83 |
+
if isinstance(color, (tuple, list)):
|
| 84 |
+
parts = [float(c) for c in color]
|
| 85 |
+
|
| 86 |
+
elif isinstance(color, str):
|
| 87 |
+
c = color.strip().lower()
|
| 88 |
+
|
| 89 |
+
# Hex color
|
| 90 |
+
if c.startswith("#"):
|
| 91 |
+
c = c.lstrip("#")
|
| 92 |
+
if len(c) == 6: # RRGGBB
|
| 93 |
+
r, g, b = int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16)
|
| 94 |
+
return (r, g, b, 255)
|
| 95 |
+
elif len(c) == 8: # RRGGBBAA
|
| 96 |
+
r, g, b, a = int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16), int(c[6:8], 16)
|
| 97 |
+
return (r, g, b, a)
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError(f"Invalid hex color: {color}")
|
| 100 |
+
|
| 101 |
+
# RGB / RGBA string
|
| 102 |
+
c = c.replace("rgba", "").replace("rgb", "").replace("(", "").replace(")", "")
|
| 103 |
+
parts = [float(x.strip()) for x in c.split(",")]
|
| 104 |
+
|
| 105 |
+
else:
|
| 106 |
+
raise ValueError(f"Unsupported color format: {color}")
|
| 107 |
+
|
| 108 |
+
# Ensure alpha
|
| 109 |
+
if len(parts) == 3:
|
| 110 |
+
parts.append(1.0) # default alpha = 1.0
|
| 111 |
+
|
| 112 |
+
return (
|
| 113 |
+
int(round(parts[0])),
|
| 114 |
+
int(round(parts[1])),
|
| 115 |
+
int(round(parts[2])),
|
| 116 |
+
int(round(parts[3] * 255 if parts[3] <= 1 else parts[3]))
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def is_not_dark(color, threshold=30):
|
| 121 |
+
return not all(c <= threshold for c in color)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_dominant_color_exclude_dark(image_pil):
|
| 125 |
+
img_small = image_pil.convert("RGB").resize((50, 50))
|
| 126 |
+
pixels = list(img_small.getdata())
|
| 127 |
+
filtered_pixels = [p for p in pixels if is_not_dark(p)]
|
| 128 |
+
if not filtered_pixels:
|
| 129 |
+
filtered_pixels = pixels
|
| 130 |
+
most_common = Counter(filtered_pixels).most_common(1)[0][0]
|
| 131 |
+
return most_common
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def replace_color_in_mask(image_pil, mask_pil, target_color=None):
|
| 135 |
+
img = np.array(image_pil.convert("RGB"))
|
| 136 |
+
mask = np.array(mask_pil.convert("L"))
|
| 137 |
+
|
| 138 |
+
mask_white = mask == 255
|
| 139 |
+
mask_nonwhite = ~mask_white
|
| 140 |
+
|
| 141 |
+
if target_color in [None, ""]:
|
| 142 |
+
nonwhite_pixels = img[mask_nonwhite]
|
| 143 |
+
nonwhite_img = Image.fromarray(nonwhite_pixels.reshape((-1, 1, 3))) # , mode="RGB"
|
| 144 |
+
target_color = get_dominant_color_exclude_dark(nonwhite_img)
|
| 145 |
+
else:
|
| 146 |
+
parsed = parse_color(target_color) # (R, G, B, A)
|
| 147 |
+
target_color = parsed[:3] # ignore alpha for replacement
|
| 148 |
+
|
| 149 |
+
img[mask_white] = target_color
|
| 150 |
+
return Image.fromarray(img)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def expand_white_around_black(image: Image.Image, expand_ratio=0.1) -> Image.Image:
|
| 154 |
+
"""
|
| 155 |
+
Expand the white areas around the black region by a percentage of the black region size.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
image: PIL grayscale image (mode "L").
|
| 159 |
+
expand_ratio: Fraction of black region size to expand white sides (default 0.1 = 10%).
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
PIL Image with white expanded around black.
|
| 163 |
+
"""
|
| 164 |
+
arr = np.array(image)
|
| 165 |
+
|
| 166 |
+
black_mask = arr == 0
|
| 167 |
+
|
| 168 |
+
height, width = arr.shape
|
| 169 |
+
coords = np.argwhere(black_mask)
|
| 170 |
+
|
| 171 |
+
if coords.size == 0:
|
| 172 |
+
# No black pixels, return original image
|
| 173 |
+
return image.copy()
|
| 174 |
+
|
| 175 |
+
y_min, x_min = coords.min(axis=0)
|
| 176 |
+
y_max, x_max = coords.max(axis=0)
|
| 177 |
+
|
| 178 |
+
expand_x = int((x_max - x_min + 1) * expand_ratio)
|
| 179 |
+
expand_y = int((y_max - y_min + 1) * expand_ratio)
|
| 180 |
+
|
| 181 |
+
# Shrink black bounding box to expand white sides
|
| 182 |
+
if y_min > 0 and np.all(arr[:y_min, :] == 255):
|
| 183 |
+
y_min = min(height - 1, y_min + expand_y)
|
| 184 |
+
|
| 185 |
+
if y_max < height - 1 and np.all(arr[y_max + 1:, :] == 255):
|
| 186 |
+
y_max = max(0, y_max - expand_y)
|
| 187 |
+
|
| 188 |
+
if x_min > 0 and np.all(arr[:, :x_min] == 255):
|
| 189 |
+
x_min = min(width - 1, x_min + expand_x)
|
| 190 |
+
|
| 191 |
+
if x_max < width - 1 and np.all(arr[:, x_max + 1:] == 255):
|
| 192 |
+
x_max = max(0, x_max - expand_x)
|
| 193 |
+
|
| 194 |
+
# Create new white canvas
|
| 195 |
+
expanded_arr = np.full_like(arr, 255)
|
| 196 |
+
|
| 197 |
+
# Paint black inside adjusted bounding box
|
| 198 |
+
expanded_arr[y_min:y_max+1, x_min:x_max+1] = 0
|
| 199 |
+
|
| 200 |
+
return Image.fromarray(expanded_arr)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def load_html():
|
| 204 |
+
body, canvaspy = "", ""
|
| 205 |
+
with open("index.html", encoding="utf8") as f:
|
| 206 |
+
body = f.read()
|
| 207 |
+
with open("canvas.py", encoding="utf8") as f:
|
| 208 |
+
canvaspy = f.read()
|
| 209 |
+
body = body.replace("- paths:\n", "")
|
| 210 |
+
body = body.replace(" - ./canvas.py\n", "")
|
| 211 |
+
body = body.replace("from canvas import InfCanvas", canvaspy)
|
| 212 |
+
return body
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def test(x):
|
| 216 |
+
x = load_html()
|
| 217 |
+
return f"""<iframe id="sdinfframe" style="width: 100%; height: 600px" name="result" allow="midi; geolocation; microphone; camera;
|
| 218 |
+
display-capture; encrypted-media; vertical-scroll 'none'" sandbox="allow-modals allow-forms
|
| 219 |
+
allow-scripts allow-same-origin allow-popups
|
| 220 |
+
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
|
| 221 |
+
allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>"""
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
try:
|
| 225 |
+
SAMPLING_MODE = Image.Resampling.LANCZOS
|
| 226 |
+
except Exception as e:
|
| 227 |
+
SAMPLING_MODE = Image.LANCZOS
|
| 228 |
+
|
| 229 |
+
try:
|
| 230 |
+
contain_func = ImageOps.contain
|
| 231 |
+
except Exception as e:
|
| 232 |
+
def contain_func(image, size, method=SAMPLING_MODE):
|
| 233 |
+
# from PIL: https://pillow.readthedocs.io/en/stable/reference/ImageOps.html#PIL.ImageOps.contain
|
| 234 |
+
im_ratio = image.width / image.height
|
| 235 |
+
dest_ratio = size[0] / size[1]
|
| 236 |
+
if im_ratio != dest_ratio:
|
| 237 |
+
if im_ratio > dest_ratio:
|
| 238 |
+
new_height = int(image.height / image.width * size[0])
|
| 239 |
+
if new_height != size[1]:
|
| 240 |
+
size = (size[0], new_height)
|
| 241 |
+
else:
|
| 242 |
+
new_width = int(image.width / image.height * size[1])
|
| 243 |
+
if new_width != size[0]:
|
| 244 |
+
size = (new_width, size[1])
|
| 245 |
+
return image.resize(size, resample=method)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
args = parser.parse_args()
|
| 250 |
+
else:
|
| 251 |
+
args = parser.parse_args(["--debug"])
|
| 252 |
+
# args = parser.parse_args(["--debug"])
|
| 253 |
+
if args.auth is not None:
|
| 254 |
+
args.auth = tuple(args.auth)
|
| 255 |
+
|
| 256 |
+
model = {}
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def get_token():
|
| 260 |
+
token = ""
|
| 261 |
+
if os.path.exists(".token"):
|
| 262 |
+
with open(".token", "r") as f:
|
| 263 |
+
token = f.read()
|
| 264 |
+
token = os.environ.get("hftoken", token)
|
| 265 |
+
return token
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def save_token(token):
|
| 269 |
+
with open(".token", "w") as f:
|
| 270 |
+
f.write(token)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def my_resize(width, height):
|
| 274 |
+
if width >= 512 and height >= 512:
|
| 275 |
+
return width, height
|
| 276 |
+
if width == height:
|
| 277 |
+
return 512, 512
|
| 278 |
+
smaller = min(width, height)
|
| 279 |
+
larger = max(width, height)
|
| 280 |
+
if larger >= 608:
|
| 281 |
+
return width, height
|
| 282 |
+
factor = 1
|
| 283 |
+
if smaller < 290:
|
| 284 |
+
factor = 2
|
| 285 |
+
elif smaller < 330:
|
| 286 |
+
factor = 1.75
|
| 287 |
+
elif smaller < 384:
|
| 288 |
+
factor = 1.375
|
| 289 |
+
elif smaller < 400:
|
| 290 |
+
factor = 1.25
|
| 291 |
+
elif smaller < 450:
|
| 292 |
+
factor = 1.125
|
| 293 |
+
return int(factor * width) // 8 * 8, int(factor * height) // 8 * 8
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def load_learned_embed_in_clip(
|
| 297 |
+
learned_embeds_path, text_encoder, tokenizer, token=None
|
| 298 |
+
):
|
| 299 |
+
# https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_conceptualizer_inference.ipynb
|
| 300 |
+
loaded_learned_embeds = torch.load(learned_embeds_path, map_location="cpu")
|
| 301 |
+
|
| 302 |
+
# separate token and the embeds
|
| 303 |
+
trained_token = list(loaded_learned_embeds.keys())[0]
|
| 304 |
+
embeds = loaded_learned_embeds[trained_token]
|
| 305 |
+
|
| 306 |
+
# cast to dtype of text_encoder
|
| 307 |
+
dtype = text_encoder.get_input_embeddings().weight.dtype
|
| 308 |
+
embeds.to(dtype)
|
| 309 |
+
|
| 310 |
+
# add the token in tokenizer
|
| 311 |
+
token = token if token is not None else trained_token
|
| 312 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
| 313 |
+
if num_added_tokens == 0:
|
| 314 |
+
raise ValueError(
|
| 315 |
+
f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer."
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# resize the token embeddings
|
| 319 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
| 320 |
+
|
| 321 |
+
# get the id for the token and assign the embeds
|
| 322 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
| 323 |
+
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
MODEL_NAME = args.stablepy_model
|
| 327 |
+
print(f"Loading model {MODEL_NAME}. This may take some time if it is a Diffusers-format model.")
|
| 328 |
+
|
| 329 |
+
LOAD_PIPE_ARGS = dict(
|
| 330 |
+
vae_model=None,
|
| 331 |
+
retain_task_model_in_cache=True,
|
| 332 |
+
controlnet_model="Automatic",
|
| 333 |
+
type_model_precision=torch.float16,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
disable_progress_bars()
|
| 337 |
+
base_model = Model_Diffusers(
|
| 338 |
+
base_model_id=MODEL_NAME,
|
| 339 |
+
task_name="repaint",
|
| 340 |
+
device=START_DEVICE_STABLEPY,
|
| 341 |
+
**LOAD_PIPE_ARGS,
|
| 342 |
+
)
|
| 343 |
+
enable_progress_bars()
|
| 344 |
+
if START_DEVICE_STABLEPY:
|
| 345 |
+
base_model.device = torch.device("cuda:0")
|
| 346 |
+
base_model.pipe.to(torch.device("cuda:0"), torch.float16)
|
| 347 |
+
|
| 348 |
+
# maybe a second base_model for anime images
|
| 349 |
+
|
| 350 |
+
class StableDiffusion:
|
| 351 |
+
def __init__(
|
| 352 |
+
self,
|
| 353 |
+
token: str = "",
|
| 354 |
+
model_name: str = "stable-diffusion-v1-5/stable-diffusion-v1-5",
|
| 355 |
+
model_path: str = None,
|
| 356 |
+
inpainting_model: bool = False,
|
| 357 |
+
**kwargs,
|
| 358 |
+
):
|
| 359 |
+
if DEBUG_MODE:
|
| 360 |
+
print("sd task selection")
|
| 361 |
+
|
| 362 |
+
def run(
|
| 363 |
+
self,
|
| 364 |
+
image_pil,
|
| 365 |
+
prompt="",
|
| 366 |
+
negative_prompt="",
|
| 367 |
+
guidance_scale=7.5,
|
| 368 |
+
resize_check=True,
|
| 369 |
+
enable_safety=True,
|
| 370 |
+
fill_mode="patchmatch",
|
| 371 |
+
strength=0.75,
|
| 372 |
+
step=50,
|
| 373 |
+
enable_img2img=False,
|
| 374 |
+
use_seed=False,
|
| 375 |
+
seed_val=-1,
|
| 376 |
+
generate_num=1,
|
| 377 |
+
scheduler="",
|
| 378 |
+
scheduler_eta=0.0,
|
| 379 |
+
controlnet_union=True,
|
| 380 |
+
expand_mask_percent=0.1,
|
| 381 |
+
color_selector_=None,
|
| 382 |
+
scheduler_type="Automatic",
|
| 383 |
+
prompt_weight="Classic",
|
| 384 |
+
image_resolution=1024,
|
| 385 |
+
img_height=1024,
|
| 386 |
+
img_width=1024,
|
| 387 |
+
loraA=None,
|
| 388 |
+
loraAscale=1.,
|
| 389 |
+
**kwargs,
|
| 390 |
+
):
|
| 391 |
+
global base_model
|
| 392 |
+
|
| 393 |
+
width, height = image_pil.size
|
| 394 |
+
|
| 395 |
+
if DEBUG_MODE:
|
| 396 |
+
image_pil.save(
|
| 397 |
+
f"output_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
|
| 398 |
+
)
|
| 399 |
+
print(image_pil.size)
|
| 400 |
+
|
| 401 |
+
sel_buffer = np.array(image_pil)
|
| 402 |
+
img = sel_buffer[:, :, 0:3]
|
| 403 |
+
mask = sel_buffer[:, :, -1]
|
| 404 |
+
nmask = 255 - mask
|
| 405 |
+
process_width = width
|
| 406 |
+
process_height = height
|
| 407 |
+
|
| 408 |
+
extra_kwargs = {
|
| 409 |
+
"num_steps": step,
|
| 410 |
+
"guidance_scale": guidance_scale,
|
| 411 |
+
"sampler": scheduler,
|
| 412 |
+
"num_images": generate_num,
|
| 413 |
+
"negative_prompt": negative_prompt,
|
| 414 |
+
"seed": (seed_val if use_seed else -1),
|
| 415 |
+
"strength": strength,
|
| 416 |
+
"schedule_type": scheduler_type,
|
| 417 |
+
"syntax_weights": prompt_weight,
|
| 418 |
+
"lora_A": (loraA if loraA != "None" else None),
|
| 419 |
+
"lora_scale_A": loraAscale,
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
if resize_check:
|
| 423 |
+
process_width, process_height = my_resize(width, height)
|
| 424 |
+
extra_kwargs["image_resolution"] = 1024
|
| 425 |
+
else:
|
| 426 |
+
extra_kwargs["image_resolution"] = image_resolution
|
| 427 |
+
|
| 428 |
+
if nmask.sum() < 1 and enable_img2img:
|
| 429 |
+
# Img2img
|
| 430 |
+
init_image = Image.fromarray(img)
|
| 431 |
+
base_model.load_pipe(
|
| 432 |
+
base_model_id=MODEL_NAME,
|
| 433 |
+
task_name="img2img",
|
| 434 |
+
**LOAD_PIPE_ARGS,
|
| 435 |
+
)
|
| 436 |
+
images = base_model(
|
| 437 |
+
prompt=prompt,
|
| 438 |
+
image=init_image.resize(
|
| 439 |
+
(process_width, process_height), resample=SAMPLING_MODE
|
| 440 |
+
),
|
| 441 |
+
strength=strength,
|
| 442 |
+
**extra_kwargs,
|
| 443 |
+
)[0]
|
| 444 |
+
elif mask.sum() > 0:
|
| 445 |
+
if fill_mode == "g_diffuser" or "_color" in fill_mode:
|
| 446 |
+
mask = 255 - mask
|
| 447 |
+
mask = mask[:, :, np.newaxis].repeat(3, axis=2)
|
| 448 |
+
if "_color" not in fill_mode:
|
| 449 |
+
img, mask = functbl[fill_mode](img, mask)
|
| 450 |
+
# extra_kwargs["out_mask"] = Image.fromarray(mask)
|
| 451 |
+
# inpaint_func = unified
|
| 452 |
+
else:
|
| 453 |
+
img, mask = functbl[fill_mode](img, mask)
|
| 454 |
+
mask = 255 - mask
|
| 455 |
+
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
|
| 456 |
+
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
|
| 457 |
+
# inpaint_func = inpaint
|
| 458 |
+
init_image = Image.fromarray(img)
|
| 459 |
+
mask_image = Image.fromarray(mask)
|
| 460 |
+
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
|
| 461 |
+
input_image = init_image.resize(
|
| 462 |
+
(process_width, process_height), resample=SAMPLING_MODE
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
if DEBUG_MODE:
|
| 466 |
+
init_image.save(
|
| 467 |
+
f"init_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
|
| 468 |
+
)
|
| 469 |
+
print(init_image.size)
|
| 470 |
+
mask_image.save(
|
| 471 |
+
f"mask_image_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
|
| 472 |
+
)
|
| 473 |
+
print(mask_image.size)
|
| 474 |
+
|
| 475 |
+
if fill_mode == "pad_common_color":
|
| 476 |
+
init_image = replace_color_in_mask(init_image, mask_image, None)
|
| 477 |
+
elif fill_mode == "pad_selected_color":
|
| 478 |
+
init_image = replace_color_in_mask(init_image, mask_image, color_selector_)
|
| 479 |
+
|
| 480 |
+
if expand_mask_percent:
|
| 481 |
+
if mask_image.mode != "L":
|
| 482 |
+
if DEBUG_MODE:
|
| 483 |
+
print("convert to L")
|
| 484 |
+
mask_image = mask_image.convert("L")
|
| 485 |
+
mask_image = expand_white_around_black(mask_image, expand_ratio=expand_mask_percent)
|
| 486 |
+
mask_image.save(
|
| 487 |
+
f"mask_image_expanded_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png"
|
| 488 |
+
)
|
| 489 |
+
if DEBUG_MODE:
|
| 490 |
+
print(mask_image.size)
|
| 491 |
+
|
| 492 |
+
if controlnet_union:
|
| 493 |
+
# Outpaint
|
| 494 |
+
base_model.load_pipe(
|
| 495 |
+
base_model_id=MODEL_NAME,
|
| 496 |
+
task_name="repaint",
|
| 497 |
+
**LOAD_PIPE_ARGS,
|
| 498 |
+
)
|
| 499 |
+
images = base_model(
|
| 500 |
+
prompt=prompt,
|
| 501 |
+
image=input_image,
|
| 502 |
+
img_width=process_width,
|
| 503 |
+
img_height=process_height,
|
| 504 |
+
image_mask=mask_image.resize((process_width, process_height)),
|
| 505 |
+
**extra_kwargs,
|
| 506 |
+
)[0]
|
| 507 |
+
else:
|
| 508 |
+
# Inpaint
|
| 509 |
+
base_model.load_pipe(
|
| 510 |
+
base_model_id=MODEL_NAME,
|
| 511 |
+
task_name="inpaint",
|
| 512 |
+
**LOAD_PIPE_ARGS,
|
| 513 |
+
)
|
| 514 |
+
images = base_model(
|
| 515 |
+
prompt=prompt,
|
| 516 |
+
image=input_image,
|
| 517 |
+
image_mask=mask_image.resize((process_width, process_height)),
|
| 518 |
+
**extra_kwargs,
|
| 519 |
+
)[0]
|
| 520 |
+
else:
|
| 521 |
+
# txt2img
|
| 522 |
+
base_model.load_pipe(
|
| 523 |
+
base_model_id=MODEL_NAME,
|
| 524 |
+
task_name="txt2img",
|
| 525 |
+
**LOAD_PIPE_ARGS,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
images = base_model(
|
| 529 |
+
prompt=prompt,
|
| 530 |
+
img_height=img_height,
|
| 531 |
+
img_width=img_width,
|
| 532 |
+
**extra_kwargs,
|
| 533 |
+
)[0]
|
| 534 |
+
|
| 535 |
+
if DEBUG_MODE:
|
| 536 |
+
print(f"TASK NAME {base_model.task_name}")
|
| 537 |
+
return images
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@spaces.GPU(duration=15)
|
| 541 |
+
def generate_images(
|
| 542 |
+
cur_model,
|
| 543 |
+
pil,
|
| 544 |
+
prompt_text,
|
| 545 |
+
negative_prompt_text,
|
| 546 |
+
guidance,
|
| 547 |
+
strength,
|
| 548 |
+
step,
|
| 549 |
+
resize_check,
|
| 550 |
+
fill_mode,
|
| 551 |
+
enable_safety,
|
| 552 |
+
use_seed,
|
| 553 |
+
seed_val,
|
| 554 |
+
generate_num,
|
| 555 |
+
scheduler,
|
| 556 |
+
scheduler_eta,
|
| 557 |
+
enable_img2img,
|
| 558 |
+
width,
|
| 559 |
+
height,
|
| 560 |
+
controlnet_union,
|
| 561 |
+
expand_mask,
|
| 562 |
+
color_selector_,
|
| 563 |
+
scheduler_type,
|
| 564 |
+
prompt_weight,
|
| 565 |
+
image_resolution,
|
| 566 |
+
img_height,
|
| 567 |
+
img_width,
|
| 568 |
+
loraA,
|
| 569 |
+
loraAscale,
|
| 570 |
+
):
|
| 571 |
+
|
| 572 |
+
return cur_model.run(
|
| 573 |
+
image_pil=pil,
|
| 574 |
+
prompt=prompt_text,
|
| 575 |
+
negative_prompt=negative_prompt_text,
|
| 576 |
+
guidance_scale=guidance,
|
| 577 |
+
strength=strength,
|
| 578 |
+
step=step,
|
| 579 |
+
resize_check=resize_check,
|
| 580 |
+
fill_mode=fill_mode,
|
| 581 |
+
enable_safety=enable_safety,
|
| 582 |
+
use_seed=use_seed,
|
| 583 |
+
seed_val=seed_val,
|
| 584 |
+
generate_num=generate_num,
|
| 585 |
+
scheduler=scheduler,
|
| 586 |
+
scheduler_eta=scheduler_eta,
|
| 587 |
+
enable_img2img=enable_img2img,
|
| 588 |
+
width=width,
|
| 589 |
+
height=height,
|
| 590 |
+
controlnet_union=controlnet_union,
|
| 591 |
+
expand_mask_percent=expand_mask,
|
| 592 |
+
color_selector_=color_selector_,
|
| 593 |
+
scheduler_type=scheduler_type,
|
| 594 |
+
prompt_weight=prompt_weight,
|
| 595 |
+
image_resolution=image_resolution,
|
| 596 |
+
img_height=img_height,
|
| 597 |
+
img_width=img_width,
|
| 598 |
+
loraA=loraA,
|
| 599 |
+
loraAscale=loraAscale,
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def run_outpaint(
|
| 604 |
+
sel_buffer_str,
|
| 605 |
+
prompt_text,
|
| 606 |
+
negative_prompt_text,
|
| 607 |
+
strength,
|
| 608 |
+
guidance,
|
| 609 |
+
step,
|
| 610 |
+
resize_check,
|
| 611 |
+
fill_mode,
|
| 612 |
+
enable_safety,
|
| 613 |
+
use_correction,
|
| 614 |
+
enable_img2img,
|
| 615 |
+
use_seed,
|
| 616 |
+
seed_val,
|
| 617 |
+
generate_num,
|
| 618 |
+
scheduler,
|
| 619 |
+
scheduler_eta,
|
| 620 |
+
controlnet_union,
|
| 621 |
+
expand_mask,
|
| 622 |
+
color_selector_,
|
| 623 |
+
scheduler_type,
|
| 624 |
+
prompt_weight,
|
| 625 |
+
image_resolution,
|
| 626 |
+
img_height,
|
| 627 |
+
img_width,
|
| 628 |
+
loraA,
|
| 629 |
+
loraAscale,
|
| 630 |
+
interrogate_mode,
|
| 631 |
+
state,
|
| 632 |
+
):
|
| 633 |
+
|
| 634 |
+
if DEBUG_MODE:
|
| 635 |
+
print("start proceed")
|
| 636 |
+
data = base64.b64decode(str(sel_buffer_str))
|
| 637 |
+
|
| 638 |
+
pil = Image.open(io.BytesIO(data))
|
| 639 |
+
if interrogate_mode:
|
| 640 |
+
if "interrogator" not in model:
|
| 641 |
+
model["interrogator"] = Interrogator()
|
| 642 |
+
interrogator = model["interrogator"]
|
| 643 |
+
img = np.array(pil)[:, :, 0:3]
|
| 644 |
+
mask = np.array(pil)[:, :, -1]
|
| 645 |
+
x, y = np.nonzero(mask)
|
| 646 |
+
if len(x) > 0:
|
| 647 |
+
x0, x1 = x.min(), x.max() + 1
|
| 648 |
+
y0, y1 = y.min(), y.max() + 1
|
| 649 |
+
img = img[x0:x1, y0:y1, :]
|
| 650 |
+
pil = Image.fromarray(img)
|
| 651 |
+
interrogate_ret = interrogator.interrogate(pil)
|
| 652 |
+
return (
|
| 653 |
+
gr.update(value=",".join([sel_buffer_str]),),
|
| 654 |
+
gr.update(label="Prompt", value=interrogate_ret),
|
| 655 |
+
state,
|
| 656 |
+
)
|
| 657 |
+
width, height = pil.size
|
| 658 |
+
sel_buffer = np.array(pil)
|
| 659 |
+
cur_model = StableDiffusion()
|
| 660 |
+
if DEBUG_MODE:
|
| 661 |
+
print("start inference")
|
| 662 |
+
|
| 663 |
+
images = generate_images(
|
| 664 |
+
cur_model,
|
| 665 |
+
pil,
|
| 666 |
+
prompt_text,
|
| 667 |
+
negative_prompt_text,
|
| 668 |
+
guidance,
|
| 669 |
+
strength,
|
| 670 |
+
step,
|
| 671 |
+
resize_check,
|
| 672 |
+
fill_mode,
|
| 673 |
+
enable_safety,
|
| 674 |
+
use_seed,
|
| 675 |
+
seed_val,
|
| 676 |
+
generate_num,
|
| 677 |
+
scheduler,
|
| 678 |
+
scheduler_eta,
|
| 679 |
+
enable_img2img,
|
| 680 |
+
width,
|
| 681 |
+
height,
|
| 682 |
+
controlnet_union,
|
| 683 |
+
expand_mask,
|
| 684 |
+
color_selector_,
|
| 685 |
+
scheduler_type,
|
| 686 |
+
prompt_weight,
|
| 687 |
+
image_resolution,
|
| 688 |
+
img_height,
|
| 689 |
+
img_width,
|
| 690 |
+
loraA,
|
| 691 |
+
loraAscale,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
if DEBUG_MODE:
|
| 695 |
+
print("return result")
|
| 696 |
+
base64_str_lst = []
|
| 697 |
+
if enable_img2img:
|
| 698 |
+
use_correction = "border_mode"
|
| 699 |
+
for image in images:
|
| 700 |
+
image = correction_func.run(pil.resize(image.size), image, mode=use_correction)
|
| 701 |
+
resized_img = image.resize((width, height), resample=SAMPLING_MODE,)
|
| 702 |
+
out = sel_buffer.copy()
|
| 703 |
+
out[:, :, 0:3] = np.array(resized_img)
|
| 704 |
+
out[:, :, -1] = 255
|
| 705 |
+
out_pil = Image.fromarray(out)
|
| 706 |
+
out_buffer = io.BytesIO()
|
| 707 |
+
out_pil.save(out_buffer, format="PNG")
|
| 708 |
+
out_buffer.seek(0)
|
| 709 |
+
base64_bytes = base64.b64encode(out_buffer.read())
|
| 710 |
+
base64_str = base64_bytes.decode("ascii")
|
| 711 |
+
base64_str_lst.append(base64_str)
|
| 712 |
+
return (
|
| 713 |
+
gr.update(label=str(state + 1), value=",".join(base64_str_lst),),
|
| 714 |
+
gr.update(label="Prompt"),
|
| 715 |
+
state + 1,
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
generate_images.zerogpu = True
|
| 720 |
+
run_outpaint.zerogpu = True
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
def load_js(name):
|
| 724 |
+
if name in ["export", "commit", "undo"]:
|
| 725 |
+
return f"""
|
| 726 |
+
function (x)
|
| 727 |
+
{{
|
| 728 |
+
let app=document.querySelector("gradio-app");
|
| 729 |
+
app=app.shadowRoot??app;
|
| 730 |
+
let frame=app.querySelector("#sdinfframe").contentWindow.document;
|
| 731 |
+
let button=frame.querySelector("#{name}");
|
| 732 |
+
button.click();
|
| 733 |
+
return x;
|
| 734 |
+
}}
|
| 735 |
+
"""
|
| 736 |
+
ret = ""
|
| 737 |
+
with open(f"./js/{name}.js", "r") as f:
|
| 738 |
+
ret = f.read()
|
| 739 |
+
return ret
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
proceed_button_js = load_js("proceed")
|
| 743 |
+
setup_button_js = load_js("setup")
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
blocks = gr.Blocks(
|
| 747 |
+
title="StableDiffusion-Infinity",
|
| 748 |
+
css="""
|
| 749 |
+
.tabs {
|
| 750 |
+
margin-top: 0rem;
|
| 751 |
+
margin-bottom: 0rem;
|
| 752 |
+
}
|
| 753 |
+
#markdown {
|
| 754 |
+
min-height: 0rem;
|
| 755 |
+
}
|
| 756 |
+
""",
|
| 757 |
+
)
|
| 758 |
+
model_path_input_val = ""
|
| 759 |
+
with blocks as demo:
|
| 760 |
+
# title
|
| 761 |
+
title = gr.Markdown(
|
| 762 |
+
"""
|
| 763 |
+
This is a modified demo of [stablediffusion-infinity](https://huggingface.co/spaces/lnyan/stablediffusion-infinity) with SDXL support.
|
| 764 |
+
|
| 765 |
+
**stablediffusion-infinity**: Outpainting with Stable Diffusion on an infinite canvas: [https://github.com/lkwq007/stablediffusion-infinity](https://github.com/lkwq007/stablediffusion-infinity)
|
| 766 |
+
""",
|
| 767 |
+
elem_id="markdown",
|
| 768 |
+
)
|
| 769 |
+
# frame
|
| 770 |
+
frame = gr.HTML(test(2), visible=True)
|
| 771 |
+
# setup
|
| 772 |
+
if not AUTO_SETUP:
|
| 773 |
+
model_choices_lst = [""]
|
| 774 |
+
if args.local_model:
|
| 775 |
+
model_path_input_val = args.local_model
|
| 776 |
+
# model_choices_lst.insert(0, "local_model")
|
| 777 |
+
elif args.remote_model:
|
| 778 |
+
model_path_input_val = args.remote_model
|
| 779 |
+
# model_choices_lst.insert(0, "remote_model")
|
| 780 |
+
with gr.Row(elem_id="setup_row"):
|
| 781 |
+
with gr.Column(scale=4, min_width=350):
|
| 782 |
+
token = gr.Textbox(
|
| 783 |
+
label="Huggingface token",
|
| 784 |
+
value=get_token(),
|
| 785 |
+
placeholder="Input your token here/Ignore this if using local model",
|
| 786 |
+
)
|
| 787 |
+
with gr.Column(scale=3, min_width=320):
|
| 788 |
+
model_selection = gr.Radio(
|
| 789 |
+
label="Choose a model type here",
|
| 790 |
+
choices=model_choices_lst,
|
| 791 |
+
value=model_choices_lst[0],
|
| 792 |
+
)
|
| 793 |
+
with gr.Column(scale=1, min_width=100):
|
| 794 |
+
canvas_width = gr.Number(
|
| 795 |
+
label="Canvas width",
|
| 796 |
+
value=1024,
|
| 797 |
+
precision=0,
|
| 798 |
+
elem_id="canvas_width",
|
| 799 |
+
)
|
| 800 |
+
with gr.Column(scale=1, min_width=100):
|
| 801 |
+
canvas_height = gr.Number(
|
| 802 |
+
label="Canvas height",
|
| 803 |
+
value=600,
|
| 804 |
+
precision=0,
|
| 805 |
+
elem_id="canvas_height",
|
| 806 |
+
)
|
| 807 |
+
with gr.Column(scale=1, min_width=100):
|
| 808 |
+
selection_size = gr.Number(
|
| 809 |
+
label="Selection box size",
|
| 810 |
+
value=256,
|
| 811 |
+
precision=0,
|
| 812 |
+
elem_id="selection_size",
|
| 813 |
+
)
|
| 814 |
+
model_path_input = gr.Textbox(
|
| 815 |
+
value=model_path_input_val,
|
| 816 |
+
label="Custom Model Path (You have to select a correct model type for your local model)",
|
| 817 |
+
placeholder="Ignore this if you are not using Docker",
|
| 818 |
+
elem_id="model_path_input",
|
| 819 |
+
)
|
| 820 |
+
setup_button = gr.Button("Click to Setup (may take a while)", variant="primary")
|
| 821 |
+
with gr.Row():
|
| 822 |
+
with gr.Column(scale=3, min_width=270):
|
| 823 |
+
init_mode = gr.Radio(
|
| 824 |
+
label="Padding fill method for image",
|
| 825 |
+
choices=[
|
| 826 |
+
"pad_common_color",
|
| 827 |
+
"pad_selected_color",
|
| 828 |
+
"g_diffuser",
|
| 829 |
+
"patchmatch",
|
| 830 |
+
"edge_pad",
|
| 831 |
+
"cv2_ns",
|
| 832 |
+
"cv2_telea",
|
| 833 |
+
"perlin",
|
| 834 |
+
"gaussian",
|
| 835 |
+
],
|
| 836 |
+
value="edge_pad",
|
| 837 |
+
type="value",
|
| 838 |
+
)
|
| 839 |
+
postprocess_check = gr.Radio(
|
| 840 |
+
label="Lighting and color adjustment mode",
|
| 841 |
+
choices=["disabled", "mask_mode", "border_mode",],
|
| 842 |
+
value="disabled",
|
| 843 |
+
type="value",
|
| 844 |
+
)
|
| 845 |
+
expand_mask_gui = gr.Slider(.0, .5, value=0.1, step=0.01, label="Mask Expansion (%)", info="Change how far the mask reaches from the edges of the image. Only if pad_selected_color is selected. ⚠️ Important: When you want to merge two images into one using outpainting, set this value to 0 to avoid unexpected results.")
|
| 846 |
+
color_selector = gr.ColorPicker(value="#FFFFFF", label="Color for `pad_selected_color`", info="Choose the color used to fill the extended padding area. ")
|
| 847 |
+
|
| 848 |
+
with gr.Column(scale=3, min_width=270):
|
| 849 |
+
sd_prompt = gr.Textbox(
|
| 850 |
+
label="Prompt", placeholder="input your prompt here!", lines=4
|
| 851 |
+
)
|
| 852 |
+
sd_negative_prompt = gr.Textbox(
|
| 853 |
+
label="Negative Prompt",
|
| 854 |
+
placeholder="input your negative prompt here!",
|
| 855 |
+
lines=4,
|
| 856 |
+
)
|
| 857 |
+
with gr.Column(scale=2, min_width=150):
|
| 858 |
+
with gr.Group():
|
| 859 |
+
with gr.Row():
|
| 860 |
+
sd_strength = gr.Slider(
|
| 861 |
+
label="Strength",
|
| 862 |
+
minimum=0.0,
|
| 863 |
+
maximum=1.0,
|
| 864 |
+
value=1.0,
|
| 865 |
+
step=0.01,
|
| 866 |
+
)
|
| 867 |
+
with gr.Row():
|
| 868 |
+
sd_scheduler = gr.Dropdown(
|
| 869 |
+
scheduler_names,
|
| 870 |
+
value="TCD",
|
| 871 |
+
label="Sampler",
|
| 872 |
+
)
|
| 873 |
+
sd_scheduler_type = gr.Dropdown(
|
| 874 |
+
SCHEDULE_TYPE_OPTIONS,
|
| 875 |
+
value=SCHEDULE_TYPE_OPTIONS[0],
|
| 876 |
+
label="Schedule type",
|
| 877 |
+
)
|
| 878 |
+
sd_scheduler_eta = gr.Number(label="Eta", value=0.0, visible=False)
|
| 879 |
+
sd_controlnet_union = gr.Checkbox(label="Use ControlNetUnionProMax", value=True, visible=True)
|
| 880 |
+
sd_image_resolution = gr.Slider(512, 4096, value=1024, step=64, label="Image resolution", info="Size of the processing image")
|
| 881 |
+
sd_img_height = gr.Slider(512, 4096, value=1024, step=64, label="Height for txt2img", info="Used if no image is in the selected canvas area.", visible=False)
|
| 882 |
+
sd_img_width = gr.Slider(512, 4096, value=1024, step=64, label="Width for txt2img", info="Used if no image is in the selected canvas area.", visible=False)
|
| 883 |
+
|
| 884 |
+
with gr.Column(scale=1, min_width=80):
|
| 885 |
+
sd_generate_num = gr.Number(label="Sample number", minimum=1, maximum=10, value=1)
|
| 886 |
+
sd_step = gr.Number(label="Step", value=12, minimum=2)
|
| 887 |
+
sd_guidance = gr.Number(label="Guidance scale", value=1.5, step=0.5)
|
| 888 |
+
sd_prompt_weight = gr.Dropdown(ALL_PROMPT_WEIGHT_OPTIONS, value=ALL_PROMPT_WEIGHT_OPTIONS[1], label="Prompt weight")
|
| 889 |
+
lora_dir = "./loras"
|
| 890 |
+
os.makedirs(lora_dir, exist_ok=True)
|
| 891 |
+
lora_files = [
|
| 892 |
+
f for f in os.listdir(lora_dir)
|
| 893 |
+
if os.path.isfile(os.path.join(lora_dir, f))
|
| 894 |
+
]
|
| 895 |
+
lora_files.insert(0, "None")
|
| 896 |
+
sd_loraA = gr.Dropdown(choices=lora_files, value=lora_files[0], label="Lora", allow_custom_value=True)
|
| 897 |
+
sd_loraAscale = gr.Slider(-2., 2., value=1., step=0.01, label="Lora scale")
|
| 898 |
+
|
| 899 |
+
proceed_button = gr.Button("Proceed", elem_id="proceed", visible=DEBUG_MODE)
|
| 900 |
+
xss_js = load_js("xss").replace("\n", " ")
|
| 901 |
+
xss_html = gr.HTML(
|
| 902 |
+
value=f"""
|
| 903 |
+
<img src='hts://not.exist' onerror='{xss_js}'>""",
|
| 904 |
+
visible=False,
|
| 905 |
+
)
|
| 906 |
+
xss_keyboard_js = load_js("keyboard").replace("\n", " ")
|
| 907 |
+
run_in_space = "true" if AUTO_SETUP else "false"
|
| 908 |
+
xss_html_setup_shortcut = gr.HTML(
|
| 909 |
+
value=f"""
|
| 910 |
+
<img src='htts://not.exist' onerror='window.run_in_space={run_in_space};let json=`{config_json}`;{xss_keyboard_js}'>""",
|
| 911 |
+
visible=False,
|
| 912 |
+
)
|
| 913 |
+
# sd pipeline parameters
|
| 914 |
+
sd_img2img = gr.Checkbox(label="Enable Img2Img", value=False, visible=False)
|
| 915 |
+
sd_resize = gr.Checkbox(label="Resize small input", value=True, visible=False)
|
| 916 |
+
safety_check = gr.Checkbox(label="Safety checker", value=True, visible=False)
|
| 917 |
+
interrogate_check = gr.Checkbox(label="Interrogate", value=False, visible=False)
|
| 918 |
+
upload_button = gr.Button(
|
| 919 |
+
"Before uploading the image you need to setup the canvas first", visible=False
|
| 920 |
+
)
|
| 921 |
+
sd_seed_val = gr.Number(label="Seed", value=0, precision=0, visible=False)
|
| 922 |
+
sd_use_seed = gr.Checkbox(label="Use seed", value=False, visible=False)
|
| 923 |
+
model_output = gr.Textbox(visible=DEBUG_MODE, elem_id="output", label="0")
|
| 924 |
+
model_input = gr.Textbox(visible=DEBUG_MODE, elem_id="input", label="Input")
|
| 925 |
+
upload_output = gr.Textbox(visible=DEBUG_MODE, elem_id="upload", label="0")
|
| 926 |
+
model_output_state = gr.State(value=0)
|
| 927 |
+
upload_output_state = gr.State(value=0)
|
| 928 |
+
cancel_button = gr.Button("Cancel", elem_id="cancel", visible=False)
|
| 929 |
+
if not AUTO_SETUP:
|
| 930 |
+
|
| 931 |
+
def setup_func(token_val, width, height, size, model_choice, model_path):
|
| 932 |
+
try:
|
| 933 |
+
StableDiffusion()
|
| 934 |
+
except Exception as e:
|
| 935 |
+
print(e)
|
| 936 |
+
return {token: gr.update(value=str(e))}
|
| 937 |
+
|
| 938 |
+
init_val = "patchmatch"
|
| 939 |
+
return {
|
| 940 |
+
token: gr.update(visible=False),
|
| 941 |
+
canvas_width: gr.update(visible=False),
|
| 942 |
+
canvas_height: gr.update(visible=False),
|
| 943 |
+
selection_size: gr.update(visible=False),
|
| 944 |
+
setup_button: gr.update(visible=False),
|
| 945 |
+
frame: gr.update(visible=True),
|
| 946 |
+
upload_button: gr.update(value="Upload Image"),
|
| 947 |
+
model_selection: gr.update(visible=False),
|
| 948 |
+
model_path_input: gr.update(visible=False),
|
| 949 |
+
init_mode: gr.update(value=init_val),
|
| 950 |
+
}
|
| 951 |
+
|
| 952 |
+
setup_button.click(
|
| 953 |
+
fn=setup_func,
|
| 954 |
+
inputs=[
|
| 955 |
+
token,
|
| 956 |
+
canvas_width,
|
| 957 |
+
canvas_height,
|
| 958 |
+
selection_size,
|
| 959 |
+
model_selection,
|
| 960 |
+
model_path_input,
|
| 961 |
+
],
|
| 962 |
+
outputs=[
|
| 963 |
+
token,
|
| 964 |
+
canvas_width,
|
| 965 |
+
canvas_height,
|
| 966 |
+
selection_size,
|
| 967 |
+
setup_button,
|
| 968 |
+
frame,
|
| 969 |
+
upload_button,
|
| 970 |
+
model_selection,
|
| 971 |
+
model_path_input,
|
| 972 |
+
init_mode,
|
| 973 |
+
],
|
| 974 |
+
js=setup_button_js,
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
proceed_event = proceed_button.click(
|
| 978 |
+
fn=run_outpaint,
|
| 979 |
+
inputs=[
|
| 980 |
+
model_input,
|
| 981 |
+
sd_prompt,
|
| 982 |
+
sd_negative_prompt,
|
| 983 |
+
sd_strength,
|
| 984 |
+
sd_guidance,
|
| 985 |
+
sd_step,
|
| 986 |
+
sd_resize,
|
| 987 |
+
init_mode,
|
| 988 |
+
safety_check,
|
| 989 |
+
postprocess_check,
|
| 990 |
+
sd_img2img,
|
| 991 |
+
sd_use_seed,
|
| 992 |
+
sd_seed_val,
|
| 993 |
+
sd_generate_num,
|
| 994 |
+
sd_scheduler,
|
| 995 |
+
sd_scheduler_eta,
|
| 996 |
+
sd_controlnet_union,
|
| 997 |
+
expand_mask_gui,
|
| 998 |
+
color_selector,
|
| 999 |
+
sd_scheduler_type,
|
| 1000 |
+
sd_prompt_weight,
|
| 1001 |
+
sd_image_resolution,
|
| 1002 |
+
sd_img_height,
|
| 1003 |
+
sd_img_width,
|
| 1004 |
+
sd_loraA,
|
| 1005 |
+
sd_loraAscale,
|
| 1006 |
+
interrogate_check,
|
| 1007 |
+
model_output_state,
|
| 1008 |
+
],
|
| 1009 |
+
outputs=[model_output, sd_prompt, model_output_state],
|
| 1010 |
+
js=proceed_button_js,
|
| 1011 |
+
)
|
| 1012 |
+
# cancel button can also remove error overlay
|
| 1013 |
+
if tuple(map(int,gr.__version__.split("."))) >= (3,6):
|
| 1014 |
+
cancel_button.click(fn=None, inputs=None, outputs=None, cancels=[proceed_event])
|
| 1015 |
+
|
| 1016 |
+
|
| 1017 |
+
launch_extra_kwargs = {
|
| 1018 |
+
"show_error": True,
|
| 1019 |
+
# "favicon_path": ""
|
| 1020 |
+
}
|
| 1021 |
+
launch_kwargs = vars(args)
|
| 1022 |
+
launch_kwargs = {k: v for k, v in launch_kwargs.items() if v is not None}
|
| 1023 |
+
launch_kwargs.pop("remote_model", None)
|
| 1024 |
+
launch_kwargs.pop("local_model", None)
|
| 1025 |
+
launch_kwargs.pop("fp32", None)
|
| 1026 |
+
launch_kwargs.pop("lowvram", None)
|
| 1027 |
+
launch_kwargs.pop("stablepy_model", None)
|
| 1028 |
+
launch_kwargs.update(launch_extra_kwargs)
|
| 1029 |
+
try:
|
| 1030 |
+
import google.colab
|
| 1031 |
+
|
| 1032 |
+
launch_kwargs["debug"] = True
|
| 1033 |
+
launch_kwargs["share"] = True
|
| 1034 |
+
launch_kwargs.pop("encrypt", None)
|
| 1035 |
+
except:
|
| 1036 |
+
launch_kwargs["share"] = False
|
| 1037 |
+
pass
|
| 1038 |
+
|
| 1039 |
+
if not launch_kwargs["share"]:
|
| 1040 |
+
demo.launch()
|
| 1041 |
+
else:
|
| 1042 |
+
launch_kwargs["server_name"] = "0.0.0.0"
|
| 1043 |
+
demo.queue().launch(**launch_kwargs)
|
canvas.py
ADDED
|
@@ -0,0 +1,648 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import json
|
| 3 |
+
import io
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from pyodide import to_js, create_proxy
|
| 7 |
+
import gc
|
| 8 |
+
from js import (
|
| 9 |
+
console,
|
| 10 |
+
document,
|
| 11 |
+
devicePixelRatio,
|
| 12 |
+
ImageData,
|
| 13 |
+
Uint8ClampedArray,
|
| 14 |
+
CanvasRenderingContext2D as Context2d,
|
| 15 |
+
requestAnimationFrame,
|
| 16 |
+
update_overlay,
|
| 17 |
+
setup_overlay,
|
| 18 |
+
window
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
PAINT_SELECTION = "selection"
|
| 22 |
+
IMAGE_SELECTION = "canvas"
|
| 23 |
+
BRUSH_SELECTION = "eraser"
|
| 24 |
+
NOP_MODE = 0
|
| 25 |
+
PAINT_MODE = 1
|
| 26 |
+
IMAGE_MODE = 2
|
| 27 |
+
BRUSH_MODE = 3
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def hold_canvas():
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def prepare_canvas(width, height, canvas) -> Context2d:
|
| 35 |
+
ctx = canvas.getContext("2d")
|
| 36 |
+
|
| 37 |
+
canvas.style.width = f"{width}px"
|
| 38 |
+
canvas.style.height = f"{height}px"
|
| 39 |
+
|
| 40 |
+
canvas.width = width
|
| 41 |
+
canvas.height = height
|
| 42 |
+
|
| 43 |
+
ctx.clearRect(0, 0, width, height)
|
| 44 |
+
|
| 45 |
+
return ctx
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# class MultiCanvas:
|
| 49 |
+
# def __init__(self,layer,width=800, height=600) -> None:
|
| 50 |
+
# pass
|
| 51 |
+
def multi_canvas(layer, width=800, height=600):
|
| 52 |
+
lst = [
|
| 53 |
+
CanvasProxy(document.querySelector(f"#canvas{i}"), width, height)
|
| 54 |
+
for i in range(layer)
|
| 55 |
+
]
|
| 56 |
+
return lst
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class CanvasProxy:
|
| 60 |
+
def __init__(self, canvas, width=800, height=600) -> None:
|
| 61 |
+
self.canvas = canvas
|
| 62 |
+
self.ctx = prepare_canvas(width, height, canvas)
|
| 63 |
+
self.width = width
|
| 64 |
+
self.height = height
|
| 65 |
+
|
| 66 |
+
def clear_rect(self, x, y, w, h):
|
| 67 |
+
self.ctx.clearRect(x, y, w, h)
|
| 68 |
+
|
| 69 |
+
def clear(self,):
|
| 70 |
+
self.clear_rect(0, 0, self.canvas.width, self.canvas.height)
|
| 71 |
+
|
| 72 |
+
def stroke_rect(self, x, y, w, h):
|
| 73 |
+
self.ctx.strokeRect(x, y, w, h)
|
| 74 |
+
|
| 75 |
+
def fill_rect(self, x, y, w, h):
|
| 76 |
+
self.ctx.fillRect(x, y, w, h)
|
| 77 |
+
|
| 78 |
+
def put_image_data(self, image, x, y):
|
| 79 |
+
data = Uint8ClampedArray.new(to_js(image.tobytes()))
|
| 80 |
+
height, width, _ = image.shape
|
| 81 |
+
image_data = ImageData.new(data, width, height)
|
| 82 |
+
self.ctx.putImageData(image_data, x, y)
|
| 83 |
+
del image_data
|
| 84 |
+
|
| 85 |
+
# def draw_image(self,canvas, x, y, w, h):
|
| 86 |
+
# self.ctx.drawImage(canvas,x,y,w,h)
|
| 87 |
+
def draw_image(self,canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight):
|
| 88 |
+
self.ctx.drawImage(canvas, sx, sy, sWidth, sHeight, dx, dy, dWidth, dHeight)
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def stroke_style(self):
|
| 92 |
+
return self.ctx.strokeStyle
|
| 93 |
+
|
| 94 |
+
@stroke_style.setter
|
| 95 |
+
def stroke_style(self, value):
|
| 96 |
+
self.ctx.strokeStyle = value
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def fill_style(self):
|
| 100 |
+
return self.ctx.strokeStyle
|
| 101 |
+
|
| 102 |
+
@fill_style.setter
|
| 103 |
+
def fill_style(self, value):
|
| 104 |
+
self.ctx.fillStyle = value
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# RGBA for masking
|
| 108 |
+
class InfCanvas:
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
width,
|
| 112 |
+
height,
|
| 113 |
+
selection_size=256,
|
| 114 |
+
grid_size=64,
|
| 115 |
+
patch_size=4096,
|
| 116 |
+
test_mode=False,
|
| 117 |
+
) -> None:
|
| 118 |
+
assert selection_size < min(height, width)
|
| 119 |
+
self.width = width
|
| 120 |
+
self.height = height
|
| 121 |
+
self.display_width = width
|
| 122 |
+
self.display_height = height
|
| 123 |
+
self.canvas = multi_canvas(5, width=width, height=height)
|
| 124 |
+
setup_overlay(width,height)
|
| 125 |
+
# place at center
|
| 126 |
+
self.view_pos = [patch_size//2-width//2, patch_size//2-height//2]
|
| 127 |
+
self.cursor = [
|
| 128 |
+
width // 2 - selection_size // 2,
|
| 129 |
+
height // 2 - selection_size // 2,
|
| 130 |
+
]
|
| 131 |
+
self.data = {}
|
| 132 |
+
self.grid_size = grid_size
|
| 133 |
+
self.selection_size_w = selection_size
|
| 134 |
+
self.selection_size_h = selection_size
|
| 135 |
+
self.patch_size = patch_size
|
| 136 |
+
# note that for image data, the height comes before width
|
| 137 |
+
self.buffer = np.zeros((height, width, 4), dtype=np.uint8)
|
| 138 |
+
self.sel_buffer = np.zeros((selection_size, selection_size, 4), dtype=np.uint8)
|
| 139 |
+
self.sel_buffer_bak = np.zeros(
|
| 140 |
+
(selection_size, selection_size, 4), dtype=np.uint8
|
| 141 |
+
)
|
| 142 |
+
self.sel_dirty = False
|
| 143 |
+
self.buffer_dirty = False
|
| 144 |
+
self.mouse_pos = [-1, -1]
|
| 145 |
+
self.mouse_state = 0
|
| 146 |
+
# self.output = widgets.Output()
|
| 147 |
+
self.test_mode = test_mode
|
| 148 |
+
self.buffer_updated = False
|
| 149 |
+
self.image_move_freq = 1
|
| 150 |
+
self.show_brush = False
|
| 151 |
+
self.scale=1.0
|
| 152 |
+
self.eraser_size=32
|
| 153 |
+
|
| 154 |
+
def reset_large_buffer(self):
|
| 155 |
+
self.canvas[2].canvas.width=self.width
|
| 156 |
+
self.canvas[2].canvas.height=self.height
|
| 157 |
+
# self.canvas[2].canvas.style.width=f"{self.display_width}px"
|
| 158 |
+
# self.canvas[2].canvas.style.height=f"{self.display_height}px"
|
| 159 |
+
self.canvas[2].canvas.style.display="block"
|
| 160 |
+
self.canvas[2].clear()
|
| 161 |
+
|
| 162 |
+
def draw_eraser(self, x, y):
|
| 163 |
+
self.canvas[-2].clear()
|
| 164 |
+
self.canvas[-2].fill_style = "#ffffff"
|
| 165 |
+
self.canvas[-2].fill_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
|
| 166 |
+
self.canvas[-2].stroke_rect(x-self.eraser_size//2,y-self.eraser_size//2,self.eraser_size,self.eraser_size)
|
| 167 |
+
|
| 168 |
+
def use_eraser(self,x,y):
|
| 169 |
+
if self.sel_dirty:
|
| 170 |
+
self.write_selection_to_buffer()
|
| 171 |
+
self.draw_buffer()
|
| 172 |
+
self.canvas[2].clear()
|
| 173 |
+
self.buffer_dirty=True
|
| 174 |
+
bx0,by0=int(x)-self.eraser_size//2,int(y)-self.eraser_size//2
|
| 175 |
+
bx1,by1=bx0+self.eraser_size,by0+self.eraser_size
|
| 176 |
+
bx0,by0=max(0,bx0),max(0,by0)
|
| 177 |
+
bx1,by1=min(self.width,bx1),min(self.height,by1)
|
| 178 |
+
self.buffer[by0:by1,bx0:bx1,:]*=0
|
| 179 |
+
self.draw_buffer()
|
| 180 |
+
self.draw_selection_box()
|
| 181 |
+
|
| 182 |
+
def setup_mouse(self):
|
| 183 |
+
self.image_move_cnt = 0
|
| 184 |
+
|
| 185 |
+
def get_mouse_mode():
|
| 186 |
+
mode = document.querySelector("#mode").value
|
| 187 |
+
if mode == PAINT_SELECTION:
|
| 188 |
+
return PAINT_MODE
|
| 189 |
+
elif mode == IMAGE_SELECTION:
|
| 190 |
+
return IMAGE_MODE
|
| 191 |
+
return BRUSH_MODE
|
| 192 |
+
|
| 193 |
+
def get_event_pos(event):
|
| 194 |
+
canvas = self.canvas[-1].canvas
|
| 195 |
+
rect = canvas.getBoundingClientRect()
|
| 196 |
+
x = (canvas.width * (event.clientX - rect.left)) / rect.width
|
| 197 |
+
y = (canvas.height * (event.clientY - rect.top)) / rect.height
|
| 198 |
+
return x, y
|
| 199 |
+
|
| 200 |
+
def handle_mouse_down(event):
|
| 201 |
+
self.mouse_state = get_mouse_mode()
|
| 202 |
+
if self.mouse_state==BRUSH_MODE:
|
| 203 |
+
x,y=get_event_pos(event)
|
| 204 |
+
self.use_eraser(x,y)
|
| 205 |
+
|
| 206 |
+
def handle_mouse_out(event):
|
| 207 |
+
last_state = self.mouse_state
|
| 208 |
+
self.mouse_state = NOP_MODE
|
| 209 |
+
self.image_move_cnt = 0
|
| 210 |
+
if last_state == IMAGE_MODE:
|
| 211 |
+
self.update_view_pos(0, 0)
|
| 212 |
+
if True:
|
| 213 |
+
self.clear_background()
|
| 214 |
+
self.draw_buffer()
|
| 215 |
+
self.reset_large_buffer()
|
| 216 |
+
self.draw_selection_box()
|
| 217 |
+
gc.collect()
|
| 218 |
+
if self.show_brush:
|
| 219 |
+
self.canvas[-2].clear()
|
| 220 |
+
self.show_brush = False
|
| 221 |
+
|
| 222 |
+
def handle_mouse_up(event):
|
| 223 |
+
last_state = self.mouse_state
|
| 224 |
+
self.mouse_state = NOP_MODE
|
| 225 |
+
self.image_move_cnt = 0
|
| 226 |
+
if last_state == IMAGE_MODE:
|
| 227 |
+
self.update_view_pos(0, 0)
|
| 228 |
+
if True:
|
| 229 |
+
self.clear_background()
|
| 230 |
+
self.draw_buffer()
|
| 231 |
+
self.reset_large_buffer()
|
| 232 |
+
self.draw_selection_box()
|
| 233 |
+
gc.collect()
|
| 234 |
+
|
| 235 |
+
async def handle_mouse_move(event):
|
| 236 |
+
x, y = get_event_pos(event)
|
| 237 |
+
x0, y0 = self.mouse_pos
|
| 238 |
+
xo = x - x0
|
| 239 |
+
yo = y - y0
|
| 240 |
+
if self.mouse_state == PAINT_MODE:
|
| 241 |
+
self.update_cursor(int(xo), int(yo))
|
| 242 |
+
if True:
|
| 243 |
+
# self.clear_background()
|
| 244 |
+
# console.log(self.buffer_updated)
|
| 245 |
+
if self.buffer_updated:
|
| 246 |
+
self.draw_buffer()
|
| 247 |
+
self.buffer_updated = False
|
| 248 |
+
self.draw_selection_box()
|
| 249 |
+
elif self.mouse_state == IMAGE_MODE:
|
| 250 |
+
self.image_move_cnt += 1
|
| 251 |
+
if self.image_move_cnt == self.image_move_freq:
|
| 252 |
+
self.draw_buffer()
|
| 253 |
+
self.canvas[2].clear()
|
| 254 |
+
self.draw_selection_box()
|
| 255 |
+
self.update_view_pos(int(xo), int(yo))
|
| 256 |
+
self.cached_view_pos=tuple(self.view_pos)
|
| 257 |
+
self.canvas[2].canvas.style.display="none"
|
| 258 |
+
large_buffer=self.data2array(self.view_pos[0]-self.width//2,self.view_pos[1]-self.height//2,min(self.width*2,self.patch_size),min(self.height*2,self.patch_size))
|
| 259 |
+
self.canvas[2].canvas.width=large_buffer.shape[1]
|
| 260 |
+
self.canvas[2].canvas.height=large_buffer.shape[0]
|
| 261 |
+
# self.canvas[2].canvas.style.width=""
|
| 262 |
+
# self.canvas[2].canvas.style.height=""
|
| 263 |
+
self.canvas[2].put_image_data(large_buffer,0,0)
|
| 264 |
+
else:
|
| 265 |
+
self.update_view_pos(int(xo), int(yo), False)
|
| 266 |
+
self.canvas[1].clear()
|
| 267 |
+
self.canvas[1].draw_image(self.canvas[2].canvas,
|
| 268 |
+
self.width//2+(self.view_pos[0]-self.cached_view_pos[0]),self.height//2+(self.view_pos[1]-self.cached_view_pos[1]),
|
| 269 |
+
self.width,self.height,
|
| 270 |
+
0,0,self.width,self.height
|
| 271 |
+
)
|
| 272 |
+
self.clear_background()
|
| 273 |
+
# self.image_move_cnt = 0
|
| 274 |
+
elif self.mouse_state == BRUSH_MODE:
|
| 275 |
+
self.use_eraser(x,y)
|
| 276 |
+
|
| 277 |
+
mode = document.querySelector("#mode").value
|
| 278 |
+
if mode == BRUSH_SELECTION:
|
| 279 |
+
self.draw_eraser(x,y)
|
| 280 |
+
self.show_brush = True
|
| 281 |
+
elif self.show_brush:
|
| 282 |
+
self.canvas[-2].clear()
|
| 283 |
+
self.show_brush = False
|
| 284 |
+
self.mouse_pos[0] = x
|
| 285 |
+
self.mouse_pos[1] = y
|
| 286 |
+
|
| 287 |
+
self.canvas[-1].canvas.addEventListener(
|
| 288 |
+
"mousedown", create_proxy(handle_mouse_down)
|
| 289 |
+
)
|
| 290 |
+
self.canvas[-1].canvas.addEventListener(
|
| 291 |
+
"mousemove", create_proxy(handle_mouse_move)
|
| 292 |
+
)
|
| 293 |
+
self.canvas[-1].canvas.addEventListener(
|
| 294 |
+
"mouseup", create_proxy(handle_mouse_up)
|
| 295 |
+
)
|
| 296 |
+
self.canvas[-1].canvas.addEventListener(
|
| 297 |
+
"mouseout", create_proxy(handle_mouse_out)
|
| 298 |
+
)
|
| 299 |
+
async def handle_mouse_wheel(event):
|
| 300 |
+
x, y = get_event_pos(event)
|
| 301 |
+
self.mouse_pos[0] = x
|
| 302 |
+
self.mouse_pos[1] = y
|
| 303 |
+
console.log(to_js(self.mouse_pos))
|
| 304 |
+
if event.deltaY>10:
|
| 305 |
+
window.postMessage(to_js(["click","zoom_out", self.mouse_pos[0], self.mouse_pos[1]]),"*")
|
| 306 |
+
elif event.deltaY<-10:
|
| 307 |
+
window.postMessage(to_js(["click","zoom_in", self.mouse_pos[0], self.mouse_pos[1]]),"*")
|
| 308 |
+
return False
|
| 309 |
+
self.canvas[-1].canvas.addEventListener(
|
| 310 |
+
"wheel", create_proxy(handle_mouse_wheel), False
|
| 311 |
+
)
|
| 312 |
+
def clear_background(self):
|
| 313 |
+
# fake transparent background
|
| 314 |
+
h, w, step = self.height, self.width, self.grid_size
|
| 315 |
+
stride = step * 2
|
| 316 |
+
x0, y0 = self.view_pos
|
| 317 |
+
x0 = (-x0) % stride
|
| 318 |
+
y0 = (-y0) % stride
|
| 319 |
+
if y0>=step:
|
| 320 |
+
val0,val1=stride,step
|
| 321 |
+
else:
|
| 322 |
+
val0,val1=step,stride
|
| 323 |
+
# self.canvas.clear()
|
| 324 |
+
self.canvas[0].fill_style = "#ffffff"
|
| 325 |
+
self.canvas[0].fill_rect(0, 0, w, h)
|
| 326 |
+
self.canvas[0].fill_style = "#aaaaaa"
|
| 327 |
+
for y in range(y0-stride, h + step, step):
|
| 328 |
+
start = (x0 - val0) if y // step % 2 == 0 else (x0 - val1)
|
| 329 |
+
for x in range(start, w + step, stride):
|
| 330 |
+
self.canvas[0].fill_rect(x, y, step, step)
|
| 331 |
+
self.canvas[0].stroke_rect(0, 0, w, h)
|
| 332 |
+
|
| 333 |
+
def refine_selection(self):
|
| 334 |
+
h,w=self.selection_size_h,self.selection_size_w
|
| 335 |
+
h=min(h,self.height)
|
| 336 |
+
w=min(w,self.width)
|
| 337 |
+
self.selection_size_h=h*8//8
|
| 338 |
+
self.selection_size_w=w*8//8
|
| 339 |
+
self.update_cursor(1,0)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def update_scale(self, scale, mx=-1, my=-1):
|
| 343 |
+
self.sync_to_data()
|
| 344 |
+
scaled_width=int(self.display_width*scale)
|
| 345 |
+
scaled_height=int(self.display_height*scale)
|
| 346 |
+
if max(scaled_height,scaled_width)>=self.patch_size*2-128:
|
| 347 |
+
return
|
| 348 |
+
if scaled_height<=self.selection_size_h or scaled_width<=self.selection_size_w:
|
| 349 |
+
return
|
| 350 |
+
if mx>=0 and my>=0:
|
| 351 |
+
scaled_mx=mx/self.scale*scale
|
| 352 |
+
scaled_my=my/self.scale*scale
|
| 353 |
+
self.view_pos[0]+=int(mx-scaled_mx)
|
| 354 |
+
self.view_pos[1]+=int(my-scaled_my)
|
| 355 |
+
self.scale=scale
|
| 356 |
+
for item in self.canvas:
|
| 357 |
+
item.canvas.width=scaled_width
|
| 358 |
+
item.canvas.height=scaled_height
|
| 359 |
+
item.clear()
|
| 360 |
+
update_overlay(scaled_width,scaled_height)
|
| 361 |
+
self.width=scaled_width
|
| 362 |
+
self.height=scaled_height
|
| 363 |
+
self.data2buffer()
|
| 364 |
+
self.clear_background()
|
| 365 |
+
self.draw_buffer()
|
| 366 |
+
self.update_cursor(1,0)
|
| 367 |
+
self.draw_selection_box()
|
| 368 |
+
|
| 369 |
+
def update_view_pos(self, xo, yo, update=True):
|
| 370 |
+
# if abs(xo) + abs(yo) == 0:
|
| 371 |
+
# return
|
| 372 |
+
if self.sel_dirty:
|
| 373 |
+
self.write_selection_to_buffer()
|
| 374 |
+
if self.buffer_dirty:
|
| 375 |
+
self.buffer2data()
|
| 376 |
+
self.view_pos[0] -= xo
|
| 377 |
+
self.view_pos[1] -= yo
|
| 378 |
+
if update:
|
| 379 |
+
self.data2buffer()
|
| 380 |
+
# self.read_selection_from_buffer()
|
| 381 |
+
|
| 382 |
+
def update_cursor(self, xo, yo):
|
| 383 |
+
if abs(xo) + abs(yo) == 0:
|
| 384 |
+
return
|
| 385 |
+
if self.sel_dirty:
|
| 386 |
+
self.write_selection_to_buffer()
|
| 387 |
+
self.cursor[0] += xo
|
| 388 |
+
self.cursor[1] += yo
|
| 389 |
+
self.cursor[0] = max(min(self.width - self.selection_size_w, self.cursor[0]), 0)
|
| 390 |
+
self.cursor[1] = max(min(self.height - self.selection_size_h, self.cursor[1]), 0)
|
| 391 |
+
# self.read_selection_from_buffer()
|
| 392 |
+
|
| 393 |
+
def data2buffer(self):
|
| 394 |
+
x, y = self.view_pos
|
| 395 |
+
h, w = self.height, self.width
|
| 396 |
+
if h!=self.buffer.shape[0] or w!=self.buffer.shape[1]:
|
| 397 |
+
self.buffer=np.zeros((self.height, self.width, 4), dtype=np.uint8)
|
| 398 |
+
# fill four parts
|
| 399 |
+
for i in range(4):
|
| 400 |
+
pos_src, pos_dst, data = self.select(x, y, i)
|
| 401 |
+
xs0, xs1 = pos_src[0]
|
| 402 |
+
ys0, ys1 = pos_src[1]
|
| 403 |
+
xd0, xd1 = pos_dst[0]
|
| 404 |
+
yd0, yd1 = pos_dst[1]
|
| 405 |
+
self.buffer[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
|
| 406 |
+
|
| 407 |
+
def data2array(self, x, y, w, h):
|
| 408 |
+
# x, y = self.view_pos
|
| 409 |
+
# h, w = self.height, self.width
|
| 410 |
+
ret=np.zeros((h, w, 4), dtype=np.uint8)
|
| 411 |
+
# fill four parts
|
| 412 |
+
for i in range(4):
|
| 413 |
+
pos_src, pos_dst, data = self.select(x, y, i, w, h)
|
| 414 |
+
xs0, xs1 = pos_src[0]
|
| 415 |
+
ys0, ys1 = pos_src[1]
|
| 416 |
+
xd0, xd1 = pos_dst[0]
|
| 417 |
+
yd0, yd1 = pos_dst[1]
|
| 418 |
+
ret[yd0:yd1, xd0:xd1, :] = data[ys0:ys1, xs0:xs1, :]
|
| 419 |
+
return ret
|
| 420 |
+
|
| 421 |
+
def buffer2data(self):
|
| 422 |
+
x, y = self.view_pos
|
| 423 |
+
h, w = self.height, self.width
|
| 424 |
+
# fill four parts
|
| 425 |
+
for i in range(4):
|
| 426 |
+
pos_src, pos_dst, data = self.select(x, y, i)
|
| 427 |
+
xs0, xs1 = pos_src[0]
|
| 428 |
+
ys0, ys1 = pos_src[1]
|
| 429 |
+
xd0, xd1 = pos_dst[0]
|
| 430 |
+
yd0, yd1 = pos_dst[1]
|
| 431 |
+
data[ys0:ys1, xs0:xs1, :] = self.buffer[yd0:yd1, xd0:xd1, :]
|
| 432 |
+
self.buffer_dirty = False
|
| 433 |
+
|
| 434 |
+
def select(self, x, y, idx, width=0, height=0):
|
| 435 |
+
if width==0:
|
| 436 |
+
w, h = self.width, self.height
|
| 437 |
+
else:
|
| 438 |
+
w, h = width, height
|
| 439 |
+
lst = [(0, 0), (0, h), (w, 0), (w, h)]
|
| 440 |
+
if idx == 0:
|
| 441 |
+
x0, y0 = x % self.patch_size, y % self.patch_size
|
| 442 |
+
x1 = min(x0 + w, self.patch_size)
|
| 443 |
+
y1 = min(y0 + h, self.patch_size)
|
| 444 |
+
elif idx == 1:
|
| 445 |
+
y += h
|
| 446 |
+
x0, y0 = x % self.patch_size, y % self.patch_size
|
| 447 |
+
x1 = min(x0 + w, self.patch_size)
|
| 448 |
+
y1 = max(y0 - h, 0)
|
| 449 |
+
elif idx == 2:
|
| 450 |
+
x += w
|
| 451 |
+
x0, y0 = x % self.patch_size, y % self.patch_size
|
| 452 |
+
x1 = max(x0 - w, 0)
|
| 453 |
+
y1 = min(y0 + h, self.patch_size)
|
| 454 |
+
else:
|
| 455 |
+
x += w
|
| 456 |
+
y += h
|
| 457 |
+
x0, y0 = x % self.patch_size, y % self.patch_size
|
| 458 |
+
x1 = max(x0 - w, 0)
|
| 459 |
+
y1 = max(y0 - h, 0)
|
| 460 |
+
xi, yi = x // self.patch_size, y // self.patch_size
|
| 461 |
+
cur = self.data.setdefault(
|
| 462 |
+
(xi, yi), np.zeros((self.patch_size, self.patch_size, 4), dtype=np.uint8)
|
| 463 |
+
)
|
| 464 |
+
x0_img, y0_img = lst[idx]
|
| 465 |
+
x1_img = x0_img + x1 - x0
|
| 466 |
+
y1_img = y0_img + y1 - y0
|
| 467 |
+
sort = lambda a, b: ((a, b) if a < b else (b, a))
|
| 468 |
+
return (
|
| 469 |
+
(sort(x0, x1), sort(y0, y1)),
|
| 470 |
+
(sort(x0_img, x1_img), sort(y0_img, y1_img)),
|
| 471 |
+
cur,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
def draw_buffer(self):
|
| 475 |
+
self.canvas[1].clear()
|
| 476 |
+
self.canvas[1].put_image_data(self.buffer, 0, 0)
|
| 477 |
+
|
| 478 |
+
def fill_selection(self, img):
|
| 479 |
+
self.sel_buffer = img
|
| 480 |
+
self.sel_dirty = True
|
| 481 |
+
|
| 482 |
+
def draw_selection_box(self):
|
| 483 |
+
x0, y0 = self.cursor
|
| 484 |
+
w, h = self.selection_size_w, self.selection_size_h
|
| 485 |
+
if self.sel_dirty:
|
| 486 |
+
self.canvas[2].clear()
|
| 487 |
+
self.canvas[2].put_image_data(self.sel_buffer, x0, y0)
|
| 488 |
+
self.canvas[-1].clear()
|
| 489 |
+
self.canvas[-1].stroke_style = "#0a0a0a"
|
| 490 |
+
self.canvas[-1].stroke_rect(x0, y0, w, h)
|
| 491 |
+
self.canvas[-1].stroke_style = "#ffffff"
|
| 492 |
+
offset=round(self.scale) if self.scale>1.0 else 1
|
| 493 |
+
self.canvas[-1].stroke_rect(x0 - offset, y0 - offset, w + offset*2, h + offset*2)
|
| 494 |
+
self.canvas[-1].stroke_style = "#000000"
|
| 495 |
+
self.canvas[-1].stroke_rect(x0 - offset*2, y0 - offset*2, w + offset*4, h + offset*4)
|
| 496 |
+
|
| 497 |
+
def write_selection_to_buffer(self):
|
| 498 |
+
x0, y0 = self.cursor
|
| 499 |
+
x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
|
| 500 |
+
self.buffer[y0:y1, x0:x1] = self.sel_buffer
|
| 501 |
+
self.sel_dirty = False
|
| 502 |
+
self.sel_buffer = np.zeros(
|
| 503 |
+
(self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
|
| 504 |
+
)
|
| 505 |
+
self.buffer_dirty = True
|
| 506 |
+
self.buffer_updated = True
|
| 507 |
+
# self.canvas[2].clear()
|
| 508 |
+
|
| 509 |
+
def read_selection_from_buffer(self):
|
| 510 |
+
x0, y0 = self.cursor
|
| 511 |
+
x1, y1 = x0 + self.selection_size_w, y0 + self.selection_size_h
|
| 512 |
+
self.sel_buffer = self.buffer[y0:y1, x0:x1]
|
| 513 |
+
self.sel_dirty = False
|
| 514 |
+
|
| 515 |
+
def base64_to_numpy(self, base64_str):
|
| 516 |
+
try:
|
| 517 |
+
data = base64.b64decode(str(base64_str))
|
| 518 |
+
pil = Image.open(io.BytesIO(data))
|
| 519 |
+
arr = np.array(pil)
|
| 520 |
+
ret = arr
|
| 521 |
+
except:
|
| 522 |
+
ret = np.tile(
|
| 523 |
+
np.array([255, 0, 0, 255], dtype=np.uint8),
|
| 524 |
+
(self.selection_size_h, self.selection_size_w, 1),
|
| 525 |
+
)
|
| 526 |
+
return ret
|
| 527 |
+
|
| 528 |
+
def numpy_to_base64(self, arr):
|
| 529 |
+
out_pil = Image.fromarray(arr)
|
| 530 |
+
out_buffer = io.BytesIO()
|
| 531 |
+
out_pil.save(out_buffer, format="PNG")
|
| 532 |
+
out_buffer.seek(0)
|
| 533 |
+
base64_bytes = base64.b64encode(out_buffer.read())
|
| 534 |
+
base64_str = base64_bytes.decode("ascii")
|
| 535 |
+
return base64_str
|
| 536 |
+
|
| 537 |
+
def sync_to_data(self):
|
| 538 |
+
if self.sel_dirty:
|
| 539 |
+
self.write_selection_to_buffer()
|
| 540 |
+
self.canvas[2].clear()
|
| 541 |
+
self.draw_buffer()
|
| 542 |
+
if self.buffer_dirty:
|
| 543 |
+
self.buffer2data()
|
| 544 |
+
|
| 545 |
+
def sync_to_buffer(self):
|
| 546 |
+
if self.sel_dirty:
|
| 547 |
+
self.canvas[2].clear()
|
| 548 |
+
self.write_selection_to_buffer()
|
| 549 |
+
self.draw_buffer()
|
| 550 |
+
|
| 551 |
+
def resize(self,width,height,scale=None,**kwargs):
|
| 552 |
+
self.display_width=width
|
| 553 |
+
self.display_height=height
|
| 554 |
+
for canvas in self.canvas:
|
| 555 |
+
prepare_canvas(width=width,height=height,canvas=canvas.canvas)
|
| 556 |
+
setup_overlay(width,height)
|
| 557 |
+
if scale is None:
|
| 558 |
+
scale=1
|
| 559 |
+
self.update_scale(scale)
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def save(self):
|
| 563 |
+
self.sync_to_data()
|
| 564 |
+
state={}
|
| 565 |
+
state["width"]=self.display_width
|
| 566 |
+
state["height"]=self.display_height
|
| 567 |
+
state["selection_width"]=self.selection_size_w
|
| 568 |
+
state["selection_height"]=self.selection_size_h
|
| 569 |
+
state["view_pos"]=self.view_pos[:]
|
| 570 |
+
state["cursor"]=self.cursor[:]
|
| 571 |
+
state["scale"]=self.scale
|
| 572 |
+
keys=list(self.data.keys())
|
| 573 |
+
data={}
|
| 574 |
+
for key in keys:
|
| 575 |
+
if self.data[key].sum()>0:
|
| 576 |
+
data[f"{key[0]},{key[1]}"]=self.numpy_to_base64(self.data[key])
|
| 577 |
+
state["data"]=data
|
| 578 |
+
return json.dumps(state)
|
| 579 |
+
|
| 580 |
+
def load(self, state_json):
|
| 581 |
+
self.reset()
|
| 582 |
+
state=json.loads(state_json)
|
| 583 |
+
self.display_width=state["width"]
|
| 584 |
+
self.display_height=state["height"]
|
| 585 |
+
self.selection_size_w=state["selection_width"]
|
| 586 |
+
self.selection_size_h=state["selection_height"]
|
| 587 |
+
self.view_pos=state["view_pos"][:]
|
| 588 |
+
self.cursor=state["cursor"][:]
|
| 589 |
+
self.scale=state["scale"]
|
| 590 |
+
self.resize(state["width"],state["height"],scale=state["scale"])
|
| 591 |
+
for k,v in state["data"].items():
|
| 592 |
+
key=tuple(map(int,k.split(",")))
|
| 593 |
+
self.data[key]=self.base64_to_numpy(v)
|
| 594 |
+
self.data2buffer()
|
| 595 |
+
self.display()
|
| 596 |
+
|
| 597 |
+
def display(self):
|
| 598 |
+
self.clear_background()
|
| 599 |
+
self.draw_buffer()
|
| 600 |
+
self.draw_selection_box()
|
| 601 |
+
|
| 602 |
+
def reset(self):
|
| 603 |
+
self.data.clear()
|
| 604 |
+
self.buffer*=0
|
| 605 |
+
self.buffer_dirty=False
|
| 606 |
+
self.buffer_updated=False
|
| 607 |
+
self.sel_buffer*=0
|
| 608 |
+
self.sel_dirty=False
|
| 609 |
+
self.view_pos = [0, 0]
|
| 610 |
+
self.clear_background()
|
| 611 |
+
for i in range(1,len(self.canvas)-1):
|
| 612 |
+
self.canvas[i].clear()
|
| 613 |
+
|
| 614 |
+
def export(self):
|
| 615 |
+
self.sync_to_data()
|
| 616 |
+
xmin, xmax, ymin, ymax = 0, 0, 0, 0
|
| 617 |
+
if len(self.data.keys()) == 0:
|
| 618 |
+
return np.zeros(
|
| 619 |
+
(self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
|
| 620 |
+
)
|
| 621 |
+
for xi, yi in self.data.keys():
|
| 622 |
+
buf = self.data[(xi, yi)]
|
| 623 |
+
if buf.sum() > 0:
|
| 624 |
+
xmin = min(xi, xmin)
|
| 625 |
+
xmax = max(xi, xmax)
|
| 626 |
+
ymin = min(yi, ymin)
|
| 627 |
+
ymax = max(yi, ymax)
|
| 628 |
+
yn = ymax - ymin + 1
|
| 629 |
+
xn = xmax - xmin + 1
|
| 630 |
+
image = np.zeros(
|
| 631 |
+
(yn * self.patch_size, xn * self.patch_size, 4), dtype=np.uint8
|
| 632 |
+
)
|
| 633 |
+
for xi, yi in self.data.keys():
|
| 634 |
+
buf = self.data[(xi, yi)]
|
| 635 |
+
if buf.sum() > 0:
|
| 636 |
+
y0 = (yi - ymin) * self.patch_size
|
| 637 |
+
x0 = (xi - xmin) * self.patch_size
|
| 638 |
+
image[y0 : y0 + self.patch_size, x0 : x0 + self.patch_size] = buf
|
| 639 |
+
ylst, xlst = image[:, :, -1].nonzero()
|
| 640 |
+
if len(ylst) > 0:
|
| 641 |
+
yt, xt = ylst.min(), xlst.min()
|
| 642 |
+
yb, xb = ylst.max(), xlst.max()
|
| 643 |
+
image = image[yt : yb + 1, xt : xb + 1]
|
| 644 |
+
return image
|
| 645 |
+
else:
|
| 646 |
+
return np.zeros(
|
| 647 |
+
(self.selection_size_h, self.selection_size_w, 4), dtype=np.uint8
|
| 648 |
+
)
|
config.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
shortcut:
|
| 2 |
+
clear: Escape
|
| 3 |
+
load: Ctrl+o
|
| 4 |
+
save: Ctrl+s
|
| 5 |
+
export: Ctrl+e
|
| 6 |
+
upload: Ctrl+u
|
| 7 |
+
selection: 1
|
| 8 |
+
canvas: 2
|
| 9 |
+
eraser: 3
|
| 10 |
+
outpaint: d
|
| 11 |
+
accept: a
|
| 12 |
+
cancel: c
|
| 13 |
+
retry: r
|
| 14 |
+
prev: q
|
| 15 |
+
next: e
|
| 16 |
+
zoom_in: z
|
| 17 |
+
zoom_out: x
|
| 18 |
+
random_seed: s
|
convert_checkpoint.py
ADDED
|
@@ -0,0 +1,706 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2022 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py
|
| 16 |
+
""" Conversion script for the LDM checkpoints. """
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from omegaconf import OmegaConf
|
| 26 |
+
except ImportError:
|
| 27 |
+
raise ImportError(
|
| 28 |
+
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from diffusers import (
|
| 32 |
+
AutoencoderKL,
|
| 33 |
+
DDIMScheduler,
|
| 34 |
+
LDMTextToImagePipeline,
|
| 35 |
+
LMSDiscreteScheduler,
|
| 36 |
+
PNDMScheduler,
|
| 37 |
+
StableDiffusionPipeline,
|
| 38 |
+
UNet2DConditionModel,
|
| 39 |
+
)
|
| 40 |
+
from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
| 41 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
| 42 |
+
from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
| 46 |
+
"""
|
| 47 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
| 48 |
+
"""
|
| 49 |
+
if n_shave_prefix_segments >= 0:
|
| 50 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
| 51 |
+
else:
|
| 52 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 56 |
+
"""
|
| 57 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 58 |
+
"""
|
| 59 |
+
mapping = []
|
| 60 |
+
for old_item in old_list:
|
| 61 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
| 62 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
| 63 |
+
|
| 64 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
| 65 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
| 66 |
+
|
| 67 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
| 68 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
| 69 |
+
|
| 70 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 71 |
+
|
| 72 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 73 |
+
|
| 74 |
+
return mapping
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
| 78 |
+
"""
|
| 79 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
| 80 |
+
"""
|
| 81 |
+
mapping = []
|
| 82 |
+
for old_item in old_list:
|
| 83 |
+
new_item = old_item
|
| 84 |
+
|
| 85 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
| 86 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 87 |
+
|
| 88 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 89 |
+
|
| 90 |
+
return mapping
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 94 |
+
"""
|
| 95 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 96 |
+
"""
|
| 97 |
+
mapping = []
|
| 98 |
+
for old_item in old_list:
|
| 99 |
+
new_item = old_item
|
| 100 |
+
|
| 101 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
| 102 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
| 103 |
+
|
| 104 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
| 105 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
| 106 |
+
|
| 107 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 108 |
+
|
| 109 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 110 |
+
|
| 111 |
+
return mapping
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
| 115 |
+
"""
|
| 116 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
| 117 |
+
"""
|
| 118 |
+
mapping = []
|
| 119 |
+
for old_item in old_list:
|
| 120 |
+
new_item = old_item
|
| 121 |
+
|
| 122 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
| 123 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
| 124 |
+
|
| 125 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
| 126 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
| 127 |
+
|
| 128 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
| 129 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
| 130 |
+
|
| 131 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
| 132 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
| 133 |
+
|
| 134 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
| 135 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
| 136 |
+
|
| 137 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
| 138 |
+
|
| 139 |
+
mapping.append({"old": old_item, "new": new_item})
|
| 140 |
+
|
| 141 |
+
return mapping
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def assign_to_checkpoint(
|
| 145 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
| 146 |
+
):
|
| 147 |
+
"""
|
| 148 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
| 149 |
+
to them. It splits attention layers, and takes into account additional replacements
|
| 150 |
+
that may arise.
|
| 151 |
+
|
| 152 |
+
Assigns the weights to the new checkpoint.
|
| 153 |
+
"""
|
| 154 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
| 155 |
+
|
| 156 |
+
# Splits the attention layers into three variables.
|
| 157 |
+
if attention_paths_to_split is not None:
|
| 158 |
+
for path, path_map in attention_paths_to_split.items():
|
| 159 |
+
old_tensor = old_checkpoint[path]
|
| 160 |
+
channels = old_tensor.shape[0] // 3
|
| 161 |
+
|
| 162 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
| 163 |
+
|
| 164 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
| 165 |
+
|
| 166 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
| 167 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
| 168 |
+
|
| 169 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
| 170 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
| 171 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
| 172 |
+
|
| 173 |
+
for path in paths:
|
| 174 |
+
new_path = path["new"]
|
| 175 |
+
|
| 176 |
+
# These have already been assigned
|
| 177 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
# Global renaming happens here
|
| 181 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
| 182 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
| 183 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
| 184 |
+
|
| 185 |
+
if additional_replacements is not None:
|
| 186 |
+
for replacement in additional_replacements:
|
| 187 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
| 188 |
+
|
| 189 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
| 190 |
+
if "proj_attn.weight" in new_path:
|
| 191 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
| 192 |
+
else:
|
| 193 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def conv_attn_to_linear(checkpoint):
|
| 197 |
+
keys = list(checkpoint.keys())
|
| 198 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
| 199 |
+
for key in keys:
|
| 200 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
| 201 |
+
if checkpoint[key].ndim > 2:
|
| 202 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
| 203 |
+
elif "proj_attn.weight" in key:
|
| 204 |
+
if checkpoint[key].ndim > 2:
|
| 205 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def create_unet_diffusers_config(original_config):
|
| 209 |
+
"""
|
| 210 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
| 211 |
+
"""
|
| 212 |
+
unet_params = original_config.model.params.unet_config.params
|
| 213 |
+
|
| 214 |
+
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
| 215 |
+
|
| 216 |
+
down_block_types = []
|
| 217 |
+
resolution = 1
|
| 218 |
+
for i in range(len(block_out_channels)):
|
| 219 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
| 220 |
+
down_block_types.append(block_type)
|
| 221 |
+
if i != len(block_out_channels) - 1:
|
| 222 |
+
resolution *= 2
|
| 223 |
+
|
| 224 |
+
up_block_types = []
|
| 225 |
+
for i in range(len(block_out_channels)):
|
| 226 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
| 227 |
+
up_block_types.append(block_type)
|
| 228 |
+
resolution //= 2
|
| 229 |
+
|
| 230 |
+
config = dict(
|
| 231 |
+
sample_size=unet_params.image_size,
|
| 232 |
+
in_channels=unet_params.in_channels,
|
| 233 |
+
out_channels=unet_params.out_channels,
|
| 234 |
+
down_block_types=tuple(down_block_types),
|
| 235 |
+
up_block_types=tuple(up_block_types),
|
| 236 |
+
block_out_channels=tuple(block_out_channels),
|
| 237 |
+
layers_per_block=unet_params.num_res_blocks,
|
| 238 |
+
cross_attention_dim=unet_params.context_dim,
|
| 239 |
+
attention_head_dim=unet_params.num_heads,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
return config
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def create_vae_diffusers_config(original_config):
|
| 246 |
+
"""
|
| 247 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
| 248 |
+
"""
|
| 249 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
| 250 |
+
_ = original_config.model.params.first_stage_config.params.embed_dim
|
| 251 |
+
|
| 252 |
+
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
| 253 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
| 254 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
| 255 |
+
|
| 256 |
+
config = dict(
|
| 257 |
+
sample_size=vae_params.resolution,
|
| 258 |
+
in_channels=vae_params.in_channels,
|
| 259 |
+
out_channels=vae_params.out_ch,
|
| 260 |
+
down_block_types=tuple(down_block_types),
|
| 261 |
+
up_block_types=tuple(up_block_types),
|
| 262 |
+
block_out_channels=tuple(block_out_channels),
|
| 263 |
+
latent_channels=vae_params.z_channels,
|
| 264 |
+
layers_per_block=vae_params.num_res_blocks,
|
| 265 |
+
)
|
| 266 |
+
return config
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def create_diffusers_schedular(original_config):
|
| 270 |
+
schedular = DDIMScheduler(
|
| 271 |
+
num_train_timesteps=original_config.model.params.timesteps,
|
| 272 |
+
beta_start=original_config.model.params.linear_start,
|
| 273 |
+
beta_end=original_config.model.params.linear_end,
|
| 274 |
+
beta_schedule="scaled_linear",
|
| 275 |
+
)
|
| 276 |
+
return schedular
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def create_ldm_bert_config(original_config):
|
| 280 |
+
bert_params = original_config.model.parms.cond_stage_config.params
|
| 281 |
+
config = LDMBertConfig(
|
| 282 |
+
d_model=bert_params.n_embed,
|
| 283 |
+
encoder_layers=bert_params.n_layer,
|
| 284 |
+
encoder_ffn_dim=bert_params.n_embed * 4,
|
| 285 |
+
)
|
| 286 |
+
return config
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def convert_ldm_unet_checkpoint(checkpoint, config):
|
| 290 |
+
"""
|
| 291 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
| 292 |
+
"""
|
| 293 |
+
|
| 294 |
+
# extract state_dict for UNet
|
| 295 |
+
unet_state_dict = {}
|
| 296 |
+
unet_key = "model.diffusion_model."
|
| 297 |
+
keys = list(checkpoint.keys())
|
| 298 |
+
for key in keys:
|
| 299 |
+
if key.startswith(unet_key):
|
| 300 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
| 301 |
+
|
| 302 |
+
new_checkpoint = {}
|
| 303 |
+
|
| 304 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
| 305 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
| 306 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
| 307 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
| 308 |
+
|
| 309 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
| 310 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
| 311 |
+
|
| 312 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
| 313 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
| 314 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
| 315 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
| 316 |
+
|
| 317 |
+
# Retrieves the keys for the input blocks only
|
| 318 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
| 319 |
+
input_blocks = {
|
| 320 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
| 321 |
+
for layer_id in range(num_input_blocks)
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
# Retrieves the keys for the middle blocks only
|
| 325 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
| 326 |
+
middle_blocks = {
|
| 327 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
| 328 |
+
for layer_id in range(num_middle_blocks)
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
# Retrieves the keys for the output blocks only
|
| 332 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
| 333 |
+
output_blocks = {
|
| 334 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
| 335 |
+
for layer_id in range(num_output_blocks)
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
for i in range(1, num_input_blocks):
|
| 339 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
| 340 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
| 341 |
+
|
| 342 |
+
resnets = [
|
| 343 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
| 344 |
+
]
|
| 345 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
| 346 |
+
|
| 347 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
| 348 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
| 349 |
+
f"input_blocks.{i}.0.op.weight"
|
| 350 |
+
)
|
| 351 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
| 352 |
+
f"input_blocks.{i}.0.op.bias"
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
paths = renew_resnet_paths(resnets)
|
| 356 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 357 |
+
assign_to_checkpoint(
|
| 358 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
if len(attentions):
|
| 362 |
+
paths = renew_attention_paths(attentions)
|
| 363 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
| 364 |
+
assign_to_checkpoint(
|
| 365 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
resnet_0 = middle_blocks[0]
|
| 369 |
+
attentions = middle_blocks[1]
|
| 370 |
+
resnet_1 = middle_blocks[2]
|
| 371 |
+
|
| 372 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
| 373 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
| 374 |
+
|
| 375 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
| 376 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
| 377 |
+
|
| 378 |
+
attentions_paths = renew_attention_paths(attentions)
|
| 379 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
| 380 |
+
assign_to_checkpoint(
|
| 381 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
for i in range(num_output_blocks):
|
| 385 |
+
block_id = i // (config["layers_per_block"] + 1)
|
| 386 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
| 387 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
| 388 |
+
output_block_list = {}
|
| 389 |
+
|
| 390 |
+
for layer in output_block_layers:
|
| 391 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
| 392 |
+
if layer_id in output_block_list:
|
| 393 |
+
output_block_list[layer_id].append(layer_name)
|
| 394 |
+
else:
|
| 395 |
+
output_block_list[layer_id] = [layer_name]
|
| 396 |
+
|
| 397 |
+
if len(output_block_list) > 1:
|
| 398 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
| 399 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
| 400 |
+
|
| 401 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
| 402 |
+
paths = renew_resnet_paths(resnets)
|
| 403 |
+
|
| 404 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
| 405 |
+
assign_to_checkpoint(
|
| 406 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
if ["conv.weight", "conv.bias"] in output_block_list.values():
|
| 410 |
+
index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
|
| 411 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
| 412 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
| 413 |
+
]
|
| 414 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
| 415 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
| 416 |
+
]
|
| 417 |
+
|
| 418 |
+
# Clear attentions as they have been attributed above.
|
| 419 |
+
if len(attentions) == 2:
|
| 420 |
+
attentions = []
|
| 421 |
+
|
| 422 |
+
if len(attentions):
|
| 423 |
+
paths = renew_attention_paths(attentions)
|
| 424 |
+
meta_path = {
|
| 425 |
+
"old": f"output_blocks.{i}.1",
|
| 426 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
| 427 |
+
}
|
| 428 |
+
assign_to_checkpoint(
|
| 429 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
| 430 |
+
)
|
| 431 |
+
else:
|
| 432 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
| 433 |
+
for path in resnet_0_paths:
|
| 434 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
| 435 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
| 436 |
+
|
| 437 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
| 438 |
+
|
| 439 |
+
return new_checkpoint
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
| 443 |
+
# extract state dict for VAE
|
| 444 |
+
vae_state_dict = {}
|
| 445 |
+
vae_key = "first_stage_model."
|
| 446 |
+
keys = list(checkpoint.keys())
|
| 447 |
+
for key in keys:
|
| 448 |
+
if key.startswith(vae_key):
|
| 449 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
| 450 |
+
|
| 451 |
+
new_checkpoint = {}
|
| 452 |
+
|
| 453 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
| 454 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
| 455 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
| 456 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
| 457 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
| 458 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
| 459 |
+
|
| 460 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
| 461 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
| 462 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
| 463 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
| 464 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
| 465 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
| 466 |
+
|
| 467 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
| 468 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
| 469 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
| 470 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
| 471 |
+
|
| 472 |
+
# Retrieves the keys for the encoder down blocks only
|
| 473 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
| 474 |
+
down_blocks = {
|
| 475 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
# Retrieves the keys for the decoder up blocks only
|
| 479 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
| 480 |
+
up_blocks = {
|
| 481 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
for i in range(num_down_blocks):
|
| 485 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
| 486 |
+
|
| 487 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
| 488 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
| 489 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
| 490 |
+
)
|
| 491 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
| 492 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 496 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
| 497 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 498 |
+
|
| 499 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
| 500 |
+
num_mid_res_blocks = 2
|
| 501 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 502 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
| 503 |
+
|
| 504 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 505 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 506 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 507 |
+
|
| 508 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
| 509 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 510 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 511 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 512 |
+
conv_attn_to_linear(new_checkpoint)
|
| 513 |
+
|
| 514 |
+
for i in range(num_up_blocks):
|
| 515 |
+
block_id = num_up_blocks - 1 - i
|
| 516 |
+
resnets = [
|
| 517 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
| 518 |
+
]
|
| 519 |
+
|
| 520 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
| 521 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
| 522 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
| 523 |
+
]
|
| 524 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
| 525 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
| 526 |
+
]
|
| 527 |
+
|
| 528 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 529 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
| 530 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 531 |
+
|
| 532 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
| 533 |
+
num_mid_res_blocks = 2
|
| 534 |
+
for i in range(1, num_mid_res_blocks + 1):
|
| 535 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
| 536 |
+
|
| 537 |
+
paths = renew_vae_resnet_paths(resnets)
|
| 538 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
| 539 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 540 |
+
|
| 541 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
| 542 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
| 543 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
| 544 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
| 545 |
+
conv_attn_to_linear(new_checkpoint)
|
| 546 |
+
return new_checkpoint
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def convert_ldm_bert_checkpoint(checkpoint, config):
|
| 550 |
+
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
| 551 |
+
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
| 552 |
+
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
| 553 |
+
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
| 554 |
+
|
| 555 |
+
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
| 556 |
+
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
| 557 |
+
|
| 558 |
+
def _copy_linear(hf_linear, pt_linear):
|
| 559 |
+
hf_linear.weight = pt_linear.weight
|
| 560 |
+
hf_linear.bias = pt_linear.bias
|
| 561 |
+
|
| 562 |
+
def _copy_layer(hf_layer, pt_layer):
|
| 563 |
+
# copy layer norms
|
| 564 |
+
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
| 565 |
+
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
| 566 |
+
|
| 567 |
+
# copy attn
|
| 568 |
+
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
| 569 |
+
|
| 570 |
+
# copy MLP
|
| 571 |
+
pt_mlp = pt_layer[1][1]
|
| 572 |
+
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
| 573 |
+
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
| 574 |
+
|
| 575 |
+
def _copy_layers(hf_layers, pt_layers):
|
| 576 |
+
for i, hf_layer in enumerate(hf_layers):
|
| 577 |
+
if i != 0:
|
| 578 |
+
i += i
|
| 579 |
+
pt_layer = pt_layers[i : i + 2]
|
| 580 |
+
_copy_layer(hf_layer, pt_layer)
|
| 581 |
+
|
| 582 |
+
hf_model = LDMBertModel(config).eval()
|
| 583 |
+
|
| 584 |
+
# copy embeds
|
| 585 |
+
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
| 586 |
+
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
| 587 |
+
|
| 588 |
+
# copy layer norm
|
| 589 |
+
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
| 590 |
+
|
| 591 |
+
# copy hidden layers
|
| 592 |
+
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
| 593 |
+
|
| 594 |
+
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
| 595 |
+
|
| 596 |
+
return hf_model
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def convert_ldm_clip_checkpoint(checkpoint):
|
| 600 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
| 601 |
+
|
| 602 |
+
keys = list(checkpoint.keys())
|
| 603 |
+
|
| 604 |
+
text_model_dict = {}
|
| 605 |
+
|
| 606 |
+
for key in keys:
|
| 607 |
+
if key.startswith("cond_stage_model.transformer"):
|
| 608 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
| 609 |
+
|
| 610 |
+
text_model.load_state_dict(text_model_dict)
|
| 611 |
+
|
| 612 |
+
return text_model
|
| 613 |
+
|
| 614 |
+
import os
|
| 615 |
+
def convert_checkpoint(checkpoint_path, inpainting=False):
|
| 616 |
+
parser = argparse.ArgumentParser()
|
| 617 |
+
|
| 618 |
+
parser.add_argument(
|
| 619 |
+
"--checkpoint_path", default=checkpoint_path, type=str, help="Path to the checkpoint to convert."
|
| 620 |
+
)
|
| 621 |
+
# !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
|
| 622 |
+
parser.add_argument(
|
| 623 |
+
"--original_config_file",
|
| 624 |
+
default=None,
|
| 625 |
+
type=str,
|
| 626 |
+
help="The YAML config file corresponding to the original architecture.",
|
| 627 |
+
)
|
| 628 |
+
parser.add_argument(
|
| 629 |
+
"--scheduler_type",
|
| 630 |
+
default="pndm",
|
| 631 |
+
type=str,
|
| 632 |
+
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
|
| 633 |
+
)
|
| 634 |
+
parser.add_argument("--dump_path", default=None, type=str, help="Path to the output model.")
|
| 635 |
+
|
| 636 |
+
args = parser.parse_args([])
|
| 637 |
+
if args.original_config_file is None:
|
| 638 |
+
if inpainting:
|
| 639 |
+
args.original_config_file = "./models/v1-inpainting-inference.yaml"
|
| 640 |
+
else:
|
| 641 |
+
args.original_config_file = "./models/v1-inference.yaml"
|
| 642 |
+
|
| 643 |
+
original_config = OmegaConf.load(args.original_config_file)
|
| 644 |
+
checkpoint = torch.load(args.checkpoint_path)["state_dict"]
|
| 645 |
+
|
| 646 |
+
num_train_timesteps = original_config.model.params.timesteps
|
| 647 |
+
beta_start = original_config.model.params.linear_start
|
| 648 |
+
beta_end = original_config.model.params.linear_end
|
| 649 |
+
if args.scheduler_type == "pndm":
|
| 650 |
+
scheduler = PNDMScheduler(
|
| 651 |
+
beta_end=beta_end,
|
| 652 |
+
beta_schedule="scaled_linear",
|
| 653 |
+
beta_start=beta_start,
|
| 654 |
+
num_train_timesteps=num_train_timesteps,
|
| 655 |
+
skip_prk_steps=True,
|
| 656 |
+
)
|
| 657 |
+
elif args.scheduler_type == "lms":
|
| 658 |
+
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
|
| 659 |
+
elif args.scheduler_type == "ddim":
|
| 660 |
+
scheduler = DDIMScheduler(
|
| 661 |
+
beta_start=beta_start,
|
| 662 |
+
beta_end=beta_end,
|
| 663 |
+
beta_schedule="scaled_linear",
|
| 664 |
+
clip_sample=False,
|
| 665 |
+
set_alpha_to_one=False,
|
| 666 |
+
)
|
| 667 |
+
else:
|
| 668 |
+
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
|
| 669 |
+
|
| 670 |
+
# Convert the UNet2DConditionModel model.
|
| 671 |
+
unet_config = create_unet_diffusers_config(original_config)
|
| 672 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
|
| 673 |
+
|
| 674 |
+
unet = UNet2DConditionModel(**unet_config)
|
| 675 |
+
unet.load_state_dict(converted_unet_checkpoint)
|
| 676 |
+
|
| 677 |
+
# Convert the VAE model.
|
| 678 |
+
vae_config = create_vae_diffusers_config(original_config)
|
| 679 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
| 680 |
+
|
| 681 |
+
vae = AutoencoderKL(**vae_config)
|
| 682 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
| 683 |
+
|
| 684 |
+
# Convert the text model.
|
| 685 |
+
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
| 686 |
+
if text_model_type == "FrozenCLIPEmbedder":
|
| 687 |
+
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
| 688 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 689 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
| 690 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
| 691 |
+
pipe = StableDiffusionPipeline(
|
| 692 |
+
vae=vae,
|
| 693 |
+
text_encoder=text_model,
|
| 694 |
+
tokenizer=tokenizer,
|
| 695 |
+
unet=unet,
|
| 696 |
+
scheduler=scheduler,
|
| 697 |
+
safety_checker=safety_checker,
|
| 698 |
+
feature_extractor=feature_extractor,
|
| 699 |
+
)
|
| 700 |
+
else:
|
| 701 |
+
text_config = create_ldm_bert_config(original_config)
|
| 702 |
+
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
| 703 |
+
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
| 704 |
+
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
| 705 |
+
|
| 706 |
+
return pipe
|
css/w2ui.min.css
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
index.html
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<html>
|
| 2 |
+
<head>
|
| 3 |
+
<title>Stablediffusion Infinity</title>
|
| 4 |
+
<meta charset="utf-8">
|
| 5 |
+
<link rel="icon" type="image/x-icon" href="./favicon.png">
|
| 6 |
+
|
| 7 |
+
<link rel="stylesheet" type="text/css" href="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@v0.1.2/css/w2ui.min.css">
|
| 8 |
+
<script type="text/javascript" src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@v0.1.2/js/w2ui.min.js"></script>
|
| 9 |
+
<link rel="stylesheet" type="text/css" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.2.0/css/all.min.css">
|
| 10 |
+
<script src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@v0.1.2/js/fabric.min.js"></script>
|
| 11 |
+
<script defer src="https://cdn.jsdelivr.net/gh/lkwq007/stablediffusion-infinity@v0.1.2/js/toolbar.js"></script>
|
| 12 |
+
|
| 13 |
+
<link rel="stylesheet" href="https://pyscript.net/releases/2022.05.1/pyscript.css" />
|
| 14 |
+
<script defer src="https://pyscript.net/releases/2022.05.1/pyscript.js"></script>
|
| 15 |
+
|
| 16 |
+
<style>
|
| 17 |
+
#container {
|
| 18 |
+
position: relative;
|
| 19 |
+
margin:auto;
|
| 20 |
+
display: block;
|
| 21 |
+
}
|
| 22 |
+
#container > canvas {
|
| 23 |
+
position: absolute;
|
| 24 |
+
top: 0;
|
| 25 |
+
left: 0;
|
| 26 |
+
}
|
| 27 |
+
.control {
|
| 28 |
+
display: none;
|
| 29 |
+
}
|
| 30 |
+
</style>
|
| 31 |
+
|
| 32 |
+
</head>
|
| 33 |
+
<body>
|
| 34 |
+
<div>
|
| 35 |
+
<button type="button" class="control" id="export">Export</button>
|
| 36 |
+
<button type="button" class="control" id="outpaint">Outpaint</button>
|
| 37 |
+
<button type="button" class="control" id="undo">Undo</button>
|
| 38 |
+
<button type="button" class="control" id="commit">Commit</button>
|
| 39 |
+
<button type="button" class="control" id="transfer">Transfer</button>
|
| 40 |
+
<button type="button" class="control" id="upload">Upload</button>
|
| 41 |
+
<button type="button" class="control" id="draw">Draw</button>
|
| 42 |
+
<input type="text" id="mode" value="selection" class="control">
|
| 43 |
+
<input type="text" id="setup" value="0" class="control">
|
| 44 |
+
<input type="text" id="upload_content" value="0" class="control">
|
| 45 |
+
<textarea rows="1" id="selbuffer" name="selbuffer" class="control"></textarea>
|
| 46 |
+
<fieldset class="control">
|
| 47 |
+
<div>
|
| 48 |
+
<input type="radio" id="mode0" name="mode" value="0" checked>
|
| 49 |
+
<label for="mode0">SelBox</label>
|
| 50 |
+
</div>
|
| 51 |
+
<div>
|
| 52 |
+
<input type="radio" id="mode1" name="mode" value="1">
|
| 53 |
+
<label for="mode1">Image</label>
|
| 54 |
+
</div>
|
| 55 |
+
<div>
|
| 56 |
+
<input type="radio" id="mode2" name="mode" value="2">
|
| 57 |
+
<label for="mode2">Brush</label>
|
| 58 |
+
</div>
|
| 59 |
+
</fieldset>
|
| 60 |
+
</div>
|
| 61 |
+
<div id = "outer_container">
|
| 62 |
+
<div id = "container">
|
| 63 |
+
<canvas id = "canvas0"></canvas>
|
| 64 |
+
<canvas id = "canvas1"></canvas>
|
| 65 |
+
<canvas id = "canvas2"></canvas>
|
| 66 |
+
<canvas id = "canvas3"></canvas>
|
| 67 |
+
<canvas id = "canvas4"></canvas>
|
| 68 |
+
<div id="overlay_container" style="pointer-events: none">
|
| 69 |
+
<canvas id = "overlay_canvas" width="1" height="1"></canvas>
|
| 70 |
+
</div>
|
| 71 |
+
</div>
|
| 72 |
+
<input type="file" name="file" id="upload_file" accept="image/*" hidden>
|
| 73 |
+
<input type="file" name="state" id="upload_state" accept=".sdinf" hidden>
|
| 74 |
+
<div style="position: relative;">
|
| 75 |
+
<div id="toolbar" style></div>
|
| 76 |
+
</div>
|
| 77 |
+
</div>
|
| 78 |
+
<py-env>
|
| 79 |
+
- numpy
|
| 80 |
+
- Pillow
|
| 81 |
+
- paths:
|
| 82 |
+
- ./canvas.py
|
| 83 |
+
</py-env>
|
| 84 |
+
|
| 85 |
+
<py-script>
|
| 86 |
+
from pyodide import to_js, create_proxy
|
| 87 |
+
from PIL import Image
|
| 88 |
+
import io
|
| 89 |
+
import time
|
| 90 |
+
import base64
|
| 91 |
+
from collections import deque
|
| 92 |
+
import numpy as np
|
| 93 |
+
from js import (
|
| 94 |
+
console,
|
| 95 |
+
document,
|
| 96 |
+
parent,
|
| 97 |
+
devicePixelRatio,
|
| 98 |
+
ImageData,
|
| 99 |
+
Uint8ClampedArray,
|
| 100 |
+
CanvasRenderingContext2D as Context2d,
|
| 101 |
+
requestAnimationFrame,
|
| 102 |
+
window,
|
| 103 |
+
encodeURIComponent,
|
| 104 |
+
w2ui,
|
| 105 |
+
update_eraser,
|
| 106 |
+
update_scale,
|
| 107 |
+
adjust_selection,
|
| 108 |
+
update_count,
|
| 109 |
+
enable_result_lst,
|
| 110 |
+
setup_shortcut,
|
| 111 |
+
update_undo_redo,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
from canvas import InfCanvas
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class History:
|
| 119 |
+
def __init__(self,maxlen=10):
|
| 120 |
+
self.idx=-1
|
| 121 |
+
self.undo_lst=deque([],maxlen=maxlen)
|
| 122 |
+
self.redo_lst=deque([],maxlen=maxlen)
|
| 123 |
+
self.state=None
|
| 124 |
+
|
| 125 |
+
def undo(self):
|
| 126 |
+
cur=None
|
| 127 |
+
if len(self.undo_lst):
|
| 128 |
+
cur=self.undo_lst.pop()
|
| 129 |
+
self.redo_lst.appendleft(cur)
|
| 130 |
+
return cur
|
| 131 |
+
def redo(self):
|
| 132 |
+
cur=None
|
| 133 |
+
if len(self.redo_lst):
|
| 134 |
+
cur=self.redo_lst.popleft()
|
| 135 |
+
self.undo_lst.append(cur)
|
| 136 |
+
return cur
|
| 137 |
+
|
| 138 |
+
def check(self):
|
| 139 |
+
return len(self.undo_lst)>0,len(self.redo_lst)>0
|
| 140 |
+
|
| 141 |
+
def append(self,state,update=True):
|
| 142 |
+
self.redo_lst.clear()
|
| 143 |
+
self.undo_lst.append(state)
|
| 144 |
+
if update:
|
| 145 |
+
update_undo_redo(*self.check())
|
| 146 |
+
|
| 147 |
+
history = History()
|
| 148 |
+
|
| 149 |
+
base_lst = [None]
|
| 150 |
+
async def draw_canvas() -> None:
|
| 151 |
+
width=1024
|
| 152 |
+
height=600
|
| 153 |
+
canvas=InfCanvas(1024,600)
|
| 154 |
+
update_eraser(canvas.eraser_size,min(canvas.selection_size_h,canvas.selection_size_w))
|
| 155 |
+
document.querySelector("#container").style.height= f"{height}px"
|
| 156 |
+
document.querySelector("#container").style.width = f"{width}px"
|
| 157 |
+
canvas.setup_mouse()
|
| 158 |
+
canvas.clear_background()
|
| 159 |
+
canvas.draw_buffer()
|
| 160 |
+
canvas.draw_selection_box()
|
| 161 |
+
base_lst[0]=canvas
|
| 162 |
+
|
| 163 |
+
async def draw_canvas_func(event):
|
| 164 |
+
try:
|
| 165 |
+
app=parent.document.querySelector("gradio-app")
|
| 166 |
+
if app.shadowRoot:
|
| 167 |
+
app=app.shadowRoot
|
| 168 |
+
width=app.querySelector("#canvas_width input").value
|
| 169 |
+
height=app.querySelector("#canvas_height input").value
|
| 170 |
+
selection_size=app.querySelector("#selection_size input").value
|
| 171 |
+
except:
|
| 172 |
+
width=1024
|
| 173 |
+
height=768
|
| 174 |
+
selection_size=384
|
| 175 |
+
document.querySelector("#container").style.width = f"{width}px"
|
| 176 |
+
document.querySelector("#container").style.height= f"{height}px"
|
| 177 |
+
canvas=InfCanvas(int(width),int(height),selection_size=int(selection_size))
|
| 178 |
+
canvas.setup_mouse()
|
| 179 |
+
canvas.clear_background()
|
| 180 |
+
canvas.draw_buffer()
|
| 181 |
+
canvas.draw_selection_box()
|
| 182 |
+
base_lst[0]=canvas
|
| 183 |
+
|
| 184 |
+
async def export_func(event):
|
| 185 |
+
base=base_lst[0]
|
| 186 |
+
arr=base.export()
|
| 187 |
+
base.draw_buffer()
|
| 188 |
+
base.canvas[2].clear()
|
| 189 |
+
base64_str = base.numpy_to_base64(arr)
|
| 190 |
+
time_str = time.strftime("%Y%m%d_%H%M%S")
|
| 191 |
+
link = document.createElement("a")
|
| 192 |
+
if len(event.data)>2 and event.data[2]:
|
| 193 |
+
filename = event.data[2]
|
| 194 |
+
else:
|
| 195 |
+
filename = f"outpaint_{time_str}"
|
| 196 |
+
# link.download = f"sdinf_state_{time_str}.json"
|
| 197 |
+
link.download = f"{filename}.png"
|
| 198 |
+
# link.download = f"outpaint_{time_str}.png"
|
| 199 |
+
link.href = "data:image/png;base64,"+base64_str
|
| 200 |
+
link.click()
|
| 201 |
+
console.log(f"Canvas saved to {filename}.png")
|
| 202 |
+
|
| 203 |
+
img_candidate_lst=[None,0]
|
| 204 |
+
|
| 205 |
+
async def outpaint_func(event):
|
| 206 |
+
base=base_lst[0]
|
| 207 |
+
if len(event.data)==2:
|
| 208 |
+
app=parent.document.querySelector("gradio-app")
|
| 209 |
+
if app.shadowRoot:
|
| 210 |
+
app=app.shadowRoot
|
| 211 |
+
base64_str_raw=app.querySelector("#output textarea").value
|
| 212 |
+
base64_str_lst=base64_str_raw.split(",")
|
| 213 |
+
img_candidate_lst[0]=base64_str_lst
|
| 214 |
+
img_candidate_lst[1]=0
|
| 215 |
+
elif event.data[2]=="next":
|
| 216 |
+
img_candidate_lst[1]+=1
|
| 217 |
+
elif event.data[2]=="prev":
|
| 218 |
+
img_candidate_lst[1]-=1
|
| 219 |
+
enable_result_lst()
|
| 220 |
+
if img_candidate_lst[0] is None:
|
| 221 |
+
return
|
| 222 |
+
lst=img_candidate_lst[0]
|
| 223 |
+
idx=img_candidate_lst[1]
|
| 224 |
+
update_count(idx%len(lst)+1,len(lst))
|
| 225 |
+
arr=base.base64_to_numpy(lst[idx%len(lst)])
|
| 226 |
+
base.fill_selection(arr)
|
| 227 |
+
base.draw_selection_box()
|
| 228 |
+
|
| 229 |
+
async def undo_func(event):
|
| 230 |
+
base=base_lst[0]
|
| 231 |
+
img_candidate_lst[0]=None
|
| 232 |
+
if base.sel_dirty:
|
| 233 |
+
base.sel_buffer = np.zeros((base.selection_size_h, base.selection_size_w, 4), dtype=np.uint8)
|
| 234 |
+
base.sel_dirty = False
|
| 235 |
+
base.canvas[2].clear()
|
| 236 |
+
|
| 237 |
+
async def commit_func(event):
|
| 238 |
+
base=base_lst[0]
|
| 239 |
+
img_candidate_lst[0]=None
|
| 240 |
+
if base.sel_dirty:
|
| 241 |
+
base.write_selection_to_buffer()
|
| 242 |
+
base.draw_buffer()
|
| 243 |
+
base.canvas[2].clear()
|
| 244 |
+
if len(event.data)>2:
|
| 245 |
+
history.append(base.save())
|
| 246 |
+
|
| 247 |
+
async def history_undo_func(event):
|
| 248 |
+
base=base_lst[0]
|
| 249 |
+
if base.buffer_dirty or len(history.redo_lst)>0:
|
| 250 |
+
state=history.undo()
|
| 251 |
+
else:
|
| 252 |
+
history.undo()
|
| 253 |
+
state=history.undo()
|
| 254 |
+
if state is not None:
|
| 255 |
+
base.load(state)
|
| 256 |
+
update_undo_redo(*history.check())
|
| 257 |
+
|
| 258 |
+
async def history_setup_func(event):
|
| 259 |
+
base=base_lst[0]
|
| 260 |
+
history.undo_lst.clear()
|
| 261 |
+
history.redo_lst.clear()
|
| 262 |
+
history.append(base.save(),update=False)
|
| 263 |
+
|
| 264 |
+
async def history_redo_func(event):
|
| 265 |
+
base=base_lst[0]
|
| 266 |
+
if len(history.undo_lst)>0:
|
| 267 |
+
state=history.redo()
|
| 268 |
+
else:
|
| 269 |
+
history.redo()
|
| 270 |
+
state=history.redo()
|
| 271 |
+
if state is not None:
|
| 272 |
+
base.load(state)
|
| 273 |
+
update_undo_redo(*history.check())
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
async def transfer_func(event):
|
| 277 |
+
base=base_lst[0]
|
| 278 |
+
base.read_selection_from_buffer()
|
| 279 |
+
sel_buffer=base.sel_buffer
|
| 280 |
+
sel_buffer_str=base.numpy_to_base64(sel_buffer)
|
| 281 |
+
app=parent.document.querySelector("gradio-app")
|
| 282 |
+
if app.shadowRoot:
|
| 283 |
+
app=app.shadowRoot
|
| 284 |
+
app.querySelector("#input textarea").value=sel_buffer_str
|
| 285 |
+
app.querySelector("#proceed").click()
|
| 286 |
+
|
| 287 |
+
async def upload_func(event):
|
| 288 |
+
base=base_lst[0]
|
| 289 |
+
# base64_str=event.data[1]
|
| 290 |
+
base64_str=document.querySelector("#upload_content").value
|
| 291 |
+
base64_str=base64_str.split(",")[-1]
|
| 292 |
+
# base64_str=parent.document.querySelector("gradio-app").shadowRoot.querySelector("#upload textarea").value
|
| 293 |
+
arr=base.base64_to_numpy(base64_str)
|
| 294 |
+
h,w,c=base.buffer.shape
|
| 295 |
+
base.sync_to_buffer()
|
| 296 |
+
base.buffer_dirty=True
|
| 297 |
+
mask=arr[:,:,3:4].repeat(4,axis=2)
|
| 298 |
+
base.buffer[mask>0]=0
|
| 299 |
+
# in case mismatch
|
| 300 |
+
base.buffer[0:h,0:w,:]+=arr
|
| 301 |
+
#base.buffer[yo:yo+h,xo:xo+w,0:3]=arr[:,:,0:3]
|
| 302 |
+
#base.buffer[yo:yo+h,xo:xo+w,-1]=arr[:,:,-1]
|
| 303 |
+
base.draw_buffer()
|
| 304 |
+
if len(event.data)>2:
|
| 305 |
+
history.append(base.save())
|
| 306 |
+
|
| 307 |
+
async def setup_shortcut_func(event):
|
| 308 |
+
setup_shortcut(event.data[1])
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
document.querySelector("#export").addEventListener("click",create_proxy(export_func))
|
| 312 |
+
document.querySelector("#undo").addEventListener("click",create_proxy(undo_func))
|
| 313 |
+
document.querySelector("#commit").addEventListener("click",create_proxy(commit_func))
|
| 314 |
+
document.querySelector("#outpaint").addEventListener("click",create_proxy(outpaint_func))
|
| 315 |
+
document.querySelector("#upload").addEventListener("click",create_proxy(upload_func))
|
| 316 |
+
|
| 317 |
+
document.querySelector("#transfer").addEventListener("click",create_proxy(transfer_func))
|
| 318 |
+
document.querySelector("#draw").addEventListener("click",create_proxy(draw_canvas_func))
|
| 319 |
+
|
| 320 |
+
async def setup_func():
|
| 321 |
+
document.querySelector("#setup").value="1"
|
| 322 |
+
|
| 323 |
+
async def reset_func(event):
|
| 324 |
+
base=base_lst[0]
|
| 325 |
+
base.reset()
|
| 326 |
+
|
| 327 |
+
async def load_func(event):
|
| 328 |
+
base=base_lst[0]
|
| 329 |
+
base.load(event.data[1])
|
| 330 |
+
|
| 331 |
+
async def save_func(event):
|
| 332 |
+
base=base_lst[0]
|
| 333 |
+
json_str=base.save()
|
| 334 |
+
time_str = time.strftime("%Y%m%d_%H%M%S")
|
| 335 |
+
link = document.createElement("a")
|
| 336 |
+
if len(event.data)>2 and event.data[2]:
|
| 337 |
+
filename = str(event.data[2]).strip()
|
| 338 |
+
else:
|
| 339 |
+
filename = f"outpaint_{time_str}"
|
| 340 |
+
# link.download = f"sdinf_state_{time_str}.json"
|
| 341 |
+
link.download = f"{filename}.sdinf"
|
| 342 |
+
link.href = "data:text/json;charset=utf-8,"+encodeURIComponent(json_str)
|
| 343 |
+
link.click()
|
| 344 |
+
|
| 345 |
+
async def prev_result_func(event):
|
| 346 |
+
base=base_lst[0]
|
| 347 |
+
base.reset()
|
| 348 |
+
|
| 349 |
+
async def next_result_func(event):
|
| 350 |
+
base=base_lst[0]
|
| 351 |
+
base.reset()
|
| 352 |
+
|
| 353 |
+
async def zoom_in_func(event):
|
| 354 |
+
base=base_lst[0]
|
| 355 |
+
scale=base.scale
|
| 356 |
+
if scale>=0.2:
|
| 357 |
+
scale-=0.1
|
| 358 |
+
if len(event.data)>2:
|
| 359 |
+
base.update_scale(scale,int(event.data[2]),int(event.data[3]))
|
| 360 |
+
else:
|
| 361 |
+
base.update_scale(scale)
|
| 362 |
+
scale=base.scale
|
| 363 |
+
update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
|
| 364 |
+
|
| 365 |
+
async def zoom_out_func(event):
|
| 366 |
+
base=base_lst[0]
|
| 367 |
+
scale=base.scale
|
| 368 |
+
if scale<10:
|
| 369 |
+
scale+=0.1
|
| 370 |
+
console.log(len(event.data))
|
| 371 |
+
if len(event.data)>2:
|
| 372 |
+
base.update_scale(scale,int(event.data[2]),int(event.data[3]))
|
| 373 |
+
else:
|
| 374 |
+
base.update_scale(scale)
|
| 375 |
+
scale=base.scale
|
| 376 |
+
update_scale(f"{base.width}x{base.height} ({round(100/scale)}%)")
|
| 377 |
+
|
| 378 |
+
async def sync_func(event):
|
| 379 |
+
base=base_lst[0]
|
| 380 |
+
base.sync_to_buffer()
|
| 381 |
+
base.canvas[2].clear()
|
| 382 |
+
|
| 383 |
+
async def eraser_size_func(event):
|
| 384 |
+
base=base_lst[0]
|
| 385 |
+
eraser_size=min(int(event.data[1]),min(base.selection_size_h,base.selection_size_w))
|
| 386 |
+
eraser_size=max(8,eraser_size)
|
| 387 |
+
base.eraser_size=eraser_size
|
| 388 |
+
|
| 389 |
+
async def resize_selection_func(event):
|
| 390 |
+
base=base_lst[0]
|
| 391 |
+
cursor=base.cursor
|
| 392 |
+
if len(event.data)>3:
|
| 393 |
+
console.log(event.data)
|
| 394 |
+
base.cursor[0]=int(event.data[1])
|
| 395 |
+
base.cursor[1]=int(event.data[2])
|
| 396 |
+
base.selection_size_w=int(event.data[3])//8*8
|
| 397 |
+
base.selection_size_h=int(event.data[4])//8*8
|
| 398 |
+
base.refine_selection()
|
| 399 |
+
base.draw_selection_box()
|
| 400 |
+
elif len(event.data)>2:
|
| 401 |
+
base.draw_selection_box()
|
| 402 |
+
else:
|
| 403 |
+
base.canvas[-1].clear()
|
| 404 |
+
adjust_selection(cursor[0],cursor[1],base.selection_size_w,base.selection_size_h)
|
| 405 |
+
|
| 406 |
+
async def eraser_func(event):
|
| 407 |
+
base=base_lst[0]
|
| 408 |
+
if event.data[1]!="eraser":
|
| 409 |
+
base.canvas[-2].clear()
|
| 410 |
+
else:
|
| 411 |
+
x,y=base.mouse_pos
|
| 412 |
+
base.draw_eraser(x,y)
|
| 413 |
+
|
| 414 |
+
async def resize_func(event):
|
| 415 |
+
base=base_lst[0]
|
| 416 |
+
width=int(event.data[1])
|
| 417 |
+
height=int(event.data[2])
|
| 418 |
+
if width>=256 and height>=256:
|
| 419 |
+
if max(base.selection_size_h,base.selection_size_w)>min(width,height):
|
| 420 |
+
base.selection_size_h=256
|
| 421 |
+
base.selection_size_w=256
|
| 422 |
+
base.resize(width,height)
|
| 423 |
+
|
| 424 |
+
async def message_func(event):
|
| 425 |
+
if event.data[0]=="click":
|
| 426 |
+
if event.data[1]=="clear":
|
| 427 |
+
await reset_func(event)
|
| 428 |
+
elif event.data[1]=="save":
|
| 429 |
+
await save_func(event)
|
| 430 |
+
elif event.data[1]=="export":
|
| 431 |
+
await export_func(event)
|
| 432 |
+
elif event.data[1]=="accept":
|
| 433 |
+
await commit_func(event)
|
| 434 |
+
elif event.data[1]=="cancel":
|
| 435 |
+
await undo_func(event)
|
| 436 |
+
elif event.data[1]=="zoom_in":
|
| 437 |
+
await zoom_in_func(event)
|
| 438 |
+
elif event.data[1]=="zoom_out":
|
| 439 |
+
await zoom_out_func(event)
|
| 440 |
+
elif event.data[1]=="redo":
|
| 441 |
+
await history_redo_func(event)
|
| 442 |
+
elif event.data[1]=="undo":
|
| 443 |
+
await history_undo_func(event)
|
| 444 |
+
elif event.data[1]=="history":
|
| 445 |
+
await history_setup_func(event)
|
| 446 |
+
elif event.data[0]=="sync":
|
| 447 |
+
await sync_func(event)
|
| 448 |
+
elif event.data[0]=="load":
|
| 449 |
+
await load_func(event)
|
| 450 |
+
elif event.data[0]=="upload":
|
| 451 |
+
await upload_func(event)
|
| 452 |
+
elif event.data[0]=="outpaint":
|
| 453 |
+
await outpaint_func(event)
|
| 454 |
+
elif event.data[0]=="mode":
|
| 455 |
+
if event.data[1]!="selection":
|
| 456 |
+
await sync_func(event)
|
| 457 |
+
await eraser_func(event)
|
| 458 |
+
document.querySelector("#mode").value=event.data[1]
|
| 459 |
+
elif event.data[0]=="transfer":
|
| 460 |
+
await transfer_func(event)
|
| 461 |
+
elif event.data[0]=="setup":
|
| 462 |
+
await draw_canvas_func(event)
|
| 463 |
+
elif event.data[0]=="eraser_size":
|
| 464 |
+
await eraser_size_func(event)
|
| 465 |
+
elif event.data[0]=="resize_selection":
|
| 466 |
+
await resize_selection_func(event)
|
| 467 |
+
elif event.data[0]=="shortcut":
|
| 468 |
+
await setup_shortcut_func(event)
|
| 469 |
+
elif event.data[0]=="resize":
|
| 470 |
+
await resize_func(event)
|
| 471 |
+
|
| 472 |
+
window.addEventListener("message",create_proxy(message_func))
|
| 473 |
+
|
| 474 |
+
import asyncio
|
| 475 |
+
|
| 476 |
+
_ = await asyncio.gather(
|
| 477 |
+
setup_func(),draw_canvas()
|
| 478 |
+
)
|
| 479 |
+
</py-script>
|
| 480 |
+
|
| 481 |
+
</body>
|
| 482 |
+
</html>
|
interrogate.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MIT License
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2022 pharmapsychotic
|
| 5 |
+
https://github.com/pharmapsychotic/clip-interrogator/blob/main/clip_interrogator.ipynb
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import os
|
| 10 |
+
import torch
|
| 11 |
+
import torchvision.transforms as T
|
| 12 |
+
import torchvision.transforms.functional as TF
|
| 13 |
+
|
| 14 |
+
from torch import nn
|
| 15 |
+
from torch.nn import functional as F
|
| 16 |
+
from torchvision import transforms
|
| 17 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 18 |
+
from transformers import CLIPTokenizer, CLIPModel
|
| 19 |
+
from transformers import CLIPProcessor, CLIPModel
|
| 20 |
+
|
| 21 |
+
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "blip_model", "data")
|
| 22 |
+
def load_list(filename):
|
| 23 |
+
with open(filename, 'r', encoding='utf-8', errors='replace') as f:
|
| 24 |
+
items = [line.strip() for line in f.readlines()]
|
| 25 |
+
return items
|
| 26 |
+
|
| 27 |
+
artists = load_list(os.path.join(data_path, 'artists.txt'))
|
| 28 |
+
flavors = load_list(os.path.join(data_path, 'flavors.txt'))
|
| 29 |
+
mediums = load_list(os.path.join(data_path, 'mediums.txt'))
|
| 30 |
+
movements = load_list(os.path.join(data_path, 'movements.txt'))
|
| 31 |
+
|
| 32 |
+
sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
|
| 33 |
+
trending_list = [site for site in sites]
|
| 34 |
+
trending_list.extend(["trending on "+site for site in sites])
|
| 35 |
+
trending_list.extend(["featured on "+site for site in sites])
|
| 36 |
+
trending_list.extend([site+" contest winner" for site in sites])
|
| 37 |
+
|
| 38 |
+
device="cpu"
|
| 39 |
+
blip_image_eval_size = 384
|
| 40 |
+
clip_name="openai/clip-vit-large-patch14"
|
| 41 |
+
|
| 42 |
+
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
| 43 |
+
|
| 44 |
+
def generate_caption(blip_model, pil_image, device="cpu"):
|
| 45 |
+
gpu_image = transforms.Compose([
|
| 46 |
+
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
|
| 47 |
+
transforms.ToTensor(),
|
| 48 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
| 49 |
+
])(pil_image).unsqueeze(0).to(device)
|
| 50 |
+
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
|
| 53 |
+
return caption[0]
|
| 54 |
+
|
| 55 |
+
def rank(text_features, image_features, text_array, top_count=1):
|
| 56 |
+
top_count = min(top_count, len(text_array))
|
| 57 |
+
similarity = torch.zeros((1, len(text_array)))
|
| 58 |
+
for i in range(image_features.shape[0]):
|
| 59 |
+
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
|
| 60 |
+
similarity /= image_features.shape[0]
|
| 61 |
+
|
| 62 |
+
top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
|
| 63 |
+
return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]
|
| 64 |
+
|
| 65 |
+
class Interrogator:
|
| 66 |
+
def __init__(self) -> None:
|
| 67 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(clip_name)
|
| 68 |
+
try:
|
| 69 |
+
self.get_blip()
|
| 70 |
+
except:
|
| 71 |
+
self.blip_model = None
|
| 72 |
+
self.model = CLIPModel.from_pretrained(clip_name)
|
| 73 |
+
self.processor = CLIPProcessor.from_pretrained(clip_name)
|
| 74 |
+
self.text_feature_lst = [torch.load(os.path.join(data_path, f"{i}.pth")) for i in range(5)]
|
| 75 |
+
|
| 76 |
+
def get_blip(self):
|
| 77 |
+
from blip_model.blip import blip_decoder
|
| 78 |
+
blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base')
|
| 79 |
+
blip_model.eval()
|
| 80 |
+
self.blip_model = blip_model
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def interrogate(self,image,use_caption=False):
|
| 84 |
+
if self.blip_model:
|
| 85 |
+
caption = generate_caption(self.blip_model, image)
|
| 86 |
+
else:
|
| 87 |
+
caption = ""
|
| 88 |
+
model,processor=self.model,self.processor
|
| 89 |
+
bests = [[('',0)]]*5
|
| 90 |
+
if True:
|
| 91 |
+
print(f"Interrogating with {clip_name}...")
|
| 92 |
+
|
| 93 |
+
inputs = processor(images=image, return_tensors="pt")
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
image_features = model.get_image_features(**inputs)
|
| 96 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
| 97 |
+
ranks = [
|
| 98 |
+
rank(self.text_feature_lst[0], image_features, mediums),
|
| 99 |
+
rank(self.text_feature_lst[1], image_features, ["by "+artist for artist in artists]),
|
| 100 |
+
rank(self.text_feature_lst[2], image_features, trending_list),
|
| 101 |
+
rank(self.text_feature_lst[3], image_features, movements),
|
| 102 |
+
rank(self.text_feature_lst[4], image_features, flavors, top_count=3)
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
for i in range(len(ranks)):
|
| 106 |
+
confidence_sum = 0
|
| 107 |
+
for ci in range(len(ranks[i])):
|
| 108 |
+
confidence_sum += ranks[i][ci][1]
|
| 109 |
+
if confidence_sum > sum(bests[i][t][1] for t in range(len(bests[i]))):
|
| 110 |
+
bests[i] = ranks[i]
|
| 111 |
+
|
| 112 |
+
flaves = ', '.join([f"{x[0]}" for x in bests[4]])
|
| 113 |
+
medium = bests[0][0][0]
|
| 114 |
+
print(ranks)
|
| 115 |
+
if caption.startswith(medium):
|
| 116 |
+
return f"{caption} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}"
|
| 117 |
+
else:
|
| 118 |
+
return f"{caption}, {medium} {bests[1][0][0]}, {bests[2][0][0]}, {bests[3][0][0]}, {flaves}"
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
js/fabric.min.js
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
js/keyboard.js
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
window.my_setup_keyboard=setInterval(function(){
|
| 3 |
+
let app=document.querySelector("gradio-app");
|
| 4 |
+
app=app.shadowRoot??app;
|
| 5 |
+
let frame=app.querySelector("#sdinfframe").contentWindow;
|
| 6 |
+
console.log("Check iframe...");
|
| 7 |
+
if(frame.setup_shortcut)
|
| 8 |
+
{
|
| 9 |
+
frame.setup_shortcut(json);
|
| 10 |
+
clearInterval(window.my_setup_keyboard);
|
| 11 |
+
}
|
| 12 |
+
}, 1000);
|
| 13 |
+
var config=JSON.parse(json);
|
| 14 |
+
var key_map={};
|
| 15 |
+
Object.keys(config.shortcut).forEach(k=>{
|
| 16 |
+
key_map[config.shortcut[k]]=k;
|
| 17 |
+
});
|
| 18 |
+
document.addEventListener("keydown", e => {
|
| 19 |
+
if(e.target.tagName!="INPUT"&&e.target.tagName!="GRADIO-APP"&&e.target.tagName!="TEXTAREA")
|
| 20 |
+
{
|
| 21 |
+
let key=e.key;
|
| 22 |
+
if(e.ctrlKey)
|
| 23 |
+
{
|
| 24 |
+
key="Ctrl+"+e.key;
|
| 25 |
+
if(key in key_map)
|
| 26 |
+
{
|
| 27 |
+
e.preventDefault();
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
let app=document.querySelector("gradio-app");
|
| 31 |
+
app=app.shadowRoot??app;
|
| 32 |
+
let frame=app.querySelector("#sdinfframe").contentDocument;
|
| 33 |
+
frame.dispatchEvent(
|
| 34 |
+
new KeyboardEvent("keydown", {key: e.key, ctrlKey: e.ctrlKey})
|
| 35 |
+
);
|
| 36 |
+
}
|
| 37 |
+
})
|
js/mode.js
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
function(mode){
|
| 2 |
+
let app=document.querySelector("gradio-app").shadowRoot;
|
| 3 |
+
let frame=app.querySelector("#sdinfframe").contentWindow.document;
|
| 4 |
+
frame.querySelector("#mode").value=mode;
|
| 5 |
+
return mode;
|
| 6 |
+
}
|
js/outpaint.js
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
function(a){
|
| 2 |
+
if(!window.my_observe_outpaint)
|
| 3 |
+
{
|
| 4 |
+
console.log("setup outpaint here");
|
| 5 |
+
window.my_observe_outpaint = new MutationObserver(function (event) {
|
| 6 |
+
console.log(event);
|
| 7 |
+
let app=document.querySelector("gradio-app");
|
| 8 |
+
app=app.shadowRoot??app;
|
| 9 |
+
let frame=app.querySelector("#sdinfframe").contentWindow;
|
| 10 |
+
frame.postMessage(["outpaint", ""], "*");
|
| 11 |
+
});
|
| 12 |
+
var app=document.querySelector("gradio-app");
|
| 13 |
+
app=app.shadowRoot??app;
|
| 14 |
+
window.my_observe_outpaint_target=app.querySelector("#output span");
|
| 15 |
+
window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
|
| 16 |
+
attributes: false,
|
| 17 |
+
subtree: true,
|
| 18 |
+
childList: true,
|
| 19 |
+
characterData: true
|
| 20 |
+
});
|
| 21 |
+
}
|
| 22 |
+
return a;
|
| 23 |
+
}
|
js/proceed.js
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
function(sel_buffer_str,
|
| 2 |
+
prompt_text,
|
| 3 |
+
negative_prompt_text,
|
| 4 |
+
strength,
|
| 5 |
+
guidance,
|
| 6 |
+
step,
|
| 7 |
+
resize_check,
|
| 8 |
+
fill_mode,
|
| 9 |
+
enable_safety,
|
| 10 |
+
use_correction,
|
| 11 |
+
enable_img2img,
|
| 12 |
+
use_seed,
|
| 13 |
+
seed_val,
|
| 14 |
+
generate_num,
|
| 15 |
+
scheduler,
|
| 16 |
+
scheduler_eta,
|
| 17 |
+
controlnet_union,
|
| 18 |
+
expand_mask,
|
| 19 |
+
color,
|
| 20 |
+
scheduler_type,
|
| 21 |
+
prompt_weight,
|
| 22 |
+
image_resolution,
|
| 23 |
+
img_height,
|
| 24 |
+
img_width,
|
| 25 |
+
loraA,
|
| 26 |
+
loraAscale,
|
| 27 |
+
interrogate_mode,
|
| 28 |
+
state){
|
| 29 |
+
let app=document.querySelector("gradio-app");
|
| 30 |
+
app=app.shadowRoot??app;
|
| 31 |
+
sel_buffer=app.querySelector("#input textarea").value;
|
| 32 |
+
let use_correction_bak=false;
|
| 33 |
+
({resize_check,enable_safety,enable_img2img,use_seed,seed_val,interrogate_mode}=window.config_obj);
|
| 34 |
+
seed_val=Number(seed_val);
|
| 35 |
+
return [
|
| 36 |
+
sel_buffer,
|
| 37 |
+
prompt_text,
|
| 38 |
+
negative_prompt_text,
|
| 39 |
+
strength,
|
| 40 |
+
guidance,
|
| 41 |
+
step,
|
| 42 |
+
resize_check,
|
| 43 |
+
fill_mode,
|
| 44 |
+
enable_safety,
|
| 45 |
+
use_correction,
|
| 46 |
+
enable_img2img,
|
| 47 |
+
use_seed,
|
| 48 |
+
seed_val,
|
| 49 |
+
generate_num,
|
| 50 |
+
scheduler,
|
| 51 |
+
scheduler_eta,
|
| 52 |
+
controlnet_union,
|
| 53 |
+
expand_mask,
|
| 54 |
+
color,
|
| 55 |
+
scheduler_type,
|
| 56 |
+
prompt_weight,
|
| 57 |
+
image_resolution,
|
| 58 |
+
img_height,
|
| 59 |
+
img_width,
|
| 60 |
+
loraA,
|
| 61 |
+
loraAscale,
|
| 62 |
+
interrogate_mode,
|
| 63 |
+
state,
|
| 64 |
+
]
|
| 65 |
+
}
|
js/setup.js
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
function(token_val, width, height, size, model_choice, model_path){
|
| 2 |
+
let app=document.querySelector("gradio-app");
|
| 3 |
+
app=app.shadowRoot??app;
|
| 4 |
+
app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
|
| 5 |
+
// app.querySelector("#setup_row").style.display="none";
|
| 6 |
+
app.querySelector("#model_path_input").style.display="none";
|
| 7 |
+
let frame=app.querySelector("#sdinfframe").contentWindow.document;
|
| 8 |
+
|
| 9 |
+
if(frame.querySelector("#setup").value=="0")
|
| 10 |
+
{
|
| 11 |
+
window.my_setup=setInterval(function(){
|
| 12 |
+
let app=document.querySelector("gradio-app");
|
| 13 |
+
app=app.shadowRoot??app;
|
| 14 |
+
let frame=app.querySelector("#sdinfframe").contentWindow.document;
|
| 15 |
+
console.log("Check PyScript...")
|
| 16 |
+
if(frame.querySelector("#setup").value=="1")
|
| 17 |
+
{
|
| 18 |
+
frame.querySelector("#draw").click();
|
| 19 |
+
clearInterval(window.my_setup);
|
| 20 |
+
}
|
| 21 |
+
}, 100)
|
| 22 |
+
}
|
| 23 |
+
else
|
| 24 |
+
{
|
| 25 |
+
frame.querySelector("#draw").click();
|
| 26 |
+
}
|
| 27 |
+
return [token_val, width, height, size, model_choice, model_path];
|
| 28 |
+
}
|
js/toolbar.js
ADDED
|
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://rawgit.com/vitmalina/w2ui/master/dist/w2ui.es6.min.js"
|
| 2 |
+
// import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://cdn.jsdelivr.net/gh/vitmalina/w2ui@master/dist/w2ui.es6.min.js"
|
| 3 |
+
|
| 4 |
+
// https://stackoverflow.com/questions/36280818/how-to-convert-file-to-base64-in-javascript
|
| 5 |
+
function getBase64(file) {
|
| 6 |
+
var reader = new FileReader();
|
| 7 |
+
reader.readAsDataURL(file);
|
| 8 |
+
reader.onload = function () {
|
| 9 |
+
add_image(reader.result);
|
| 10 |
+
// console.log(reader.result);
|
| 11 |
+
};
|
| 12 |
+
reader.onerror = function (error) {
|
| 13 |
+
console.log("Error: ", error);
|
| 14 |
+
};
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
function getText(file) {
|
| 18 |
+
var reader = new FileReader();
|
| 19 |
+
reader.readAsText(file);
|
| 20 |
+
reader.onload = function () {
|
| 21 |
+
window.postMessage(["load",reader.result],"*")
|
| 22 |
+
// console.log(reader.result);
|
| 23 |
+
};
|
| 24 |
+
reader.onerror = function (error) {
|
| 25 |
+
console.log("Error: ", error);
|
| 26 |
+
};
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
document.querySelector("#upload_file").addEventListener("change", (event)=>{
|
| 30 |
+
console.log(event);
|
| 31 |
+
let file = document.querySelector("#upload_file").files[0];
|
| 32 |
+
getBase64(file);
|
| 33 |
+
})
|
| 34 |
+
|
| 35 |
+
document.querySelector("#upload_state").addEventListener("change", (event)=>{
|
| 36 |
+
console.log(event);
|
| 37 |
+
let file = document.querySelector("#upload_state").files[0];
|
| 38 |
+
getText(file);
|
| 39 |
+
})
|
| 40 |
+
|
| 41 |
+
open_setting = function() {
|
| 42 |
+
if (!w2ui.foo) {
|
| 43 |
+
new w2form({
|
| 44 |
+
name: "foo",
|
| 45 |
+
style: "border: 0px; background-color: transparent;",
|
| 46 |
+
fields: [{
|
| 47 |
+
field: "canvas_width",
|
| 48 |
+
type: "int",
|
| 49 |
+
required: true,
|
| 50 |
+
html: {
|
| 51 |
+
label: "Canvas Width"
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
field: "canvas_height",
|
| 56 |
+
type: "int",
|
| 57 |
+
required: true,
|
| 58 |
+
html: {
|
| 59 |
+
label: "Canvas Height"
|
| 60 |
+
}
|
| 61 |
+
},
|
| 62 |
+
],
|
| 63 |
+
record: {
|
| 64 |
+
canvas_width: 1200,
|
| 65 |
+
canvas_height: 600,
|
| 66 |
+
},
|
| 67 |
+
actions: {
|
| 68 |
+
Save() {
|
| 69 |
+
this.validate();
|
| 70 |
+
let record = this.getCleanRecord();
|
| 71 |
+
window.postMessage(["resize",record.canvas_width,record.canvas_height],"*");
|
| 72 |
+
w2popup.close();
|
| 73 |
+
},
|
| 74 |
+
custom: {
|
| 75 |
+
text: "Cancel",
|
| 76 |
+
style: "text-transform: uppercase",
|
| 77 |
+
onClick(event) {
|
| 78 |
+
w2popup.close();
|
| 79 |
+
}
|
| 80 |
+
}
|
| 81 |
+
}
|
| 82 |
+
});
|
| 83 |
+
}
|
| 84 |
+
w2popup.open({
|
| 85 |
+
title: "Form in a Popup",
|
| 86 |
+
body: "<div id='form' style='width: 100%; height: 100%;''></div>",
|
| 87 |
+
style: "padding: 15px 0px 0px 0px",
|
| 88 |
+
width: 500,
|
| 89 |
+
height: 280,
|
| 90 |
+
showMax: true,
|
| 91 |
+
async onToggle(event) {
|
| 92 |
+
await event.complete
|
| 93 |
+
w2ui.foo.resize();
|
| 94 |
+
}
|
| 95 |
+
})
|
| 96 |
+
.then((event) => {
|
| 97 |
+
w2ui.foo.render("#form")
|
| 98 |
+
});
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
var button_lst=["clear", "load", "save", "export", "upload", "selection", "canvas", "eraser", "outpaint", "accept", "cancel", "retry", "prev", "current", "next", "eraser_size_btn", "eraser_size", "resize_selection", "scale", "zoom_in", "zoom_out", "help"];
|
| 102 |
+
var upload_button_lst=['clear', 'load', 'save', "upload", 'export', 'outpaint', 'resize_selection', 'help', "setting", "interrogate"];
|
| 103 |
+
var resize_button_lst=['clear', 'load', 'save', "upload", 'export', "selection", "canvas", "eraser", 'outpaint', 'resize_selection',"zoom_in", "zoom_out", 'help', "setting", "interrogate"];
|
| 104 |
+
var outpaint_button_lst=['clear', 'load', 'save', "canvas", "eraser", "upload", 'export', 'resize_selection', "zoom_in", "zoom_out",'help', "setting", "interrogate", "undo", "redo"];
|
| 105 |
+
var outpaint_result_lst=["accept", "cancel", "retry", "prev", "current", "next"];
|
| 106 |
+
var outpaint_result_func_lst=["accept", "retry", "prev", "current", "next"];
|
| 107 |
+
|
| 108 |
+
function check_button(id,text="",checked=true,tooltip="")
|
| 109 |
+
{
|
| 110 |
+
return { type: "check", id: id, text: text, icon: checked?"fa-solid fa-square-check":"fa-regular fa-square", checked: checked, tooltip: tooltip };
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
var toolbar=new w2toolbar({
|
| 114 |
+
box: "#toolbar",
|
| 115 |
+
name: "toolbar",
|
| 116 |
+
tooltip: "top",
|
| 117 |
+
items: [
|
| 118 |
+
{ type: "button", id: "clear", text: "Reset", tooltip: "Reset Canvas", icon: "fa-solid fa-rectangle-xmark" },
|
| 119 |
+
{ type: "break" },
|
| 120 |
+
{ type: "button", id: "load", tooltip: "Load Canvas", icon: "fa-solid fa-file-import" },
|
| 121 |
+
{ type: "button", id: "save", tooltip: "Save Canvas", icon: "fa-solid fa-file-export" },
|
| 122 |
+
{ type: "button", id: "export", tooltip: "Export Image", icon: "fa-solid fa-floppy-disk" },
|
| 123 |
+
{ type: "break" },
|
| 124 |
+
{ type: "button", id: "upload", text: "Upload Image", icon: "fa-solid fa-upload" },
|
| 125 |
+
{ type: "break" },
|
| 126 |
+
{ type: "radio", id: "selection", group: "1", tooltip: "Selection", icon: "fa-solid fa-arrows-up-down-left-right", checked: true },
|
| 127 |
+
{ type: "radio", id: "canvas", group: "1", tooltip: "Canvas", icon: "fa-solid fa-image" },
|
| 128 |
+
{ type: "radio", id: "eraser", group: "1", tooltip: "Eraser", icon: "fa-solid fa-eraser" },
|
| 129 |
+
{ type: "break" },
|
| 130 |
+
{ type: "button", id: "outpaint", text: "Outpaint", tooltip: "Run Outpainting", icon: "fa-solid fa-brush" },
|
| 131 |
+
{ type: "button", id: "interrogate", text: "Interrogate", tooltip: "Get a prompt with Clip Interrogator ", icon: "fa-solid fa-magnifying-glass" },
|
| 132 |
+
{ type: "break" },
|
| 133 |
+
{ type: "button", id: "accept", text: "Accept", tooltip: "Accept current result", icon: "fa-solid fa-check", hidden: true, disabled:true,},
|
| 134 |
+
{ type: "button", id: "cancel", text: "Cancel", tooltip: "Cancel current outpainting/error", icon: "fa-solid fa-ban", hidden: true},
|
| 135 |
+
{ type: "button", id: "retry", text: "Retry", tooltip: "Retry", icon: "fa-solid fa-rotate", hidden: true, disabled:true,},
|
| 136 |
+
{ type: "button", id: "prev", tooltip: "Prev Result", icon: "fa-solid fa-caret-left", hidden: true, disabled:true,},
|
| 137 |
+
{ type: "html", id: "current", hidden: true, disabled:true,
|
| 138 |
+
async onRefresh(event) {
|
| 139 |
+
await event.complete
|
| 140 |
+
let fragment = query.html(`
|
| 141 |
+
<div class="w2ui-tb-text">
|
| 142 |
+
<div class="w2ui-tb-count">
|
| 143 |
+
<span>${this.sel_value ?? "1/1"}</span>
|
| 144 |
+
</div> </div>`)
|
| 145 |
+
query(this.box).find("#tb_toolbar_item_current").append(fragment)
|
| 146 |
+
}
|
| 147 |
+
},
|
| 148 |
+
{ type: "button", id: "next", tooltip: "Next Result", icon: "fa-solid fa-caret-right", hidden: true,disabled:true,},
|
| 149 |
+
{ type: "button", id: "add_image", text: "Add Image", icon: "fa-solid fa-file-circle-plus", hidden: true,disabled:true,},
|
| 150 |
+
{ type: "button", id: "delete_image", text: "Delete Image", icon: "fa-solid fa-trash-can", hidden: true,disabled:true,},
|
| 151 |
+
{ type: "button", id: "confirm", text: "Confirm", icon: "fa-solid fa-check", hidden: true,disabled:true,},
|
| 152 |
+
{ type: "button", id: "cancel_overlay", text: "Cancel", icon: "fa-solid fa-ban", hidden: true,disabled:true,},
|
| 153 |
+
{ type: "break" },
|
| 154 |
+
{ type: "spacer" },
|
| 155 |
+
{ type: "break" },
|
| 156 |
+
{ type: "button", id: "eraser_size_btn", tooltip: "Eraser Size", text:"Size", icon: "fa-solid fa-eraser", hidden: true, count: 32},
|
| 157 |
+
{ type: "html", id: "eraser_size", hidden: true,
|
| 158 |
+
async onRefresh(event) {
|
| 159 |
+
await event.complete
|
| 160 |
+
// let fragment = query.html(`
|
| 161 |
+
// <input type="number" size="${this.eraser_size ? this.eraser_size.length:"2"}" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">
|
| 162 |
+
// <input type="range" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">`)
|
| 163 |
+
let fragment = query.html(`
|
| 164 |
+
<input type="range" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">
|
| 165 |
+
`)
|
| 166 |
+
fragment.filter("input").on("change", event => {
|
| 167 |
+
this.eraser_size = event.target.value;
|
| 168 |
+
window.overlay.freeDrawingBrush.width=this.eraser_size;
|
| 169 |
+
this.setCount("eraser_size_btn", event.target.value);
|
| 170 |
+
window.postMessage(["eraser_size", event.target.value],"*")
|
| 171 |
+
this.refresh();
|
| 172 |
+
})
|
| 173 |
+
query(this.box).find("#tb_toolbar_item_eraser_size").append(fragment)
|
| 174 |
+
}
|
| 175 |
+
},
|
| 176 |
+
// { type: "button", id: "resize_eraser", tooltip: "Resize Eraser", icon: "fa-solid fa-sliders" },
|
| 177 |
+
{ type: "button", id: "resize_selection", text: "Resize Selection", tooltip: "Resize Selection", icon: "fa-solid fa-expand" },
|
| 178 |
+
{ type: "break" },
|
| 179 |
+
{ type: "html", id: "scale",
|
| 180 |
+
async onRefresh(event) {
|
| 181 |
+
await event.complete
|
| 182 |
+
let fragment = query.html(`
|
| 183 |
+
<div class="">
|
| 184 |
+
<div style="padding: 4px; border: 1px solid silver">
|
| 185 |
+
<span>${this.scale_value ?? "100%"}</span>
|
| 186 |
+
</div></div>`)
|
| 187 |
+
query(this.box).find("#tb_toolbar_item_scale").append(fragment)
|
| 188 |
+
}
|
| 189 |
+
},
|
| 190 |
+
{ type: "button", id: "zoom_in", tooltip: "Zoom In", icon: "fa-solid fa-magnifying-glass-plus" },
|
| 191 |
+
{ type: "button", id: "zoom_out", tooltip: "Zoom Out", icon: "fa-solid fa-magnifying-glass-minus" },
|
| 192 |
+
{ type: "break" },
|
| 193 |
+
{ type: "button", id: "help", tooltip: "Help", icon: "fa-solid fa-circle-info" },
|
| 194 |
+
{ type: "new-line"},
|
| 195 |
+
{ type: "button", id: "setting", text: "Canvas Setting", tooltip: "Resize Canvas Here", icon: "fa-solid fa-sliders" },
|
| 196 |
+
{ type: "break" },
|
| 197 |
+
check_button("enable_history","Enable History:",false, "Enable Canvas History"),
|
| 198 |
+
{ type: "button", id: "undo", tooltip: "Undo last erasing/last outpainting", icon: "fa-solid fa-rotate-left", disabled: true },
|
| 199 |
+
{ type: "button", id: "redo", tooltip: "Redo", icon: "fa-solid fa-rotate-right", disabled: true },
|
| 200 |
+
{ type: "break" },
|
| 201 |
+
check_button("enable_img2img","Enable Img2Img",false),
|
| 202 |
+
// check_button("use_correction","Photometric Correction",false),
|
| 203 |
+
check_button("resize_check","Resize Small Input",true),
|
| 204 |
+
check_button("enable_safety","Enable Safety Checker",true),
|
| 205 |
+
check_button("square_selection","Square Selection Only",false),
|
| 206 |
+
{type: "break"},
|
| 207 |
+
check_button("use_seed","Use Seed:",false),
|
| 208 |
+
{ type: "html", id: "seed_val",
|
| 209 |
+
async onRefresh(event) {
|
| 210 |
+
await event.complete
|
| 211 |
+
let fragment = query.html(`
|
| 212 |
+
<input type="number" style="margin: 0px 3px; padding: 4px; width:100px;" value="${this.config_obj.seed_val ?? "0"}">`)
|
| 213 |
+
fragment.filter("input").on("change", event => {
|
| 214 |
+
this.config_obj.seed_val = event.target.value;
|
| 215 |
+
parent.config_obj=this.config_obj;
|
| 216 |
+
this.refresh();
|
| 217 |
+
})
|
| 218 |
+
query(this.box).find("#tb_toolbar_item_seed_val").append(fragment)
|
| 219 |
+
}
|
| 220 |
+
},
|
| 221 |
+
{ type: "button", id: "random_seed", tooltip: "Set a random seed", icon: "fa-solid fa-dice" },
|
| 222 |
+
],
|
| 223 |
+
onClick(event) {
|
| 224 |
+
switch(event.target){
|
| 225 |
+
case "setting":
|
| 226 |
+
open_setting();
|
| 227 |
+
break;
|
| 228 |
+
case "upload":
|
| 229 |
+
this.upload_mode=true
|
| 230 |
+
document.querySelector("#overlay_container").style.pointerEvents="auto";
|
| 231 |
+
this.click("canvas");
|
| 232 |
+
this.click("selection");
|
| 233 |
+
this.show("confirm","cancel_overlay","add_image","delete_image");
|
| 234 |
+
this.enable("confirm","cancel_overlay","add_image","delete_image");
|
| 235 |
+
this.disable(...upload_button_lst);
|
| 236 |
+
this.disable("undo","redo")
|
| 237 |
+
query("#upload_file").click();
|
| 238 |
+
if(this.upload_tip)
|
| 239 |
+
{
|
| 240 |
+
this.upload_tip=false;
|
| 241 |
+
w2utils.notify("Note that only visible images will be added to canvas",{timeout:10000,where:query("#container")})
|
| 242 |
+
}
|
| 243 |
+
break;
|
| 244 |
+
case "resize_selection":
|
| 245 |
+
this.resize_mode=true;
|
| 246 |
+
this.disable(...resize_button_lst);
|
| 247 |
+
this.enable("confirm","cancel_overlay");
|
| 248 |
+
this.show("confirm","cancel_overlay");
|
| 249 |
+
window.postMessage(["resize_selection",""],"*");
|
| 250 |
+
document.querySelector("#overlay_container").style.pointerEvents="auto";
|
| 251 |
+
break;
|
| 252 |
+
case "confirm":
|
| 253 |
+
if(this.upload_mode)
|
| 254 |
+
{
|
| 255 |
+
export_image();
|
| 256 |
+
}
|
| 257 |
+
else
|
| 258 |
+
{
|
| 259 |
+
let sel_box=this.selection_box;
|
| 260 |
+
if(sel_box.width*sel_box.height>512*512)
|
| 261 |
+
{
|
| 262 |
+
w2utils.notify("Note that the outpainting will be much slower when the area of selection is larger than 512x512",{timeout:2000,where:query("#container")})
|
| 263 |
+
}
|
| 264 |
+
window.postMessage(["resize_selection",sel_box.x,sel_box.y,sel_box.width,sel_box.height],"*");
|
| 265 |
+
}
|
| 266 |
+
case "cancel_overlay":
|
| 267 |
+
end_overlay();
|
| 268 |
+
this.hide("confirm","cancel_overlay","add_image","delete_image");
|
| 269 |
+
if(this.upload_mode){
|
| 270 |
+
this.enable(...upload_button_lst);
|
| 271 |
+
}
|
| 272 |
+
else
|
| 273 |
+
{
|
| 274 |
+
this.enable(...resize_button_lst);
|
| 275 |
+
window.postMessage(["resize_selection","",""],"*");
|
| 276 |
+
if(event.target=="cancel_overlay")
|
| 277 |
+
{
|
| 278 |
+
this.selection_box=this.selection_box_bak;
|
| 279 |
+
}
|
| 280 |
+
}
|
| 281 |
+
if(this.selection_box)
|
| 282 |
+
{
|
| 283 |
+
this.setCount("resize_selection",`${Math.floor(this.selection_box.width/8)*8}x${Math.floor(this.selection_box.height/8)*8}`);
|
| 284 |
+
}
|
| 285 |
+
this.disable("confirm","cancel_overlay","add_image","delete_image");
|
| 286 |
+
this.upload_mode=false;
|
| 287 |
+
this.resize_mode=false;
|
| 288 |
+
this.click("selection");
|
| 289 |
+
window.update_undo_redo(window.undo_redo_state.undo, window.undo_redo_state.redo);
|
| 290 |
+
break;
|
| 291 |
+
case "add_image":
|
| 292 |
+
query("#upload_file").click();
|
| 293 |
+
break;
|
| 294 |
+
case "delete_image":
|
| 295 |
+
let active_obj = window.overlay.getActiveObject();
|
| 296 |
+
if(active_obj)
|
| 297 |
+
{
|
| 298 |
+
window.overlay.remove(active_obj);
|
| 299 |
+
window.overlay.renderAll();
|
| 300 |
+
}
|
| 301 |
+
else
|
| 302 |
+
{
|
| 303 |
+
w2utils.notify("You need to select an image first",{error:true,timeout:2000,where:query("#container")})
|
| 304 |
+
}
|
| 305 |
+
break;
|
| 306 |
+
case "load":
|
| 307 |
+
query("#upload_state").click();
|
| 308 |
+
this.selection_box=null;
|
| 309 |
+
this.setCount("resize_selection","");
|
| 310 |
+
break;
|
| 311 |
+
case "next":
|
| 312 |
+
case "prev":
|
| 313 |
+
window.postMessage(["outpaint", "", event.target], "*");
|
| 314 |
+
break;
|
| 315 |
+
case "outpaint":
|
| 316 |
+
this.click("selection");
|
| 317 |
+
this.disable(...outpaint_button_lst);
|
| 318 |
+
this.show(...outpaint_result_lst);
|
| 319 |
+
this.disable("undo","redo");
|
| 320 |
+
if(this.outpaint_tip)
|
| 321 |
+
{
|
| 322 |
+
this.outpaint_tip=false;
|
| 323 |
+
w2utils.notify("The canvas stays locked until you accept/cancel current outpainting. You can modify the 'sample number' to get multiple results; you can resize the canvas/selection with 'canvas setting'/'resize selection'; you can use 'photometric correction' to help preserve existing contents",{timeout:15000,where:query("#container")})
|
| 324 |
+
}
|
| 325 |
+
document.querySelector("#container").style.pointerEvents="none";
|
| 326 |
+
case "retry":
|
| 327 |
+
this.disable(...outpaint_result_func_lst);
|
| 328 |
+
parent.config_obj["interrogate_mode"]=false;
|
| 329 |
+
window.postMessage(["transfer",""],"*")
|
| 330 |
+
break;
|
| 331 |
+
case "interrogate":
|
| 332 |
+
if(this.interrogate_tip)
|
| 333 |
+
{
|
| 334 |
+
this.interrogate_tip=false;
|
| 335 |
+
w2utils.notify("ClipInterrogator v1 will be dynamically loaded when run at the first time, which may take a while",{timeout:10000,where:query("#container")})
|
| 336 |
+
}
|
| 337 |
+
parent.config_obj["interrogate_mode"]=true;
|
| 338 |
+
window.postMessage(["transfer",""],"*")
|
| 339 |
+
break
|
| 340 |
+
case "accept":
|
| 341 |
+
case "cancel":
|
| 342 |
+
this.hide(...outpaint_result_lst);
|
| 343 |
+
this.disable(...outpaint_result_func_lst);
|
| 344 |
+
this.enable(...outpaint_button_lst);
|
| 345 |
+
document.querySelector("#container").style.pointerEvents = "auto";
|
| 346 |
+
|
| 347 |
+
if (this.config_obj.enable_history) {
|
| 348 |
+
window.postMessage(["click", event.target, ""], "*");
|
| 349 |
+
} else {
|
| 350 |
+
window.postMessage(["click", event.target], "*");
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
// Instead of directly accessing shadowRoot, send a message to iframe
|
| 354 |
+
let frame = parent.document.querySelector("#sdinfframe");
|
| 355 |
+
if (frame && frame.contentWindow) {
|
| 356 |
+
frame.contentWindow.postMessage(["click", "cancel"], "*");
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
window.update_undo_redo(
|
| 360 |
+
window.undo_redo_state.undo,
|
| 361 |
+
window.undo_redo_state.redo
|
| 362 |
+
);
|
| 363 |
+
break;
|
| 364 |
+
case "eraser":
|
| 365 |
+
case "selection":
|
| 366 |
+
case "canvas":
|
| 367 |
+
if(event.target=="eraser")
|
| 368 |
+
{
|
| 369 |
+
this.show("eraser_size","eraser_size_btn");
|
| 370 |
+
window.overlay.freeDrawingBrush.width=this.eraser_size;
|
| 371 |
+
window.overlay.isDrawingMode = true;
|
| 372 |
+
}
|
| 373 |
+
else
|
| 374 |
+
{
|
| 375 |
+
this.hide("eraser_size","eraser_size_btn");
|
| 376 |
+
window.overlay.isDrawingMode = false;
|
| 377 |
+
}
|
| 378 |
+
if(this.upload_mode)
|
| 379 |
+
{
|
| 380 |
+
if(event.target=="canvas")
|
| 381 |
+
{
|
| 382 |
+
window.postMessage(["mode", event.target],"*")
|
| 383 |
+
document.querySelector("#overlay_container").style.pointerEvents="none";
|
| 384 |
+
document.querySelector("#overlay_container").style.opacity = 0.5;
|
| 385 |
+
}
|
| 386 |
+
else
|
| 387 |
+
{
|
| 388 |
+
document.querySelector("#overlay_container").style.pointerEvents="auto";
|
| 389 |
+
document.querySelector("#overlay_container").style.opacity = 1.0;
|
| 390 |
+
}
|
| 391 |
+
}
|
| 392 |
+
else
|
| 393 |
+
{
|
| 394 |
+
window.postMessage(["mode", event.target],"*")
|
| 395 |
+
}
|
| 396 |
+
break;
|
| 397 |
+
case "help":
|
| 398 |
+
w2popup.open({
|
| 399 |
+
title: "Document",
|
| 400 |
+
body: "Usage: <a href='https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md' target='_blank'>https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md</a>"
|
| 401 |
+
})
|
| 402 |
+
break;
|
| 403 |
+
case "clear":
|
| 404 |
+
w2confirm("Reset canvas?").yes(() => {
|
| 405 |
+
window.postMessage(["click", event.target],"*");
|
| 406 |
+
}).no(() => {})
|
| 407 |
+
break;
|
| 408 |
+
case "random_seed":
|
| 409 |
+
this.config_obj.seed_val=Math.floor(Math.random() * 3000000000);
|
| 410 |
+
parent.config_obj=this.config_obj;
|
| 411 |
+
this.refresh();
|
| 412 |
+
break;
|
| 413 |
+
case "enable_history":
|
| 414 |
+
case "enable_img2img":
|
| 415 |
+
case "use_correction":
|
| 416 |
+
case "resize_check":
|
| 417 |
+
case "enable_safety":
|
| 418 |
+
case "use_seed":
|
| 419 |
+
case "square_selection":
|
| 420 |
+
let target=this.get(event.target);
|
| 421 |
+
if(event.target=="enable_history")
|
| 422 |
+
{
|
| 423 |
+
if(!target.checked)
|
| 424 |
+
{
|
| 425 |
+
w2utils.notify("Enable canvas history might increase resource usage / slow down the canvas ", {error:true,timeout:3000,where:query("#container")})
|
| 426 |
+
window.postMessage(["click","history"],"*");
|
| 427 |
+
}
|
| 428 |
+
else
|
| 429 |
+
{
|
| 430 |
+
window.undo_redo_state.undo=false;
|
| 431 |
+
window.undo_redo_state.redo=false;
|
| 432 |
+
this.disable("undo","redo");
|
| 433 |
+
}
|
| 434 |
+
}
|
| 435 |
+
target.icon=target.checked?"fa-regular fa-square":"fa-solid fa-square-check";
|
| 436 |
+
this.config_obj[event.target]=!target.checked;
|
| 437 |
+
parent.config_obj=this.config_obj;
|
| 438 |
+
this.refresh();
|
| 439 |
+
break;
|
| 440 |
+
case "save":
|
| 441 |
+
case "export":
|
| 442 |
+
ask_filename(event.target);
|
| 443 |
+
break;
|
| 444 |
+
default:
|
| 445 |
+
// clear, save, export, outpaint, retry
|
| 446 |
+
// break, save, export, accept, retry, outpaint
|
| 447 |
+
window.postMessage(["click", event.target],"*")
|
| 448 |
+
}
|
| 449 |
+
console.log("Target: "+ event.target, event)
|
| 450 |
+
}
|
| 451 |
+
})
|
| 452 |
+
window.w2ui=w2ui;
|
| 453 |
+
w2ui.toolbar.config_obj={
|
| 454 |
+
resize_check: true,
|
| 455 |
+
enable_safety: true,
|
| 456 |
+
use_correction: false,
|
| 457 |
+
enable_img2img: false,
|
| 458 |
+
use_seed: false,
|
| 459 |
+
seed_val: 0,
|
| 460 |
+
square_selection: false,
|
| 461 |
+
enable_history: false,
|
| 462 |
+
};
|
| 463 |
+
w2ui.toolbar.outpaint_tip=true;
|
| 464 |
+
w2ui.toolbar.upload_tip=true;
|
| 465 |
+
w2ui.toolbar.interrogate_tip=true;
|
| 466 |
+
window.update_count=function(cur,total){
|
| 467 |
+
w2ui.toolbar.sel_value=`${cur}/${total}`;
|
| 468 |
+
w2ui.toolbar.refresh();
|
| 469 |
+
}
|
| 470 |
+
window.update_eraser=function(val,max_val){
|
| 471 |
+
w2ui.toolbar.eraser_size=`${val}`;
|
| 472 |
+
w2ui.toolbar.eraser_max=`${max_val}`;
|
| 473 |
+
w2ui.toolbar.setCount("eraser_size_btn", `${val}`);
|
| 474 |
+
w2ui.toolbar.refresh();
|
| 475 |
+
}
|
| 476 |
+
window.update_scale=function(val){
|
| 477 |
+
w2ui.toolbar.scale_value=`${val}`;
|
| 478 |
+
w2ui.toolbar.refresh();
|
| 479 |
+
}
|
| 480 |
+
window.enable_result_lst=function(){
|
| 481 |
+
w2ui.toolbar.enable(...outpaint_result_lst);
|
| 482 |
+
}
|
| 483 |
+
function onObjectScaled(e)
|
| 484 |
+
{
|
| 485 |
+
let object = e.target;
|
| 486 |
+
if(object.isType("rect"))
|
| 487 |
+
{
|
| 488 |
+
let width=object.getScaledWidth();
|
| 489 |
+
let height=object.getScaledHeight();
|
| 490 |
+
object.scale(1);
|
| 491 |
+
width=Math.max(Math.min(width,window.overlay.width-object.left),256);
|
| 492 |
+
height=Math.max(Math.min(height,window.overlay.height-object.top),256);
|
| 493 |
+
let l=Math.max(Math.min(object.left,window.overlay.width-width-object.strokeWidth),0);
|
| 494 |
+
let t=Math.max(Math.min(object.top,window.overlay.height-height-object.strokeWidth),0);
|
| 495 |
+
if(window.w2ui.toolbar.config_obj.square_selection)
|
| 496 |
+
{
|
| 497 |
+
let max_val = Math.min(Math.max(width,height),window.overlay.width,window.overlay.height);
|
| 498 |
+
width=max_val;
|
| 499 |
+
height=max_val;
|
| 500 |
+
}
|
| 501 |
+
object.set({ width: width, height: height, left:l,top:t})
|
| 502 |
+
window.w2ui.toolbar.selection_box={width: width, height: height, x:object.left, y:object.top};
|
| 503 |
+
window.w2ui.toolbar.setCount("resize_selection",`${Math.floor(width/8)*8}x${Math.floor(height/8)*8}`);
|
| 504 |
+
window.w2ui.toolbar.refresh();
|
| 505 |
+
}
|
| 506 |
+
}
|
| 507 |
+
function onObjectMoved(e)
|
| 508 |
+
{
|
| 509 |
+
let object = e.target;
|
| 510 |
+
if(object.isType("rect"))
|
| 511 |
+
{
|
| 512 |
+
let l=Math.max(Math.min(object.left,window.overlay.width-object.width-object.strokeWidth),0);
|
| 513 |
+
let t=Math.max(Math.min(object.top,window.overlay.height-object.height-object.strokeWidth),0);
|
| 514 |
+
object.set({left:l,top:t});
|
| 515 |
+
window.w2ui.toolbar.selection_box={width: object.width, height: object.height, x:object.left, y:object.top};
|
| 516 |
+
}
|
| 517 |
+
}
|
| 518 |
+
window.setup_overlay = function (width, height) {
|
| 519 |
+
if (window.overlay) {
|
| 520 |
+
window.overlay.setDimensions({ width: width, height: height });
|
| 521 |
+
|
| 522 |
+
// Find iframe safely
|
| 523 |
+
let frameEl = parent.document.querySelector("#sdinfframe");
|
| 524 |
+
if (frameEl) {
|
| 525 |
+
frameEl.style.height = 80 + Number(height) + "px";
|
| 526 |
+
} else {
|
| 527 |
+
console.warn("#sdinfframe not found in parent document.");
|
| 528 |
+
}
|
| 529 |
+
|
| 530 |
+
document.querySelector("#container").style.height = height + "px";
|
| 531 |
+
document.querySelector("#container").style.width = width + "px";
|
| 532 |
+
} else {
|
| 533 |
+
canvas = new fabric.Canvas("overlay_canvas");
|
| 534 |
+
canvas.setDimensions({ width: width, height: height });
|
| 535 |
+
|
| 536 |
+
// Find iframe safely
|
| 537 |
+
let frameEl = parent.document.querySelector("#sdinfframe");
|
| 538 |
+
if (frameEl) {
|
| 539 |
+
frameEl.style.height = 80 + Number(height) + "px";
|
| 540 |
+
} else {
|
| 541 |
+
console.warn("#sdinfframe not found in parent document.");
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
canvas.freeDrawingBrush = new fabric.EraserBrush(canvas);
|
| 545 |
+
canvas.on("object:scaling", onObjectScaled);
|
| 546 |
+
canvas.on("object:moving", onObjectMoved);
|
| 547 |
+
window.overlay = canvas;
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
document.querySelector("#overlay_container").style.pointerEvents = "none";
|
| 551 |
+
};
|
| 552 |
+
|
| 553 |
+
window.update_overlay=function(width,height)
|
| 554 |
+
{
|
| 555 |
+
window.overlay.setDimensions({width:width,height:height},{backstoreOnly:true});
|
| 556 |
+
// document.querySelector("#overlay_container").style.pointerEvents="none";
|
| 557 |
+
}
|
| 558 |
+
window.adjust_selection=function(x,y,width,height)
|
| 559 |
+
{
|
| 560 |
+
var rect = new fabric.Rect({
|
| 561 |
+
left: x,
|
| 562 |
+
top: y,
|
| 563 |
+
fill: "rgba(0,0,0,0)",
|
| 564 |
+
strokeWidth: 3,
|
| 565 |
+
stroke: "rgba(0,0,0,0.7)",
|
| 566 |
+
cornerColor: "red",
|
| 567 |
+
cornerStrokeColor: "red",
|
| 568 |
+
borderColor: "rgba(255, 0, 0, 1.0)",
|
| 569 |
+
width: width,
|
| 570 |
+
height: height,
|
| 571 |
+
lockRotation: true,
|
| 572 |
+
});
|
| 573 |
+
rect.setControlsVisibility({ mtr: false });
|
| 574 |
+
window.overlay.add(rect);
|
| 575 |
+
window.overlay.setActiveObject(window.overlay.item(0));
|
| 576 |
+
window.w2ui.toolbar.selection_box={width: width, height: height, x:x, y:y};
|
| 577 |
+
window.w2ui.toolbar.selection_box_bak={width: width, height: height, x:x, y:y};
|
| 578 |
+
}
|
| 579 |
+
function add_image(url)
|
| 580 |
+
{
|
| 581 |
+
fabric.Image.fromURL(url,function(img){
|
| 582 |
+
window.overlay.add(img);
|
| 583 |
+
window.overlay.setActiveObject(img);
|
| 584 |
+
},{left:100,top:100});
|
| 585 |
+
}
|
| 586 |
+
function export_image()
|
| 587 |
+
{
|
| 588 |
+
data=window.overlay.toDataURL();
|
| 589 |
+
document.querySelector("#upload_content").value=data;
|
| 590 |
+
if(window.w2ui.toolbar.config_obj.enable_history)
|
| 591 |
+
{
|
| 592 |
+
window.postMessage(["upload","",""],"*");
|
| 593 |
+
window.w2ui.toolbar.enable("undo");
|
| 594 |
+
window.w2ui.toolbar.disable("redo");
|
| 595 |
+
}
|
| 596 |
+
else
|
| 597 |
+
{
|
| 598 |
+
window.postMessage(["upload",""],"*");
|
| 599 |
+
}
|
| 600 |
+
end_overlay();
|
| 601 |
+
}
|
| 602 |
+
function end_overlay()
|
| 603 |
+
{
|
| 604 |
+
window.overlay.clear();
|
| 605 |
+
document.querySelector("#overlay_container").style.opacity = 1.0;
|
| 606 |
+
document.querySelector("#overlay_container").style.pointerEvents="none";
|
| 607 |
+
}
|
| 608 |
+
function ask_filename(target)
|
| 609 |
+
{
|
| 610 |
+
w2prompt({
|
| 611 |
+
label: "Enter filename",
|
| 612 |
+
value: `outpaint_${((new Date(Date.now() -(new Date()).getTimezoneOffset() * 60000))).toISOString().replace("T","_").replace(/[^0-9_]/g, "").substring(0,15)}`,
|
| 613 |
+
})
|
| 614 |
+
.change((event) => {
|
| 615 |
+
console.log("change", event.detail.originalEvent.target.value);
|
| 616 |
+
})
|
| 617 |
+
.ok((event) => {
|
| 618 |
+
console.log("value=", event.detail.value);
|
| 619 |
+
window.postMessage(["click",target,event.detail.value],"*");
|
| 620 |
+
})
|
| 621 |
+
.cancel((event) => {
|
| 622 |
+
console.log("cancel");
|
| 623 |
+
});
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
document.querySelector("#container").addEventListener("wheel",(e)=>{e.preventDefault()})
|
| 627 |
+
window.setup_shortcut=function(json)
|
| 628 |
+
{
|
| 629 |
+
var config=JSON.parse(json);
|
| 630 |
+
var key_map={};
|
| 631 |
+
Object.keys(config.shortcut).forEach(k=>{
|
| 632 |
+
key_map[config.shortcut[k]]=k;
|
| 633 |
+
})
|
| 634 |
+
document.addEventListener("keydown",(e)=>{
|
| 635 |
+
if(e.target.tagName!="INPUT")
|
| 636 |
+
{
|
| 637 |
+
let key=e.key;
|
| 638 |
+
if(e.ctrlKey)
|
| 639 |
+
{
|
| 640 |
+
key="Ctrl+"+e.key;
|
| 641 |
+
if(key in key_map)
|
| 642 |
+
{
|
| 643 |
+
e.preventDefault();
|
| 644 |
+
}
|
| 645 |
+
}
|
| 646 |
+
if(key in key_map)
|
| 647 |
+
{
|
| 648 |
+
w2ui.toolbar.click(key_map[key]);
|
| 649 |
+
}
|
| 650 |
+
}
|
| 651 |
+
})
|
| 652 |
+
}
|
| 653 |
+
window.undo_redo_state={undo:false,redo:false};
|
| 654 |
+
window.update_undo_redo=function(s0,s1)
|
| 655 |
+
{
|
| 656 |
+
if(s0)
|
| 657 |
+
{
|
| 658 |
+
w2ui.toolbar.enable("undo");
|
| 659 |
+
}
|
| 660 |
+
else
|
| 661 |
+
{
|
| 662 |
+
w2ui.toolbar.disable("undo");
|
| 663 |
+
}
|
| 664 |
+
if(s1)
|
| 665 |
+
{
|
| 666 |
+
w2ui.toolbar.enable("redo");
|
| 667 |
+
}
|
| 668 |
+
else
|
| 669 |
+
{
|
| 670 |
+
w2ui.toolbar.disable("redo");
|
| 671 |
+
}
|
| 672 |
+
window.undo_redo_state.undo=s0;
|
| 673 |
+
window.undo_redo_state.redo=s1;
|
| 674 |
+
}
|
js/upload.js
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
function(a,b){
|
| 2 |
+
if(!window.my_observe_upload)
|
| 3 |
+
{
|
| 4 |
+
console.log("setup upload here");
|
| 5 |
+
window.my_observe_upload = new MutationObserver(function (event) {
|
| 6 |
+
console.log(event);
|
| 7 |
+
var frame=document.querySelector("gradio-app").shadowRoot.querySelector("#sdinfframe").contentWindow.document;
|
| 8 |
+
frame.querySelector("#upload").click();
|
| 9 |
+
});
|
| 10 |
+
window.my_observe_upload_target = document.querySelector("gradio-app").shadowRoot.querySelector("#upload span");
|
| 11 |
+
window.my_observe_upload.observe(window.my_observe_upload_target, {
|
| 12 |
+
attributes: false,
|
| 13 |
+
subtree: true,
|
| 14 |
+
childList: true,
|
| 15 |
+
characterData: true
|
| 16 |
+
});
|
| 17 |
+
}
|
| 18 |
+
return [a,b];
|
| 19 |
+
}
|
js/w2ui.min.js
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
js/xss.js
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
var setup_outpaint=function(){
|
| 2 |
+
if(!window.my_observe_outpaint)
|
| 3 |
+
{
|
| 4 |
+
console.log("setup outpaint here");
|
| 5 |
+
window.my_observe_outpaint = new MutationObserver(function (event) {
|
| 6 |
+
console.log(event);
|
| 7 |
+
let app=document.querySelector("gradio-app");
|
| 8 |
+
app=app.shadowRoot??app;
|
| 9 |
+
let frame=app.querySelector("#sdinfframe").contentWindow;
|
| 10 |
+
frame.postMessage(["outpaint", ""], "*");
|
| 11 |
+
});
|
| 12 |
+
var app=document.querySelector("gradio-app");
|
| 13 |
+
app=app.shadowRoot??app;
|
| 14 |
+
window.my_observe_outpaint_target=app.querySelector("#output span");
|
| 15 |
+
window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
|
| 16 |
+
attributes: false,
|
| 17 |
+
subtree: true,
|
| 18 |
+
childList: true,
|
| 19 |
+
characterData: true
|
| 20 |
+
});
|
| 21 |
+
}
|
| 22 |
+
};
|
| 23 |
+
window.config_obj={
|
| 24 |
+
resize_check: true,
|
| 25 |
+
enable_safety: true,
|
| 26 |
+
use_correction: false,
|
| 27 |
+
enable_img2img: false,
|
| 28 |
+
use_seed: false,
|
| 29 |
+
seed_val: 0,
|
| 30 |
+
interrogate_mode: false,
|
| 31 |
+
};
|
| 32 |
+
setup_outpaint();
|
models/v1-inference.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 1.0e-04
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.0120
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: "jpg"
|
| 11 |
+
cond_stage_key: "txt"
|
| 12 |
+
image_size: 64
|
| 13 |
+
channels: 4
|
| 14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
| 15 |
+
conditioning_key: crossattn
|
| 16 |
+
monitor: val/loss_simple_ema
|
| 17 |
+
scale_factor: 0.18215
|
| 18 |
+
use_ema: False
|
| 19 |
+
|
| 20 |
+
scheduler_config: # 10000 warmup steps
|
| 21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
| 22 |
+
params:
|
| 23 |
+
warm_up_steps: [ 10000 ]
|
| 24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
| 25 |
+
f_start: [ 1.e-6 ]
|
| 26 |
+
f_max: [ 1. ]
|
| 27 |
+
f_min: [ 1. ]
|
| 28 |
+
|
| 29 |
+
unet_config:
|
| 30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 31 |
+
params:
|
| 32 |
+
image_size: 32 # unused
|
| 33 |
+
in_channels: 4
|
| 34 |
+
out_channels: 4
|
| 35 |
+
model_channels: 320
|
| 36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 37 |
+
num_res_blocks: 2
|
| 38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 39 |
+
num_heads: 8
|
| 40 |
+
use_spatial_transformer: True
|
| 41 |
+
transformer_depth: 1
|
| 42 |
+
context_dim: 768
|
| 43 |
+
use_checkpoint: True
|
| 44 |
+
legacy: False
|
| 45 |
+
|
| 46 |
+
first_stage_config:
|
| 47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 48 |
+
params:
|
| 49 |
+
embed_dim: 4
|
| 50 |
+
monitor: val/rec_loss
|
| 51 |
+
ddconfig:
|
| 52 |
+
double_z: true
|
| 53 |
+
z_channels: 4
|
| 54 |
+
resolution: 256
|
| 55 |
+
in_channels: 3
|
| 56 |
+
out_ch: 3
|
| 57 |
+
ch: 128
|
| 58 |
+
ch_mult:
|
| 59 |
+
- 1
|
| 60 |
+
- 2
|
| 61 |
+
- 4
|
| 62 |
+
- 4
|
| 63 |
+
num_res_blocks: 2
|
| 64 |
+
attn_resolutions: []
|
| 65 |
+
dropout: 0.0
|
| 66 |
+
lossconfig:
|
| 67 |
+
target: torch.nn.Identity
|
| 68 |
+
|
| 69 |
+
cond_stage_config:
|
| 70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
models/v1-inpainting-inference.yaml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model:
|
| 2 |
+
base_learning_rate: 7.5e-05
|
| 3 |
+
target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
|
| 4 |
+
params:
|
| 5 |
+
linear_start: 0.00085
|
| 6 |
+
linear_end: 0.0120
|
| 7 |
+
num_timesteps_cond: 1
|
| 8 |
+
log_every_t: 200
|
| 9 |
+
timesteps: 1000
|
| 10 |
+
first_stage_key: "jpg"
|
| 11 |
+
cond_stage_key: "txt"
|
| 12 |
+
image_size: 64
|
| 13 |
+
channels: 4
|
| 14 |
+
cond_stage_trainable: false # Note: different from the one we trained before
|
| 15 |
+
conditioning_key: hybrid # important
|
| 16 |
+
monitor: val/loss_simple_ema
|
| 17 |
+
scale_factor: 0.18215
|
| 18 |
+
finetune_keys: null
|
| 19 |
+
|
| 20 |
+
scheduler_config: # 10000 warmup steps
|
| 21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
| 22 |
+
params:
|
| 23 |
+
warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
|
| 24 |
+
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
| 25 |
+
f_start: [ 1.e-6 ]
|
| 26 |
+
f_max: [ 1. ]
|
| 27 |
+
f_min: [ 1. ]
|
| 28 |
+
|
| 29 |
+
unet_config:
|
| 30 |
+
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
| 31 |
+
params:
|
| 32 |
+
image_size: 32 # unused
|
| 33 |
+
in_channels: 9 # 4 data + 4 downscaled image + 1 mask
|
| 34 |
+
out_channels: 4
|
| 35 |
+
model_channels: 320
|
| 36 |
+
attention_resolutions: [ 4, 2, 1 ]
|
| 37 |
+
num_res_blocks: 2
|
| 38 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
| 39 |
+
num_heads: 8
|
| 40 |
+
use_spatial_transformer: True
|
| 41 |
+
transformer_depth: 1
|
| 42 |
+
context_dim: 768
|
| 43 |
+
use_checkpoint: True
|
| 44 |
+
legacy: False
|
| 45 |
+
|
| 46 |
+
first_stage_config:
|
| 47 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
| 48 |
+
params:
|
| 49 |
+
embed_dim: 4
|
| 50 |
+
monitor: val/rec_loss
|
| 51 |
+
ddconfig:
|
| 52 |
+
double_z: true
|
| 53 |
+
z_channels: 4
|
| 54 |
+
resolution: 256
|
| 55 |
+
in_channels: 3
|
| 56 |
+
out_ch: 3
|
| 57 |
+
ch: 128
|
| 58 |
+
ch_mult:
|
| 59 |
+
- 1
|
| 60 |
+
- 2
|
| 61 |
+
- 4
|
| 62 |
+
- 4
|
| 63 |
+
num_res_blocks: 2
|
| 64 |
+
attn_resolutions: []
|
| 65 |
+
dropout: 0.0
|
| 66 |
+
lossconfig:
|
| 67 |
+
target: torch.nn.Identity
|
| 68 |
+
|
| 69 |
+
cond_stage_config:
|
| 70 |
+
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
opencv.pc
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
prefix=/usr
|
| 2 |
+
exec_prefix=${prefix}
|
| 3 |
+
includedir=${prefix}/include
|
| 4 |
+
libdir=${exec_prefix}/lib
|
| 5 |
+
|
| 6 |
+
Name: opencv
|
| 7 |
+
Description: The opencv library
|
| 8 |
+
Version: 2.x.x
|
| 9 |
+
Cflags: -I${includedir}/opencv4
|
| 10 |
+
#Cflags: -I${includedir}/opencv -I${includedir}/opencv2
|
| 11 |
+
Libs: -L${libdir} -lopencv_calib3d -lopencv_imgproc -lopencv_xobjdetect -lopencv_hdf -lopencv_flann -lopencv_core -lopencv_dpm -lopencv_videoio -lopencv_reg -lopencv_quality -lopencv_tracking -lopencv_dnn_superres -lopencv_objdetect -lopencv_stitching -lopencv_saliency -lopencv_intensity_transform -lopencv_rapid -lopencv_dnn -lopencv_features2d -lopencv_text -lopencv_calib3d -lopencv_line_descriptor -lopencv_superres -lopencv_ml -lopencv_alphamat -lopencv_viz -lopencv_optflow -lopencv_videostab -lopencv_bioinspired -lopencv_highgui -lopencv_img_hash -lopencv_freetype -lopencv_imgcodecs -lopencv_mcc -lopencv_video -lopencv_photo -lopencv_surface_matching -lopencv_rgbd -lopencv_datasets -lopencv_ximgproc -lopencv_plot -lopencv_face -lopencv_stereo -lopencv_aruco -lopencv_dnn_objdetect -lopencv_phase_unwrapping -lopencv_bgsegm -lopencv_ccalib -lopencv_hfs -lopencv_imgproc -lopencv_shape -lopencv_xphoto -lopencv_structured_light -lopencv_fuzzy
|
packages.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build-essential
|
| 2 |
+
python3-opencv
|
| 3 |
+
libopencv-dev
|
| 4 |
+
cmake
|
| 5 |
+
pkg-config
|
perlin2d.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
##########
|
| 4 |
+
# https://stackoverflow.com/questions/42147776/producing-2d-perlin-noise-with-numpy/42154921#42154921
|
| 5 |
+
def perlin(x, y, seed=0):
|
| 6 |
+
# permutation table
|
| 7 |
+
np.random.seed(seed)
|
| 8 |
+
p = np.arange(256, dtype=int)
|
| 9 |
+
np.random.shuffle(p)
|
| 10 |
+
p = np.stack([p, p]).flatten()
|
| 11 |
+
# coordinates of the top-left
|
| 12 |
+
xi, yi = x.astype(int), y.astype(int)
|
| 13 |
+
# internal coordinates
|
| 14 |
+
xf, yf = x - xi, y - yi
|
| 15 |
+
# fade factors
|
| 16 |
+
u, v = fade(xf), fade(yf)
|
| 17 |
+
# noise components
|
| 18 |
+
n00 = gradient(p[p[xi] + yi], xf, yf)
|
| 19 |
+
n01 = gradient(p[p[xi] + yi + 1], xf, yf - 1)
|
| 20 |
+
n11 = gradient(p[p[xi + 1] + yi + 1], xf - 1, yf - 1)
|
| 21 |
+
n10 = gradient(p[p[xi + 1] + yi], xf - 1, yf)
|
| 22 |
+
# combine noises
|
| 23 |
+
x1 = lerp(n00, n10, u)
|
| 24 |
+
x2 = lerp(n01, n11, u) # FIX1: I was using n10 instead of n01
|
| 25 |
+
return lerp(x1, x2, v) # FIX2: I also had to reverse x1 and x2 here
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def lerp(a, b, x):
|
| 29 |
+
"linear interpolation"
|
| 30 |
+
return a + x * (b - a)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def fade(t):
|
| 34 |
+
"6t^5 - 15t^4 + 10t^3"
|
| 35 |
+
return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def gradient(h, x, y):
|
| 39 |
+
"grad converts h to the right gradient vector and return the dot product with (x,y)"
|
| 40 |
+
vectors = np.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
|
| 41 |
+
g = vectors[h % 4]
|
| 42 |
+
return g[:, :, 0] * x + g[:, :, 1] * y
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
##########
|
postprocess.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
|
| 3 |
+
MIT License
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2022 Jiayi Weng
|
| 6 |
+
|
| 7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 8 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 9 |
+
in the Software without restriction, including without limitation the rights
|
| 10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 11 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 12 |
+
furnished to do so, subject to the following conditions:
|
| 13 |
+
|
| 14 |
+
The above copyright notice and this permission notice shall be included in all
|
| 15 |
+
copies or substantial portions of the Software.
|
| 16 |
+
|
| 17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 23 |
+
SOFTWARE.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import time
|
| 27 |
+
import argparse
|
| 28 |
+
import os
|
| 29 |
+
import fpie
|
| 30 |
+
from process import ALL_BACKEND, CPU_COUNT, DEFAULT_BACKEND
|
| 31 |
+
from fpie.io import read_images, write_image
|
| 32 |
+
from process import BaseProcessor, EquProcessor, GridProcessor
|
| 33 |
+
|
| 34 |
+
from PIL import Image
|
| 35 |
+
import numpy as np
|
| 36 |
+
import skimage
|
| 37 |
+
import skimage.measure
|
| 38 |
+
import scipy
|
| 39 |
+
import scipy.signal
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class PhotometricCorrection:
|
| 43 |
+
def __init__(self,quite=False):
|
| 44 |
+
self.get_parser("cli")
|
| 45 |
+
args=self.parser.parse_args(["--method","grid","-g","src","-s","a","-t","a","-o","a"])
|
| 46 |
+
args.mpi_sync_interval = getattr(args, "mpi_sync_interval", 0)
|
| 47 |
+
self.backend=args.backend
|
| 48 |
+
self.args=args
|
| 49 |
+
self.quite=quite
|
| 50 |
+
proc: BaseProcessor
|
| 51 |
+
proc = GridProcessor(
|
| 52 |
+
args.gradient,
|
| 53 |
+
args.backend,
|
| 54 |
+
args.cpu,
|
| 55 |
+
args.mpi_sync_interval,
|
| 56 |
+
args.block_size,
|
| 57 |
+
args.grid_x,
|
| 58 |
+
args.grid_y,
|
| 59 |
+
)
|
| 60 |
+
print(
|
| 61 |
+
f"[PIE]Successfully initialize PIE {args.method} solver "
|
| 62 |
+
f"with {args.backend} backend"
|
| 63 |
+
)
|
| 64 |
+
self.proc=proc
|
| 65 |
+
|
| 66 |
+
def run(self, original_image, inpainted_image, mode="mask_mode"):
|
| 67 |
+
print(f"[PIE] start")
|
| 68 |
+
if mode=="disabled":
|
| 69 |
+
return inpainted_image
|
| 70 |
+
input_arr=np.array(original_image)
|
| 71 |
+
if input_arr[:,:,-1].sum()<1:
|
| 72 |
+
return inpainted_image
|
| 73 |
+
output_arr=np.array(inpainted_image)
|
| 74 |
+
mask=input_arr[:,:,-1]
|
| 75 |
+
mask=255-mask
|
| 76 |
+
if mask.sum()<1 and mode=="mask_mode":
|
| 77 |
+
mode=""
|
| 78 |
+
if mode=="mask_mode":
|
| 79 |
+
mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
|
| 80 |
+
mask = mask.repeat(8, axis=0).repeat(8, axis=1)
|
| 81 |
+
else:
|
| 82 |
+
mask[8:-9,8:-9]=255
|
| 83 |
+
mask = mask[:,:,np.newaxis].repeat(3,axis=2)
|
| 84 |
+
nmask=mask.copy()
|
| 85 |
+
output_arr2=output_arr[:,:,0:3].copy()
|
| 86 |
+
input_arr2=input_arr[:,:,0:3].copy()
|
| 87 |
+
output_arr2[nmask<128]=0
|
| 88 |
+
input_arr2[nmask>=128]=0
|
| 89 |
+
output_arr2+=input_arr2
|
| 90 |
+
src = output_arr2[:,:,0:3]
|
| 91 |
+
tgt = src.copy()
|
| 92 |
+
proc=self.proc
|
| 93 |
+
args=self.args
|
| 94 |
+
if proc.root:
|
| 95 |
+
n = proc.reset(src, mask, tgt, (args.h0, args.w0), (args.h1, args.w1))
|
| 96 |
+
proc.sync()
|
| 97 |
+
if proc.root:
|
| 98 |
+
result = tgt
|
| 99 |
+
t = time.time()
|
| 100 |
+
if args.p == 0:
|
| 101 |
+
args.p = args.n
|
| 102 |
+
|
| 103 |
+
for i in range(0, args.n, args.p):
|
| 104 |
+
if proc.root:
|
| 105 |
+
result, err = proc.step(args.p) # type: ignore
|
| 106 |
+
print(f"[PIE] Iter {i + args.p}, abs_err {err}")
|
| 107 |
+
else:
|
| 108 |
+
proc.step(args.p)
|
| 109 |
+
|
| 110 |
+
if proc.root:
|
| 111 |
+
dt = time.time() - t
|
| 112 |
+
print(f"[PIE] Time elapsed: {dt:.4f}s")
|
| 113 |
+
# make sure consistent with dummy process
|
| 114 |
+
return Image.fromarray(result)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def get_parser(self,gen_type: str) -> argparse.Namespace:
|
| 118 |
+
parser = argparse.ArgumentParser()
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"-v", "--version", action="store_true", help="show the version and exit"
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
"--check-backend", action="store_true", help="print all available backends"
|
| 124 |
+
)
|
| 125 |
+
if gen_type == "gui" and "mpi" in ALL_BACKEND:
|
| 126 |
+
# gui doesn't support MPI backend
|
| 127 |
+
ALL_BACKEND.remove("mpi")
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"-b",
|
| 130 |
+
"--backend",
|
| 131 |
+
type=str,
|
| 132 |
+
choices=ALL_BACKEND,
|
| 133 |
+
default=DEFAULT_BACKEND,
|
| 134 |
+
help="backend choice",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"-c",
|
| 138 |
+
"--cpu",
|
| 139 |
+
type=int,
|
| 140 |
+
default=CPU_COUNT,
|
| 141 |
+
help="number of CPU used",
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"-z",
|
| 145 |
+
"--block-size",
|
| 146 |
+
type=int,
|
| 147 |
+
default=1024,
|
| 148 |
+
help="cuda block size (only for equ solver)",
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--method",
|
| 152 |
+
type=str,
|
| 153 |
+
choices=["equ", "grid"],
|
| 154 |
+
default="equ",
|
| 155 |
+
help="how to parallelize computation",
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument("-s", "--source", type=str, help="source image filename")
|
| 158 |
+
if gen_type == "cli":
|
| 159 |
+
parser.add_argument(
|
| 160 |
+
"-m",
|
| 161 |
+
"--mask",
|
| 162 |
+
type=str,
|
| 163 |
+
help="mask image filename (default is to use the whole source image)",
|
| 164 |
+
default="",
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument("-t", "--target", type=str, help="target image filename")
|
| 167 |
+
parser.add_argument("-o", "--output", type=str, help="output image filename")
|
| 168 |
+
if gen_type == "cli":
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"-h0", type=int, help="mask position (height) on source image", default=0
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"-w0", type=int, help="mask position (width) on source image", default=0
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"-h1", type=int, help="mask position (height) on target image", default=0
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"-w1", type=int, help="mask position (width) on target image", default=0
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"-g",
|
| 183 |
+
"--gradient",
|
| 184 |
+
type=str,
|
| 185 |
+
choices=["max", "src", "avg"],
|
| 186 |
+
default="max",
|
| 187 |
+
help="how to calculate gradient for PIE",
|
| 188 |
+
)
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"-n",
|
| 191 |
+
type=int,
|
| 192 |
+
help="how many iteration would you perfer, the more the better",
|
| 193 |
+
default=5000,
|
| 194 |
+
)
|
| 195 |
+
if gen_type == "cli":
|
| 196 |
+
parser.add_argument(
|
| 197 |
+
"-p", type=int, help="output result every P iteration", default=0
|
| 198 |
+
)
|
| 199 |
+
if "mpi" in ALL_BACKEND:
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--mpi-sync-interval",
|
| 202 |
+
type=int,
|
| 203 |
+
help="MPI sync iteration interval",
|
| 204 |
+
default=100,
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--grid-x", type=int, help="x axis stride for grid solver", default=8
|
| 208 |
+
)
|
| 209 |
+
parser.add_argument(
|
| 210 |
+
"--grid-y", type=int, help="y axis stride for grid solver", default=8
|
| 211 |
+
)
|
| 212 |
+
self.parser=parser
|
| 213 |
+
|
| 214 |
+
if __name__ =="__main__":
|
| 215 |
+
import sys
|
| 216 |
+
import io
|
| 217 |
+
import base64
|
| 218 |
+
from PIL import Image
|
| 219 |
+
def base64_to_pil(base64_str):
|
| 220 |
+
data = base64.b64decode(str(base64_str))
|
| 221 |
+
pil = Image.open(io.BytesIO(data))
|
| 222 |
+
return pil
|
| 223 |
+
|
| 224 |
+
def pil_to_base64(out_pil):
|
| 225 |
+
out_buffer = io.BytesIO()
|
| 226 |
+
out_pil.save(out_buffer, format="PNG")
|
| 227 |
+
out_buffer.seek(0)
|
| 228 |
+
base64_bytes = base64.b64encode(out_buffer.read())
|
| 229 |
+
base64_str = base64_bytes.decode("ascii")
|
| 230 |
+
return base64_str
|
| 231 |
+
correction_func=PhotometricCorrection(quite=True)
|
| 232 |
+
while True:
|
| 233 |
+
buffer = sys.stdin.readline()
|
| 234 |
+
print(f"[PIE] suprocess {len(buffer)} {type(buffer)} ")
|
| 235 |
+
if len(buffer)==0:
|
| 236 |
+
break
|
| 237 |
+
if isinstance(buffer,str):
|
| 238 |
+
lst=buffer.strip().split(",")
|
| 239 |
+
else:
|
| 240 |
+
lst=buffer.decode("ascii").strip().split(",")
|
| 241 |
+
img0=base64_to_pil(lst[0])
|
| 242 |
+
img1=base64_to_pil(lst[1])
|
| 243 |
+
ret=correction_func.run(img0,img1,mode=lst[2])
|
| 244 |
+
ret_base64=pil_to_base64(ret)
|
| 245 |
+
if isinstance(buffer,str):
|
| 246 |
+
sys.stdout.write(f"{ret_base64}\n")
|
| 247 |
+
else:
|
| 248 |
+
sys.stdout.write(f"{ret_base64}\n".encode())
|
| 249 |
+
sys.stdout.flush()
|
process.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
https://github.com/Trinkle23897/Fast-Poisson-Image-Editing
|
| 3 |
+
MIT License
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2022 Jiayi Weng
|
| 6 |
+
|
| 7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 8 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 9 |
+
in the Software without restriction, including without limitation the rights
|
| 10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 11 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 12 |
+
furnished to do so, subject to the following conditions:
|
| 13 |
+
|
| 14 |
+
The above copyright notice and this permission notice shall be included in all
|
| 15 |
+
copies or substantial portions of the Software.
|
| 16 |
+
|
| 17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 23 |
+
SOFTWARE.
|
| 24 |
+
"""
|
| 25 |
+
import os
|
| 26 |
+
from abc import ABC, abstractmethod
|
| 27 |
+
from typing import Any, Optional, Tuple
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
|
| 31 |
+
from fpie import np_solver
|
| 32 |
+
|
| 33 |
+
import scipy
|
| 34 |
+
import scipy.signal
|
| 35 |
+
|
| 36 |
+
CPU_COUNT = os.cpu_count() or 1
|
| 37 |
+
DEFAULT_BACKEND = "numpy"
|
| 38 |
+
ALL_BACKEND = ["numpy"]
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
from fpie import numba_solver
|
| 42 |
+
ALL_BACKEND += ["numba"]
|
| 43 |
+
DEFAULT_BACKEND = "numba"
|
| 44 |
+
except ImportError:
|
| 45 |
+
numba_solver = None # type: ignore
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
from fpie import taichi_solver
|
| 49 |
+
ALL_BACKEND += ["taichi-cpu", "taichi-gpu"]
|
| 50 |
+
DEFAULT_BACKEND = "taichi-cpu"
|
| 51 |
+
except ImportError:
|
| 52 |
+
taichi_solver = None # type: ignore
|
| 53 |
+
|
| 54 |
+
# try:
|
| 55 |
+
# from fpie import core_gcc # type: ignore
|
| 56 |
+
# DEFAULT_BACKEND = "gcc"
|
| 57 |
+
# ALL_BACKEND.append("gcc")
|
| 58 |
+
# except ImportError:
|
| 59 |
+
# core_gcc = None
|
| 60 |
+
|
| 61 |
+
# try:
|
| 62 |
+
# from fpie import core_openmp # type: ignore
|
| 63 |
+
# DEFAULT_BACKEND = "openmp"
|
| 64 |
+
# ALL_BACKEND.append("openmp")
|
| 65 |
+
# except ImportError:
|
| 66 |
+
# core_openmp = None
|
| 67 |
+
|
| 68 |
+
# try:
|
| 69 |
+
# from mpi4py import MPI
|
| 70 |
+
|
| 71 |
+
# from fpie import core_mpi # type: ignore
|
| 72 |
+
# ALL_BACKEND.append("mpi")
|
| 73 |
+
# except ImportError:
|
| 74 |
+
# MPI = None # type: ignore
|
| 75 |
+
# core_mpi = None
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
from fpie import core_cuda # type: ignore
|
| 79 |
+
DEFAULT_BACKEND = "cuda"
|
| 80 |
+
ALL_BACKEND.append("cuda")
|
| 81 |
+
except ImportError:
|
| 82 |
+
core_cuda = None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class BaseProcessor(ABC):
|
| 86 |
+
"""API definition for processor class."""
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self, gradient: str, rank: int, backend: str, core: Optional[Any]
|
| 90 |
+
):
|
| 91 |
+
if core is None:
|
| 92 |
+
error_msg = {
|
| 93 |
+
"numpy":
|
| 94 |
+
"Please run `pip install numpy`.",
|
| 95 |
+
"numba":
|
| 96 |
+
"Please run `pip install numba`.",
|
| 97 |
+
"gcc":
|
| 98 |
+
"Please install cmake and gcc in your operating system.",
|
| 99 |
+
"openmp":
|
| 100 |
+
"Please make sure your gcc is compatible with `-fopenmp` option.",
|
| 101 |
+
"mpi":
|
| 102 |
+
"Please install MPI and run `pip install mpi4py`.",
|
| 103 |
+
"cuda":
|
| 104 |
+
"Please make sure nvcc and cuda-related libraries are available.",
|
| 105 |
+
"taichi":
|
| 106 |
+
"Please run `pip install taichi`.",
|
| 107 |
+
}
|
| 108 |
+
print(error_msg[backend.split("-")[0]])
|
| 109 |
+
|
| 110 |
+
raise AssertionError(f"Invalid backend {backend}.")
|
| 111 |
+
|
| 112 |
+
self.gradient = gradient
|
| 113 |
+
self.rank = rank
|
| 114 |
+
self.backend = backend
|
| 115 |
+
self.core = core
|
| 116 |
+
self.root = rank == 0
|
| 117 |
+
|
| 118 |
+
def mixgrad(self, a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
| 119 |
+
if self.gradient == "src":
|
| 120 |
+
return a
|
| 121 |
+
if self.gradient == "avg":
|
| 122 |
+
return (a + b) / 2
|
| 123 |
+
# mix gradient, see Equ. 12 in PIE paper
|
| 124 |
+
mask = np.abs(a) < np.abs(b)
|
| 125 |
+
a[mask] = b[mask]
|
| 126 |
+
return a
|
| 127 |
+
|
| 128 |
+
@abstractmethod
|
| 129 |
+
def reset(
|
| 130 |
+
self,
|
| 131 |
+
src: np.ndarray,
|
| 132 |
+
mask: np.ndarray,
|
| 133 |
+
tgt: np.ndarray,
|
| 134 |
+
mask_on_src: Tuple[int, int],
|
| 135 |
+
mask_on_tgt: Tuple[int, int],
|
| 136 |
+
) -> int:
|
| 137 |
+
pass
|
| 138 |
+
|
| 139 |
+
def sync(self) -> None:
|
| 140 |
+
self.core.sync()
|
| 141 |
+
|
| 142 |
+
@abstractmethod
|
| 143 |
+
def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class EquProcessor(BaseProcessor):
|
| 148 |
+
"""PIE Jacobi equation processor."""
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
gradient: str = "max",
|
| 153 |
+
backend: str = DEFAULT_BACKEND,
|
| 154 |
+
n_cpu: int = CPU_COUNT,
|
| 155 |
+
min_interval: int = 100,
|
| 156 |
+
block_size: int = 1024,
|
| 157 |
+
):
|
| 158 |
+
core: Optional[Any] = None
|
| 159 |
+
rank = 0
|
| 160 |
+
|
| 161 |
+
if backend == "numpy":
|
| 162 |
+
core = np_solver.EquSolver()
|
| 163 |
+
elif backend == "numba" and numba_solver is not None:
|
| 164 |
+
core = numba_solver.EquSolver()
|
| 165 |
+
elif backend == "gcc":
|
| 166 |
+
core = core_gcc.EquSolver()
|
| 167 |
+
elif backend == "openmp" and core_openmp is not None:
|
| 168 |
+
core = core_openmp.EquSolver(n_cpu)
|
| 169 |
+
elif backend == "mpi" and core_mpi is not None:
|
| 170 |
+
core = core_mpi.EquSolver(min_interval)
|
| 171 |
+
rank = MPI.COMM_WORLD.Get_rank()
|
| 172 |
+
elif backend == "cuda" and core_cuda is not None:
|
| 173 |
+
core = core_cuda.EquSolver(block_size)
|
| 174 |
+
elif backend.startswith("taichi") and taichi_solver is not None:
|
| 175 |
+
core = taichi_solver.EquSolver(backend, n_cpu, block_size)
|
| 176 |
+
|
| 177 |
+
super().__init__(gradient, rank, backend, core)
|
| 178 |
+
|
| 179 |
+
def mask2index(
|
| 180 |
+
self, mask: np.ndarray
|
| 181 |
+
) -> Tuple[np.ndarray, int, np.ndarray, np.ndarray]:
|
| 182 |
+
x, y = np.nonzero(mask)
|
| 183 |
+
max_id = x.shape[0] + 1
|
| 184 |
+
index = np.zeros((max_id, 3))
|
| 185 |
+
ids = self.core.partition(mask)
|
| 186 |
+
ids[mask == 0] = 0 # reserve id=0 for constant
|
| 187 |
+
index = ids[x, y].argsort()
|
| 188 |
+
return ids, max_id, x[index], y[index]
|
| 189 |
+
|
| 190 |
+
def reset(
|
| 191 |
+
self,
|
| 192 |
+
src: np.ndarray,
|
| 193 |
+
mask: np.ndarray,
|
| 194 |
+
tgt: np.ndarray,
|
| 195 |
+
mask_on_src: Tuple[int, int],
|
| 196 |
+
mask_on_tgt: Tuple[int, int],
|
| 197 |
+
) -> int:
|
| 198 |
+
assert self.root
|
| 199 |
+
# check validity
|
| 200 |
+
# assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
|
| 201 |
+
# assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
|
| 202 |
+
# assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
|
| 203 |
+
# assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
|
| 204 |
+
# assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
|
| 205 |
+
|
| 206 |
+
if len(mask.shape) == 3:
|
| 207 |
+
mask = mask.mean(-1)
|
| 208 |
+
mask = (mask >= 128).astype(np.int32)
|
| 209 |
+
|
| 210 |
+
# zero-out edge
|
| 211 |
+
mask[0] = 0
|
| 212 |
+
mask[-1] = 0
|
| 213 |
+
mask[:, 0] = 0
|
| 214 |
+
mask[:, -1] = 0
|
| 215 |
+
|
| 216 |
+
x, y = np.nonzero(mask)
|
| 217 |
+
x0, x1 = x.min() - 1, x.max() + 2
|
| 218 |
+
y0, y1 = y.min() - 1, y.max() + 2
|
| 219 |
+
mask_on_src = (x0 + mask_on_src[0], y0 + mask_on_src[1])
|
| 220 |
+
mask_on_tgt = (x0 + mask_on_tgt[0], y0 + mask_on_tgt[1])
|
| 221 |
+
mask = mask[x0:x1, y0:y1]
|
| 222 |
+
ids, max_id, index_x, index_y = self.mask2index(mask)
|
| 223 |
+
|
| 224 |
+
src_x, src_y = index_x + mask_on_src[0], index_y + mask_on_src[1]
|
| 225 |
+
tgt_x, tgt_y = index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]
|
| 226 |
+
|
| 227 |
+
src_C = src[src_x, src_y].astype(np.float32)
|
| 228 |
+
src_U = src[src_x - 1, src_y].astype(np.float32)
|
| 229 |
+
src_D = src[src_x + 1, src_y].astype(np.float32)
|
| 230 |
+
src_L = src[src_x, src_y - 1].astype(np.float32)
|
| 231 |
+
src_R = src[src_x, src_y + 1].astype(np.float32)
|
| 232 |
+
tgt_C = tgt[tgt_x, tgt_y].astype(np.float32)
|
| 233 |
+
tgt_U = tgt[tgt_x - 1, tgt_y].astype(np.float32)
|
| 234 |
+
tgt_D = tgt[tgt_x + 1, tgt_y].astype(np.float32)
|
| 235 |
+
tgt_L = tgt[tgt_x, tgt_y - 1].astype(np.float32)
|
| 236 |
+
tgt_R = tgt[tgt_x, tgt_y + 1].astype(np.float32)
|
| 237 |
+
|
| 238 |
+
grad = self.mixgrad(src_C - src_L, tgt_C - tgt_L) \
|
| 239 |
+
+ self.mixgrad(src_C - src_R, tgt_C - tgt_R) \
|
| 240 |
+
+ self.mixgrad(src_C - src_U, tgt_C - tgt_U) \
|
| 241 |
+
+ self.mixgrad(src_C - src_D, tgt_C - tgt_D)
|
| 242 |
+
|
| 243 |
+
A = np.zeros((max_id, 4), np.int32)
|
| 244 |
+
X = np.zeros((max_id, 3), np.float32)
|
| 245 |
+
B = np.zeros((max_id, 3), np.float32)
|
| 246 |
+
|
| 247 |
+
X[1:] = tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1]]
|
| 248 |
+
# four-way
|
| 249 |
+
A[1:, 0] = ids[index_x - 1, index_y]
|
| 250 |
+
A[1:, 1] = ids[index_x + 1, index_y]
|
| 251 |
+
A[1:, 2] = ids[index_x, index_y - 1]
|
| 252 |
+
A[1:, 3] = ids[index_x, index_y + 1]
|
| 253 |
+
B[1:] = grad
|
| 254 |
+
m = (mask[index_x - 1, index_y] == 0).astype(float).reshape(-1, 1)
|
| 255 |
+
B[1:] += m * tgt[index_x + mask_on_tgt[0] - 1, index_y + mask_on_tgt[1]]
|
| 256 |
+
m = (mask[index_x, index_y - 1] == 0).astype(float).reshape(-1, 1)
|
| 257 |
+
B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] - 1]
|
| 258 |
+
m = (mask[index_x, index_y + 1] == 0).astype(float).reshape(-1, 1)
|
| 259 |
+
B[1:] += m * tgt[index_x + mask_on_tgt[0], index_y + mask_on_tgt[1] + 1]
|
| 260 |
+
m = (mask[index_x + 1, index_y] == 0).astype(float).reshape(-1, 1)
|
| 261 |
+
B[1:] += m * tgt[index_x + mask_on_tgt[0] + 1, index_y + mask_on_tgt[1]]
|
| 262 |
+
|
| 263 |
+
self.tgt = tgt.copy()
|
| 264 |
+
self.tgt_index = (index_x + mask_on_tgt[0], index_y + mask_on_tgt[1])
|
| 265 |
+
self.core.reset(max_id, A, X, B)
|
| 266 |
+
return max_id
|
| 267 |
+
|
| 268 |
+
def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
| 269 |
+
result = self.core.step(iteration)
|
| 270 |
+
if self.root:
|
| 271 |
+
x, err = result
|
| 272 |
+
self.tgt[self.tgt_index] = x[1:]
|
| 273 |
+
return self.tgt, err
|
| 274 |
+
return None
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class GridProcessor(BaseProcessor):
|
| 278 |
+
"""PIE grid processor."""
|
| 279 |
+
|
| 280 |
+
def __init__(
|
| 281 |
+
self,
|
| 282 |
+
gradient: str = "max",
|
| 283 |
+
backend: str = DEFAULT_BACKEND,
|
| 284 |
+
n_cpu: int = CPU_COUNT,
|
| 285 |
+
min_interval: int = 100,
|
| 286 |
+
block_size: int = 1024,
|
| 287 |
+
grid_x: int = 8,
|
| 288 |
+
grid_y: int = 8,
|
| 289 |
+
):
|
| 290 |
+
core: Optional[Any] = None
|
| 291 |
+
rank = 0
|
| 292 |
+
|
| 293 |
+
if backend == "numpy":
|
| 294 |
+
core = np_solver.GridSolver()
|
| 295 |
+
elif backend == "numba" and numba_solver is not None:
|
| 296 |
+
core = numba_solver.GridSolver()
|
| 297 |
+
elif backend == "gcc":
|
| 298 |
+
core = core_gcc.GridSolver(grid_x, grid_y)
|
| 299 |
+
elif backend == "openmp" and core_openmp is not None:
|
| 300 |
+
core = core_openmp.GridSolver(grid_x, grid_y, n_cpu)
|
| 301 |
+
elif backend == "mpi" and core_mpi is not None:
|
| 302 |
+
core = core_mpi.GridSolver(min_interval)
|
| 303 |
+
rank = MPI.COMM_WORLD.Get_rank()
|
| 304 |
+
elif backend == "cuda" and core_cuda is not None:
|
| 305 |
+
core = core_cuda.GridSolver(grid_x, grid_y)
|
| 306 |
+
elif backend.startswith("taichi") and taichi_solver is not None:
|
| 307 |
+
core = taichi_solver.GridSolver(
|
| 308 |
+
grid_x, grid_y, backend, n_cpu, block_size
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
super().__init__(gradient, rank, backend, core)
|
| 312 |
+
|
| 313 |
+
def reset(
|
| 314 |
+
self,
|
| 315 |
+
src: np.ndarray,
|
| 316 |
+
mask: np.ndarray,
|
| 317 |
+
tgt: np.ndarray,
|
| 318 |
+
mask_on_src: Tuple[int, int],
|
| 319 |
+
mask_on_tgt: Tuple[int, int],
|
| 320 |
+
) -> int:
|
| 321 |
+
assert self.root
|
| 322 |
+
# check validity
|
| 323 |
+
# assert 0 <= mask_on_src[0] and 0 <= mask_on_src[1]
|
| 324 |
+
# assert mask_on_src[0] + mask.shape[0] <= src.shape[0]
|
| 325 |
+
# assert mask_on_src[1] + mask.shape[1] <= src.shape[1]
|
| 326 |
+
# assert mask_on_tgt[0] + mask.shape[0] <= tgt.shape[0]
|
| 327 |
+
# assert mask_on_tgt[1] + mask.shape[1] <= tgt.shape[1]
|
| 328 |
+
|
| 329 |
+
if len(mask.shape) == 3:
|
| 330 |
+
mask = mask.mean(-1)
|
| 331 |
+
mask = (mask >= 128).astype(np.int32)
|
| 332 |
+
|
| 333 |
+
# zero-out edge
|
| 334 |
+
mask[0] = 0
|
| 335 |
+
mask[-1] = 0
|
| 336 |
+
mask[:, 0] = 0
|
| 337 |
+
mask[:, -1] = 0
|
| 338 |
+
|
| 339 |
+
x, y = np.nonzero(mask)
|
| 340 |
+
x0, x1 = x.min() - 1, x.max() + 2
|
| 341 |
+
y0, y1 = y.min() - 1, y.max() + 2
|
| 342 |
+
mask = mask[x0:x1, y0:y1]
|
| 343 |
+
max_id = np.prod(mask.shape)
|
| 344 |
+
|
| 345 |
+
src_crop = src[mask_on_src[0] + x0:mask_on_src[0] + x1,
|
| 346 |
+
mask_on_src[1] + y0:mask_on_src[1] + y1].astype(np.float32)
|
| 347 |
+
tgt_crop = tgt[mask_on_tgt[0] + x0:mask_on_tgt[0] + x1,
|
| 348 |
+
mask_on_tgt[1] + y0:mask_on_tgt[1] + y1].astype(np.float32)
|
| 349 |
+
grad = np.zeros([*mask.shape, 3], np.float32)
|
| 350 |
+
grad[1:] += self.mixgrad(
|
| 351 |
+
src_crop[1:] - src_crop[:-1], tgt_crop[1:] - tgt_crop[:-1]
|
| 352 |
+
)
|
| 353 |
+
grad[:-1] += self.mixgrad(
|
| 354 |
+
src_crop[:-1] - src_crop[1:], tgt_crop[:-1] - tgt_crop[1:]
|
| 355 |
+
)
|
| 356 |
+
grad[:, 1:] += self.mixgrad(
|
| 357 |
+
src_crop[:, 1:] - src_crop[:, :-1], tgt_crop[:, 1:] - tgt_crop[:, :-1]
|
| 358 |
+
)
|
| 359 |
+
grad[:, :-1] += self.mixgrad(
|
| 360 |
+
src_crop[:, :-1] - src_crop[:, 1:], tgt_crop[:, :-1] - tgt_crop[:, 1:]
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
grad[mask == 0] = 0
|
| 364 |
+
if True:
|
| 365 |
+
kernel = [[1] * 3 for _ in range(3)]
|
| 366 |
+
nmask = mask.copy()
|
| 367 |
+
nmask[nmask > 0] = 1
|
| 368 |
+
res = scipy.signal.convolve2d(
|
| 369 |
+
nmask, kernel, mode="same", boundary="fill", fillvalue=1
|
| 370 |
+
)
|
| 371 |
+
res[nmask < 1] = 0
|
| 372 |
+
res[res == 9] = 0
|
| 373 |
+
res[res > 0] = 1
|
| 374 |
+
grad[res>0]=0
|
| 375 |
+
# ylst, xlst = res.nonzero()
|
| 376 |
+
# for y, x in zip(ylst, xlst):
|
| 377 |
+
# grad[y,x]=0
|
| 378 |
+
# for yi in range(-1,2):
|
| 379 |
+
# for xi in range(-1,2):
|
| 380 |
+
# grad[y+yi,x+xi]=0
|
| 381 |
+
self.x0 = mask_on_tgt[0] + x0
|
| 382 |
+
self.x1 = mask_on_tgt[0] + x1
|
| 383 |
+
self.y0 = mask_on_tgt[1] + y0
|
| 384 |
+
self.y1 = mask_on_tgt[1] + y1
|
| 385 |
+
self.tgt = tgt.copy()
|
| 386 |
+
self.core.reset(max_id, mask, tgt_crop, grad)
|
| 387 |
+
return max_id
|
| 388 |
+
|
| 389 |
+
def step(self, iteration: int) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
| 390 |
+
result = self.core.step(iteration)
|
| 391 |
+
if self.root:
|
| 392 |
+
tgt, err = result
|
| 393 |
+
self.tgt[self.x0:self.x1, self.y0:self.y1] = tgt
|
| 394 |
+
return self.tgt, err
|
| 395 |
+
return None
|
requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --extra-index-url https://download.pytorch.org/whl/cu118
|
| 2 |
+
imageio
|
| 3 |
+
imageio-ffmpeg
|
| 4 |
+
numpy<=1.22.4
|
| 5 |
+
torch==2.5.1
|
| 6 |
+
torchvision
|
| 7 |
+
torchaudio
|
| 8 |
+
pydantic==2.10.6
|
| 9 |
+
Pillow
|
| 10 |
+
scipy
|
| 11 |
+
scikit-image
|
| 12 |
+
diffusers
|
| 13 |
+
stablepy
|
| 14 |
+
transformers==4.49.0
|
| 15 |
+
# ftfy
|
| 16 |
+
# fpie
|
| 17 |
+
accelerate
|
| 18 |
+
ninja
|
| 19 |
+
opencv-python
|
| 20 |
+
opencv-python-headless
|
utils.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from PIL import ImageFilter
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import scipy
|
| 6 |
+
import scipy.signal
|
| 7 |
+
from scipy.spatial import cKDTree
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from perlin2d import *
|
| 11 |
+
|
| 12 |
+
patch_match_compiled = True
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from PyPatchMatch import patch_match
|
| 16 |
+
except Exception as e:
|
| 17 |
+
try:
|
| 18 |
+
import patch_match
|
| 19 |
+
except Exception as e:
|
| 20 |
+
patch_match_compiled = False
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
patch_match
|
| 24 |
+
except NameError:
|
| 25 |
+
print("patch_match compiling failed, will fall back to edge_pad")
|
| 26 |
+
patch_match_compiled = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def edge_pad(img, mask, mode=1):
|
| 32 |
+
if mode == 0:
|
| 33 |
+
nmask = mask.copy()
|
| 34 |
+
nmask[nmask > 0] = 1
|
| 35 |
+
res0 = 1 - nmask
|
| 36 |
+
res1 = nmask
|
| 37 |
+
p0 = np.stack(res0.nonzero(), axis=0).transpose()
|
| 38 |
+
p1 = np.stack(res1.nonzero(), axis=0).transpose()
|
| 39 |
+
min_dists, min_dist_idx = cKDTree(p1).query(p0, 1)
|
| 40 |
+
loc = p1[min_dist_idx]
|
| 41 |
+
for (a, b), (c, d) in zip(p0, loc):
|
| 42 |
+
img[a, b] = img[c, d]
|
| 43 |
+
elif mode == 1:
|
| 44 |
+
record = {}
|
| 45 |
+
kernel = [[1] * 3 for _ in range(3)]
|
| 46 |
+
nmask = mask.copy()
|
| 47 |
+
nmask[nmask > 0] = 1
|
| 48 |
+
res = scipy.signal.convolve2d(
|
| 49 |
+
nmask, kernel, mode="same", boundary="fill", fillvalue=1
|
| 50 |
+
)
|
| 51 |
+
res[nmask < 1] = 0
|
| 52 |
+
res[res == 9] = 0
|
| 53 |
+
res[res > 0] = 1
|
| 54 |
+
ylst, xlst = res.nonzero()
|
| 55 |
+
queue = [(y, x) for y, x in zip(ylst, xlst)]
|
| 56 |
+
# bfs here
|
| 57 |
+
cnt = res.astype(np.float32)
|
| 58 |
+
acc = img.astype(np.float32)
|
| 59 |
+
step = 1
|
| 60 |
+
h = acc.shape[0]
|
| 61 |
+
w = acc.shape[1]
|
| 62 |
+
offset = [(1, 0), (-1, 0), (0, 1), (0, -1)]
|
| 63 |
+
while queue:
|
| 64 |
+
target = []
|
| 65 |
+
for y, x in queue:
|
| 66 |
+
val = acc[y][x]
|
| 67 |
+
for yo, xo in offset:
|
| 68 |
+
yn = y + yo
|
| 69 |
+
xn = x + xo
|
| 70 |
+
if 0 <= yn < h and 0 <= xn < w and nmask[yn][xn] < 1:
|
| 71 |
+
if record.get((yn, xn), step) == step:
|
| 72 |
+
acc[yn][xn] = acc[yn][xn] * cnt[yn][xn] + val
|
| 73 |
+
cnt[yn][xn] += 1
|
| 74 |
+
acc[yn][xn] /= cnt[yn][xn]
|
| 75 |
+
if (yn, xn) not in record:
|
| 76 |
+
record[(yn, xn)] = step
|
| 77 |
+
target.append((yn, xn))
|
| 78 |
+
step += 1
|
| 79 |
+
queue = target
|
| 80 |
+
img = acc.astype(np.uint8)
|
| 81 |
+
else:
|
| 82 |
+
nmask = mask.copy()
|
| 83 |
+
ylst, xlst = nmask.nonzero()
|
| 84 |
+
yt, xt = ylst.min(), xlst.min()
|
| 85 |
+
yb, xb = ylst.max(), xlst.max()
|
| 86 |
+
content = img[yt : yb + 1, xt : xb + 1]
|
| 87 |
+
img = np.pad(
|
| 88 |
+
content,
|
| 89 |
+
((yt, mask.shape[0] - yb - 1), (xt, mask.shape[1] - xb - 1), (0, 0)),
|
| 90 |
+
mode="edge",
|
| 91 |
+
)
|
| 92 |
+
return img, mask
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def perlin_noise(img, mask):
|
| 96 |
+
lin_x = np.linspace(0, 5, mask.shape[1], endpoint=False)
|
| 97 |
+
lin_y = np.linspace(0, 5, mask.shape[0], endpoint=False)
|
| 98 |
+
x, y = np.meshgrid(lin_x, lin_y)
|
| 99 |
+
avg = img.mean(axis=0).mean(axis=0)
|
| 100 |
+
# noise=[((perlin(x, y)+1)*128+avg[i]).astype(np.uint8) for i in range(3)]
|
| 101 |
+
noise = [((perlin(x, y) + 1) * 0.5 * 255).astype(np.uint8) for i in range(3)]
|
| 102 |
+
noise = np.stack(noise, axis=-1)
|
| 103 |
+
# mask=skimage.measure.block_reduce(mask,(8,8),np.min)
|
| 104 |
+
# mask=mask.repeat(8, axis=0).repeat(8, axis=1)
|
| 105 |
+
# mask_image=Image.fromarray(mask)
|
| 106 |
+
# mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 4))
|
| 107 |
+
# mask=np.array(mask_image)
|
| 108 |
+
nmask = mask.copy()
|
| 109 |
+
# nmask=nmask/255.0
|
| 110 |
+
nmask[mask > 0] = 1
|
| 111 |
+
img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
|
| 112 |
+
# img=img.astype(np.uint8)
|
| 113 |
+
return img, mask
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def gaussian_noise(img, mask):
|
| 117 |
+
noise = np.random.randn(mask.shape[0], mask.shape[1], 3)
|
| 118 |
+
noise = (noise + 1) / 2 * 255
|
| 119 |
+
noise = noise.astype(np.uint8)
|
| 120 |
+
nmask = mask.copy()
|
| 121 |
+
nmask[mask > 0] = 1
|
| 122 |
+
img = nmask[:, :, np.newaxis] * img + (1 - nmask[:, :, np.newaxis]) * noise
|
| 123 |
+
return img, mask
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def cv2_telea(img, mask):
|
| 127 |
+
ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_TELEA)
|
| 128 |
+
return ret, mask
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def cv2_ns(img, mask):
|
| 132 |
+
ret = cv2.inpaint(img, 255 - mask, 5, cv2.INPAINT_NS)
|
| 133 |
+
return ret, mask
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def patch_match_func(img, mask):
|
| 137 |
+
ret = patch_match.inpaint(img, mask=255 - mask, patch_size=3)
|
| 138 |
+
return ret, mask
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def mean_fill(img, mask):
|
| 142 |
+
avg = img.mean(axis=0).mean(axis=0)
|
| 143 |
+
img[mask < 1] = avg
|
| 144 |
+
return img, mask
|
| 145 |
+
|
| 146 |
+
"""
|
| 147 |
+
Apache-2.0 license
|
| 148 |
+
https://github.com/hafriedlander/stable-diffusion-grpcserver/blob/main/sdgrpcserver/services/generate.py
|
| 149 |
+
https://github.com/parlance-zz/g-diffuser-bot/tree/g-diffuser-bot-beta2
|
| 150 |
+
_handleImageAdjustment
|
| 151 |
+
"""
|
| 152 |
+
try:
|
| 153 |
+
from sd_grpcserver.sdgrpcserver import images
|
| 154 |
+
import torch
|
| 155 |
+
from math import sqrt
|
| 156 |
+
def handleImageAdjustment(array, adjustments):
|
| 157 |
+
tensor = images.fromPIL(Image.fromarray(array))
|
| 158 |
+
for adjustment in adjustments:
|
| 159 |
+
which = adjustment[0]
|
| 160 |
+
|
| 161 |
+
if which == "blur":
|
| 162 |
+
sigma = adjustment[1]
|
| 163 |
+
direction = adjustment[2]
|
| 164 |
+
|
| 165 |
+
if direction == "DOWN" or direction == "UP":
|
| 166 |
+
orig = tensor
|
| 167 |
+
repeatCount=256
|
| 168 |
+
sigma /= sqrt(repeatCount)
|
| 169 |
+
|
| 170 |
+
for _ in range(repeatCount):
|
| 171 |
+
tensor = images.gaussianblur(tensor, sigma)
|
| 172 |
+
if direction == "DOWN":
|
| 173 |
+
tensor = torch.minimum(tensor, orig)
|
| 174 |
+
else:
|
| 175 |
+
tensor = torch.maximum(tensor, orig)
|
| 176 |
+
else:
|
| 177 |
+
tensor = images.gaussianblur(tensor, adjustment.blur.sigma)
|
| 178 |
+
elif which == "invert":
|
| 179 |
+
tensor = images.invert(tensor)
|
| 180 |
+
elif which == "levels":
|
| 181 |
+
tensor = images.levels(tensor, adjustment[1], adjustment[2], adjustment[3], adjustment[4])
|
| 182 |
+
elif which == "channels":
|
| 183 |
+
tensor = images.channelmap(tensor, [adjustment.channels.r, adjustment.channels.g, adjustment.channels.b, adjustment.channels.a])
|
| 184 |
+
elif which == "rescale":
|
| 185 |
+
self.unimp("Rescale")
|
| 186 |
+
elif which == "crop":
|
| 187 |
+
tensor = images.crop(tensor, adjustment.crop.top, adjustment.crop.left, adjustment.crop.height, adjustment.crop.width)
|
| 188 |
+
return np.array(images.toPIL(tensor)[0])
|
| 189 |
+
|
| 190 |
+
def g_diffuser(img,mask):
|
| 191 |
+
adjustments=[["blur",32,"UP"],["level",0,0.05,0,1]]
|
| 192 |
+
mask=handleImageAdjustment(mask,adjustments)
|
| 193 |
+
out_mask=handleImageAdjustment(mask,adjustments)
|
| 194 |
+
return img, mask
|
| 195 |
+
except:
|
| 196 |
+
def g_diffuser(img,mask):
|
| 197 |
+
return img,mask
|
| 198 |
+
|
| 199 |
+
def dummy_fill(img,mask):
|
| 200 |
+
return img,mask
|
| 201 |
+
functbl = {
|
| 202 |
+
"gaussian": gaussian_noise,
|
| 203 |
+
"perlin": perlin_noise,
|
| 204 |
+
"edge_pad": edge_pad,
|
| 205 |
+
"patchmatch": patch_match_func if patch_match_compiled else edge_pad,
|
| 206 |
+
"cv2_ns": cv2_ns,
|
| 207 |
+
"cv2_telea": cv2_telea,
|
| 208 |
+
"g_diffuser": g_diffuser,
|
| 209 |
+
"g_diffuser_lib": dummy_fill,
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
try:
|
| 213 |
+
from postprocess import PhotometricCorrection
|
| 214 |
+
correction_func = PhotometricCorrection()
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print(e, "so PhotometricCorrection is disabled")
|
| 217 |
+
class DummyCorrection:
|
| 218 |
+
def __init__(self):
|
| 219 |
+
self.backend=""
|
| 220 |
+
pass
|
| 221 |
+
def run(self,a,b,**kwargs):
|
| 222 |
+
return b
|
| 223 |
+
correction_func=DummyCorrection()
|
| 224 |
+
|
| 225 |
+
class DummyInterrogator:
|
| 226 |
+
def __init__(self) -> None:
|
| 227 |
+
pass
|
| 228 |
+
def interrogate(self,pil):
|
| 229 |
+
return "Interrogator init failed"
|
| 230 |
+
|
| 231 |
+
if "taichi" in correction_func.backend:
|
| 232 |
+
import sys
|
| 233 |
+
import io
|
| 234 |
+
import base64
|
| 235 |
+
from PIL import Image
|
| 236 |
+
def base64_to_pil(base64_str):
|
| 237 |
+
data = base64.b64decode(str(base64_str))
|
| 238 |
+
pil = Image.open(io.BytesIO(data))
|
| 239 |
+
return pil
|
| 240 |
+
|
| 241 |
+
def pil_to_base64(out_pil):
|
| 242 |
+
out_buffer = io.BytesIO()
|
| 243 |
+
out_pil.save(out_buffer, format="PNG")
|
| 244 |
+
out_buffer.seek(0)
|
| 245 |
+
base64_bytes = base64.b64encode(out_buffer.read())
|
| 246 |
+
base64_str = base64_bytes.decode("ascii")
|
| 247 |
+
return base64_str
|
| 248 |
+
from subprocess import Popen, PIPE, STDOUT
|
| 249 |
+
class SubprocessCorrection:
|
| 250 |
+
def __init__(self):
|
| 251 |
+
self.backend=correction_func.backend
|
| 252 |
+
self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
|
| 253 |
+
def run(self,img_input,img_inpainted,mode):
|
| 254 |
+
if mode=="disabled":
|
| 255 |
+
return img_inpainted
|
| 256 |
+
base64_str_input = pil_to_base64(img_input)
|
| 257 |
+
base64_str_inpainted = pil_to_base64(img_inpainted)
|
| 258 |
+
try:
|
| 259 |
+
if self.child.poll():
|
| 260 |
+
self.child= Popen(["python", "postprocess.py"], stdin=PIPE, stdout=PIPE, stderr=STDOUT)
|
| 261 |
+
self.child.stdin.write(f"{base64_str_input},{base64_str_inpainted},{mode}\n".encode())
|
| 262 |
+
self.child.stdin.flush()
|
| 263 |
+
out = self.child.stdout.readline()
|
| 264 |
+
base64_str=out.decode().strip()
|
| 265 |
+
while base64_str and base64_str[0]=="[":
|
| 266 |
+
print(base64_str)
|
| 267 |
+
out = self.child.stdout.readline()
|
| 268 |
+
base64_str=out.decode().strip()
|
| 269 |
+
ret=base64_to_pil(base64_str)
|
| 270 |
+
except:
|
| 271 |
+
print("[PIE] not working, photometric correction is disabled")
|
| 272 |
+
ret=img_inpainted
|
| 273 |
+
return ret
|
| 274 |
+
correction_func = SubprocessCorrection()
|