falseu
commited on
Commit
·
4f6c34a
1
Parent(s):
0dfef01
update comments
Browse files- AdaIN.py +20 -1
- test.py +11 -3
- test_interpolate.py +6 -2
- test_video.py +6 -1
- train.py +11 -4
AdaIN.py
CHANGED
|
@@ -13,7 +13,11 @@ class AdaINNet(nn.Module):
|
|
| 13 |
def __init__(self, vgg_weight):
|
| 14 |
super().__init__()
|
| 15 |
self.encoder = vgg19(vgg_weight)
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
for parameter in self.encoder.parameters():
|
| 18 |
parameter.requires_grad = False
|
| 19 |
|
|
@@ -21,15 +25,29 @@ class AdaINNet(nn.Module):
|
|
| 21 |
|
| 22 |
self.mseloss = nn.MSELoss()
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
def _style_loss(self, x, y):
|
| 25 |
return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
|
| 26 |
self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
|
| 27 |
|
| 28 |
def forward(self, content, style, alpha=1.0):
|
|
|
|
| 29 |
content_enc = self.encoder(content)
|
| 30 |
style_enc = self.encoder(style)
|
|
|
|
|
|
|
| 31 |
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
|
| 32 |
|
|
|
|
| 33 |
out = self.decoder(transfer_enc)
|
| 34 |
|
| 35 |
# vgg19 layer relu1_1
|
|
@@ -47,6 +65,7 @@ class AdaINNet(nn.Module):
|
|
| 47 |
# vgg19 layer relu4_1
|
| 48 |
out_enc = self.encoder[13:](out_relu31)
|
| 49 |
|
|
|
|
| 50 |
content_loss = self.mseloss(out_enc, transfer_enc)
|
| 51 |
style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
|
| 52 |
self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
|
|
|
|
| 13 |
def __init__(self, vgg_weight):
|
| 14 |
super().__init__()
|
| 15 |
self.encoder = vgg19(vgg_weight)
|
| 16 |
+
|
| 17 |
+
# drop layers after 4_1
|
| 18 |
+
self.encoder = nn.Sequential(*list(self.encoder.children())[:22])
|
| 19 |
+
|
| 20 |
+
# No optimization for encoder
|
| 21 |
for parameter in self.encoder.parameters():
|
| 22 |
parameter.requires_grad = False
|
| 23 |
|
|
|
|
| 25 |
|
| 26 |
self.mseloss = nn.MSELoss()
|
| 27 |
|
| 28 |
+
"""
|
| 29 |
+
Computes style loss of two images
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
x (torch.FloatTensor): content image tensor
|
| 33 |
+
y (torch.FloatTensor): style image tensor
|
| 34 |
+
|
| 35 |
+
Return:
|
| 36 |
+
Mean Squared Error between x.mean, y.mean and MSE between x.std, y.std
|
| 37 |
+
"""
|
| 38 |
def _style_loss(self, x, y):
|
| 39 |
return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
|
| 40 |
self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
|
| 41 |
|
| 42 |
def forward(self, content, style, alpha=1.0):
|
| 43 |
+
# Generate image features
|
| 44 |
content_enc = self.encoder(content)
|
| 45 |
style_enc = self.encoder(style)
|
| 46 |
+
|
| 47 |
+
# Perform style transfer on feature space
|
| 48 |
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
|
| 49 |
|
| 50 |
+
# Generate outptu image
|
| 51 |
out = self.decoder(transfer_enc)
|
| 52 |
|
| 53 |
# vgg19 layer relu1_1
|
|
|
|
| 65 |
# vgg19 layer relu4_1
|
| 66 |
out_enc = self.encoder[13:](out_relu31)
|
| 67 |
|
| 68 |
+
# Calculate loss
|
| 69 |
content_loss = self.mseloss(out_enc, transfer_enc)
|
| 70 |
style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
|
| 71 |
self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
|
test.py
CHANGED
|
@@ -69,6 +69,7 @@ def main():
|
|
| 69 |
assert len(content_pths) > 0, 'Failed to load content image'
|
| 70 |
assert len(style_pths) > 0, 'Failed to load style image'
|
| 71 |
|
|
|
|
| 72 |
out_dir = './results/'
|
| 73 |
os.makedirs(out_dir, exist_ok=True)
|
| 74 |
|
|
@@ -81,8 +82,9 @@ def main():
|
|
| 81 |
# Prepare image transform
|
| 82 |
t = transform(512)
|
| 83 |
|
| 84 |
-
# Prepare grid image
|
| 85 |
if args.grid_pth:
|
|
|
|
| 86 |
imgs = [np.ones((1, 1, 3), np.uint8) * 255]
|
| 87 |
for style_pth in style_pths:
|
| 88 |
imgs.append(Image.open(style_pth))
|
|
@@ -101,15 +103,20 @@ def main():
|
|
| 101 |
|
| 102 |
style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
|
| 103 |
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
| 105 |
with torch.no_grad():
|
| 106 |
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha).cpu()
|
| 107 |
|
| 108 |
-
|
|
|
|
| 109 |
print("Content: " + content_pth.stem + ". Style: " \
|
| 110 |
+ style_pth.stem + '. Alpha: ' + str(args.alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
|
| 111 |
times.append(toc-tic)
|
| 112 |
|
|
|
|
| 113 |
out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix
|
| 114 |
save_image(out_tensor, out_pth)
|
| 115 |
|
|
@@ -122,6 +129,7 @@ def main():
|
|
| 122 |
avg = sum(times)/len(times)
|
| 123 |
print("Average style transfer time: %.4f seconds" % (avg))
|
| 124 |
|
|
|
|
| 125 |
if args.grid_pth:
|
| 126 |
print("Generating grid image")
|
| 127 |
grid_image(len(content_pths) + 1, len(style_pths) + 1, imgs, save_pth=args.grid_pth)
|
|
|
|
| 69 |
assert len(content_pths) > 0, 'Failed to load content image'
|
| 70 |
assert len(style_pths) > 0, 'Failed to load style image'
|
| 71 |
|
| 72 |
+
# Prepare directory for saving results
|
| 73 |
out_dir = './results/'
|
| 74 |
os.makedirs(out_dir, exist_ok=True)
|
| 75 |
|
|
|
|
| 82 |
# Prepare image transform
|
| 83 |
t = transform(512)
|
| 84 |
|
| 85 |
+
# Prepare grid image, add style images to the first row
|
| 86 |
if args.grid_pth:
|
| 87 |
+
# Add empty image
|
| 88 |
imgs = [np.ones((1, 1, 3), np.uint8) * 255]
|
| 89 |
for style_pth in style_pths:
|
| 90 |
imgs.append(Image.open(style_pth))
|
|
|
|
| 103 |
|
| 104 |
style_tensor = t(Image.open(style_pth)).unsqueeze(0).to(device)
|
| 105 |
|
| 106 |
+
# Start time
|
| 107 |
+
tic = time.perf_counter()
|
| 108 |
+
|
| 109 |
+
# Execute style transfer
|
| 110 |
with torch.no_grad():
|
| 111 |
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha).cpu()
|
| 112 |
|
| 113 |
+
# End time
|
| 114 |
+
toc = time.perf_counter()
|
| 115 |
print("Content: " + content_pth.stem + ". Style: " \
|
| 116 |
+ style_pth.stem + '. Alpha: ' + str(args.alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
|
| 117 |
times.append(toc-tic)
|
| 118 |
|
| 119 |
+
# Save image
|
| 120 |
out_pth = out_dir + content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix
|
| 121 |
save_image(out_tensor, out_pth)
|
| 122 |
|
|
|
|
| 129 |
avg = sum(times)/len(times)
|
| 130 |
print("Average style transfer time: %.4f seconds" % (avg))
|
| 131 |
|
| 132 |
+
# Generate grid image
|
| 133 |
if args.grid_pth:
|
| 134 |
print("Generating grid image")
|
| 135 |
grid_image(len(content_pths) + 1, len(style_pths) + 1, imgs, save_pth=args.grid_pth)
|
test_interpolate.py
CHANGED
|
@@ -102,24 +102,28 @@ def main():
|
|
| 102 |
for content_pth in content_pths:
|
| 103 |
content_tensor = t(Image.open(content_pth)).unsqueeze(0).to(device)
|
| 104 |
|
|
|
|
| 105 |
style_tensor = []
|
| 106 |
for style_pth in style_pths:
|
| 107 |
img = Image.open(style_pth)
|
| 108 |
-
style_tensor.append(transform([512, 512])(img))
|
| 109 |
style_tensor = torch.stack(style_tensor, dim=0).to(device)
|
| 110 |
|
| 111 |
-
for inter_weight in inter_weights:
|
|
|
|
| 112 |
with torch.no_grad():
|
| 113 |
out_tensor = out_tensor = interpolate_style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha, inter_weight).cpu()
|
| 114 |
|
| 115 |
print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
|
| 116 |
|
|
|
|
| 117 |
out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight) + content_pth.suffix
|
| 118 |
save_image(out_tensor, out_pth)
|
| 119 |
|
| 120 |
if args.grid_pth:
|
| 121 |
imgs.append(Image.open(out_pth))
|
| 122 |
|
|
|
|
| 123 |
if args.grid_pth:
|
| 124 |
print("Generating grid image")
|
| 125 |
grid_image(5, 5, imgs, save_pth=args.grid_pth)
|
|
|
|
| 102 |
for content_pth in content_pths:
|
| 103 |
content_tensor = t(Image.open(content_pth)).unsqueeze(0).to(device)
|
| 104 |
|
| 105 |
+
# Prepare multiple style images
|
| 106 |
style_tensor = []
|
| 107 |
for style_pth in style_pths:
|
| 108 |
img = Image.open(style_pth)
|
| 109 |
+
style_tensor.append(transform([512, 512])(img))
|
| 110 |
style_tensor = torch.stack(style_tensor, dim=0).to(device)
|
| 111 |
|
| 112 |
+
for inter_weight in inter_weights:
|
| 113 |
+
# Execute Interpolate style transfer
|
| 114 |
with torch.no_grad():
|
| 115 |
out_tensor = out_tensor = interpolate_style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha, inter_weight).cpu()
|
| 116 |
|
| 117 |
print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))
|
| 118 |
|
| 119 |
+
# Save results
|
| 120 |
out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight) + content_pth.suffix
|
| 121 |
save_image(out_tensor, out_pth)
|
| 122 |
|
| 123 |
if args.grid_pth:
|
| 124 |
imgs.append(Image.open(out_pth))
|
| 125 |
|
| 126 |
+
# Generate grid image
|
| 127 |
if args.grid_pth:
|
| 128 |
print("Generating grid image")
|
| 129 |
grid_image(5, 5, imgs, save_pth=args.grid_pth)
|
test_video.py
CHANGED
|
@@ -55,13 +55,16 @@ def main():
|
|
| 55 |
style_image_pth = Path(args.style_image)
|
| 56 |
style_image = Image.open(style_image_pth)
|
| 57 |
|
|
|
|
| 58 |
fps = int(content_video.get(cv2.CAP_PROP_FPS))
|
| 59 |
frame_count = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 60 |
video_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 61 |
video_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 62 |
|
|
|
|
| 63 |
video_tqdm = tqdm(frame_count)
|
| 64 |
|
|
|
|
| 65 |
out_dir = './results_video/'
|
| 66 |
os.makedirs(out_dir, exist_ok=True)
|
| 67 |
out_pth = Path(out_dir + content_video_pth.stem + '_style_' \
|
|
@@ -81,7 +84,8 @@ def main():
|
|
| 81 |
|
| 82 |
while content_video.isOpened():
|
| 83 |
ret, content_image = content_video.read()
|
| 84 |
-
|
|
|
|
| 85 |
break
|
| 86 |
|
| 87 |
content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
|
|
@@ -96,6 +100,7 @@ def main():
|
|
| 96 |
out_tensor = cv2.normalize(src=out_tensor, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
|
| 97 |
out_tensor = cv2.resize(out_tensor, (video_width, video_height), interpolation=cv2.INTER_CUBIC)
|
| 98 |
|
|
|
|
| 99 |
writer.append_data(np.array(out_tensor))
|
| 100 |
video_tqdm.update(1)
|
| 101 |
|
|
|
|
| 55 |
style_image_pth = Path(args.style_image)
|
| 56 |
style_image = Image.open(style_image_pth)
|
| 57 |
|
| 58 |
+
# Read video info
|
| 59 |
fps = int(content_video.get(cv2.CAP_PROP_FPS))
|
| 60 |
frame_count = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 61 |
video_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
| 62 |
video_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
| 63 |
|
| 64 |
+
# Prepare loop
|
| 65 |
video_tqdm = tqdm(frame_count)
|
| 66 |
|
| 67 |
+
# Prepare output video writer
|
| 68 |
out_dir = './results_video/'
|
| 69 |
os.makedirs(out_dir, exist_ok=True)
|
| 70 |
out_pth = Path(out_dir + content_video_pth.stem + '_style_' \
|
|
|
|
| 84 |
|
| 85 |
while content_video.isOpened():
|
| 86 |
ret, content_image = content_video.read()
|
| 87 |
+
# Failed to read a frame
|
| 88 |
+
if not ret:
|
| 89 |
break
|
| 90 |
|
| 91 |
content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
|
|
|
|
| 100 |
out_tensor = cv2.normalize(src=out_tensor, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
|
| 101 |
out_tensor = cv2.resize(out_tensor, (video_width, video_height), interpolation=cv2.INTER_CUBIC)
|
| 102 |
|
| 103 |
+
# Write output frame to video
|
| 104 |
writer.append_data(np.array(out_tensor))
|
| 105 |
video_tqdm.update(1)
|
| 106 |
|
train.py
CHANGED
|
@@ -17,21 +17,24 @@ def main():
|
|
| 17 |
args = parser.parse_args()
|
| 18 |
|
| 19 |
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
| 20 |
-
|
| 21 |
check_point_dir = './check_point/'
|
| 22 |
weights_dir = './weights/'
|
|
|
|
|
|
|
| 23 |
train_set = TrainSet(args.content_dir, args.style_dir)
|
| 24 |
train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
|
| 25 |
-
|
|
|
|
| 26 |
vgg_model = torch.load('vgg_normalized.pth')
|
| 27 |
model = AdaINNet(vgg_model).to(device)
|
| 28 |
-
|
| 29 |
decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
|
|
|
|
| 30 |
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
|
| 31 |
losses = []
|
| 32 |
iteration = 0
|
| 33 |
|
| 34 |
-
# If resume
|
| 35 |
if args.resume > 0:
|
| 36 |
states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
|
| 37 |
model.decoder.load_state_dict(states['decoder'])
|
|
@@ -54,10 +57,14 @@ def main():
|
|
| 54 |
content_batch = content_batch.to(device)
|
| 55 |
style_batch = style_batch.to(device)
|
| 56 |
|
|
|
|
| 57 |
loss_content, loss_style = model(content_batch, style_batch)
|
| 58 |
loss_scaled = loss_content + 10 * loss_style
|
|
|
|
|
|
|
| 59 |
loss_scaled.backward()
|
| 60 |
decoder_optimizer.step()
|
|
|
|
| 61 |
total_loss = loss_scaled.item()
|
| 62 |
content_loss = loss_content.item()
|
| 63 |
style_loss = loss_style.item()
|
|
|
|
| 17 |
args = parser.parse_args()
|
| 18 |
|
| 19 |
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
|
| 20 |
+
|
| 21 |
check_point_dir = './check_point/'
|
| 22 |
weights_dir = './weights/'
|
| 23 |
+
|
| 24 |
+
# Prepare Training dataset
|
| 25 |
train_set = TrainSet(args.content_dir, args.style_dir)
|
| 26 |
train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True)
|
| 27 |
+
|
| 28 |
+
# load vgg19 weights
|
| 29 |
vgg_model = torch.load('vgg_normalized.pth')
|
| 30 |
model = AdaINNet(vgg_model).to(device)
|
|
|
|
| 31 |
decoder_optimizer = torch.optim.Adam(model.decoder.parameters(), lr=1e-6)
|
| 32 |
+
|
| 33 |
total_loss, content_loss, style_loss = 0.0, 0.0, 0.0
|
| 34 |
losses = []
|
| 35 |
iteration = 0
|
| 36 |
|
| 37 |
+
# If resume training, load states
|
| 38 |
if args.resume > 0:
|
| 39 |
states = torch.load(check_point_dir + "epoch_" + str(args.resume)+'.pth')
|
| 40 |
model.decoder.load_state_dict(states['decoder'])
|
|
|
|
| 57 |
content_batch = content_batch.to(device)
|
| 58 |
style_batch = style_batch.to(device)
|
| 59 |
|
| 60 |
+
# Feed forward and compute loss
|
| 61 |
loss_content, loss_style = model(content_batch, style_batch)
|
| 62 |
loss_scaled = loss_content + 10 * loss_style
|
| 63 |
+
|
| 64 |
+
# Gradient descent
|
| 65 |
loss_scaled.backward()
|
| 66 |
decoder_optimizer.step()
|
| 67 |
+
|
| 68 |
total_loss = loss_scaled.item()
|
| 69 |
content_loss = loss_content.item()
|
| 70 |
style_loss = loss_style.item()
|