Commit
·
2cd2753
1
Parent(s):
f33899f
interpolated frames get generated for video
Browse files
.gitignore
CHANGED
|
@@ -2,4 +2,5 @@
|
|
| 2 |
output
|
| 3 |
SuperSloMo.ckpt
|
| 4 |
Test.mp4
|
| 5 |
-
Result_Test
|
|
|
|
|
|
| 2 |
output
|
| 3 |
SuperSloMo.ckpt
|
| 4 |
Test.mp4
|
| 5 |
+
Result_Test
|
| 6 |
+
interpolated_frames
|
frames.py
CHANGED
|
@@ -1,9 +1,5 @@
|
|
| 1 |
import cv2
|
| 2 |
import os
|
| 3 |
-
from PIL import Image
|
| 4 |
-
from torchvision.transforms import transforms, ToTensor
|
| 5 |
-
from torch import tensor
|
| 6 |
-
from torchvision.transforms import ToPILImage,Resize
|
| 7 |
import torch
|
| 8 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 9 |
|
|
@@ -35,13 +31,6 @@ def downsample(video_path, output_dir, target_fps):
|
|
| 35 |
pass
|
| 36 |
|
| 37 |
|
| 38 |
-
def load_frames(path,size=(128,128)) -> tensor: # converts PIL image to tensor on the GPU
|
| 39 |
-
image = Image.open(path).convert('RGB')
|
| 40 |
-
tensor = ToTensor()
|
| 41 |
-
resized_image=Resize(size)(image)
|
| 42 |
-
return tensor(resized_image).unsqueeze(0).to(device)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
|
| 46 |
if __name__ == "__main__": # sets the __name__ variable to __main__ for this script
|
| 47 |
|
|
|
|
| 1 |
import cv2
|
| 2 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 5 |
|
|
|
|
| 31 |
pass
|
| 32 |
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
if __name__ == "__main__": # sets the __name__ variable to __main__ for this script
|
| 36 |
|
main.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from model import UNet
|
| 3 |
from PIL import Image
|
| 4 |
from torchvision.transforms import transforms, ToTensor
|
| 5 |
import torch.nn.functional as F
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
| 7 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
|
| 9 |
def save_frames(tensor, out_path) -> None:
|
|
@@ -15,95 +19,117 @@ def normalize_frames(tensor):
|
|
| 15 |
tensor = tensor.squeeze(0).detach().cpu()
|
| 16 |
tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
|
| 17 |
tensor = (tensor * 255).byte() # Scale to [0, 255]
|
| 18 |
-
tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C]
|
| 19 |
return tensor
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def load_frames(image_path)->torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
transform = transforms.Compose([
|
| 23 |
-
|
|
|
|
| 24 |
])
|
| 25 |
img = Image.open(image_path).convert("RGB")
|
| 26 |
-
tensor = transform(img).unsqueeze(0).to(device)
|
| 27 |
return tensor
|
| 28 |
|
| 29 |
def time_steps(input_fps, output_fps) -> list[float]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if output_fps <= input_fps:
|
| 31 |
return []
|
| 32 |
k = output_fps // input_fps
|
| 33 |
n = k - 1
|
| 34 |
return [i / (n + 1) for i in range(1, n + 1)]
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
interval = time_steps(input_fps, output_fps)
|
| 46 |
-
input_tensor = torch.cat((A, B), dim=1)
|
| 47 |
-
|
| 48 |
with torch.no_grad():
|
| 49 |
flow_output = model_FC(input_tensor)
|
| 50 |
flow_forward = flow_output[:, :2, :, :] # Forward flow
|
| 51 |
flow_backward = flow_output[:, 2:4, :, :] # Backward flow
|
| 52 |
-
|
| 53 |
generated_frames = []
|
| 54 |
with torch.no_grad():
|
| 55 |
for t in interval:
|
| 56 |
t_tensor = torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device)
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
interpolated_frame = warped_A * (1 - t_tensor) + warped_B * t_tensor
|
| 62 |
generated_frames.append(interpolated_frame)
|
| 63 |
-
|
| 64 |
return generated_frames
|
| 65 |
|
| 66 |
|
| 67 |
def warp_frames(frame, flow):
|
| 68 |
b, c, h, w = frame.size()
|
| 69 |
-
|
| 70 |
-
|
| 71 |
if h != flow_h or w != flow_w:
|
| 72 |
frame = F.interpolate(frame, size=(flow_h, flow_w), mode='bilinear', align_corners=True)
|
| 73 |
-
|
| 74 |
grid_y, grid_x = torch.meshgrid(torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij")
|
| 75 |
grid_x = grid_x.float().to(device)
|
| 76 |
grid_y = grid_y.float().to(device)
|
| 77 |
-
|
| 78 |
flow_x = flow[:, 0, :, :]
|
| 79 |
flow_y = flow[:, 1, :, :]
|
| 80 |
x = grid_x.unsqueeze(0) + flow_x
|
| 81 |
y = grid_y.unsqueeze(0) + flow_y
|
| 82 |
-
|
| 83 |
x = 2.0 * x / (flow_w - 1) - 1.0
|
| 84 |
y = 2.0 * y / (flow_h - 1) - 1.0
|
| 85 |
grid = torch.stack((x, y), dim=-1)
|
| 86 |
|
| 87 |
-
warped_frame = F.grid_sample(frame, grid, align_corners=True)
|
| 88 |
return warped_frame
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
def solve():
|
| 92 |
checkpoint = torch.load("SuperSloMo.ckpt")
|
| 93 |
model_FC = UNet(6, 4).to(device) # Initialize flow computation model
|
| 94 |
model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights
|
| 95 |
model_FC.eval()
|
| 96 |
-
|
| 97 |
model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model
|
| 98 |
model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights
|
| 99 |
model_AT.eval()
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
interpolated_frames
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
|
| 108 |
def main():
|
| 109 |
solve()
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
import torch
|
| 3 |
from model import UNet
|
| 4 |
from PIL import Image
|
| 5 |
from torchvision.transforms import transforms, ToTensor
|
| 6 |
import torch.nn.functional as F
|
| 7 |
+
from torch.cuda.amp import autocast
|
| 8 |
+
import os
|
| 9 |
+
import subprocess
|
| 10 |
+
from torchvision.transforms import Resize
|
| 11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 12 |
|
| 13 |
def save_frames(tensor, out_path) -> None:
|
|
|
|
| 19 |
tensor = tensor.squeeze(0).detach().cpu()
|
| 20 |
tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
|
| 21 |
tensor = (tensor * 255).byte() # Scale to [0, 255]
|
| 22 |
+
tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C] height width channels
|
| 23 |
return tensor
|
| 24 |
+
def laod_allframes(frame_dir):
|
| 25 |
+
frames_path = sorted(
|
| 26 |
+
[os.path.join(frame_dir, f) for f in os.listdir(frame_dir) if f.endswith('.png')]
|
| 27 |
+
)
|
| 28 |
+
for frame_path in frames_path:
|
| 29 |
+
yield load_frames(frame_path)
|
| 30 |
def load_frames(image_path)->torch.Tensor:
|
| 31 |
+
'''
|
| 32 |
+
Converts the PIL image(RGB) to a pytorch Tensor and loads into GPU
|
| 33 |
+
:params image_path
|
| 34 |
+
:return: pytorch tensor
|
| 35 |
+
'''
|
| 36 |
transform = transforms.Compose([
|
| 37 |
+
Resize((720,1280)),
|
| 38 |
+
ToTensor()
|
| 39 |
])
|
| 40 |
img = Image.open(image_path).convert("RGB")
|
| 41 |
+
tensor = transform(img).unsqueeze(0).to(device)
|
| 42 |
return tensor
|
| 43 |
|
| 44 |
def time_steps(input_fps, output_fps) -> list[float]:
|
| 45 |
+
'''
|
| 46 |
+
Generates Time intervals to interpolate between frames A and B
|
| 47 |
+
:param input_fps: Video FPS(Original)
|
| 48 |
+
:param output_fps: Target FPS(Output)
|
| 49 |
+
:return: List of intermediate FPS required between 2 Frames A and B
|
| 50 |
+
'''
|
| 51 |
if output_fps <= input_fps:
|
| 52 |
return []
|
| 53 |
k = output_fps // input_fps
|
| 54 |
n = k - 1
|
| 55 |
return [i / (n + 1) for i in range(1, n + 1)]
|
| 56 |
+
def interpolate_video(frames_dir,model_fc,input_fps,ouput_fps,output_dir):
|
| 57 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 58 |
+
count=0
|
| 59 |
+
iterator=laod_allframes(frames_dir)
|
| 60 |
+
try:
|
| 61 |
+
prev_frame=next(iterator)
|
| 62 |
+
for curr_frame in iterator:
|
| 63 |
+
interpolated_frames=interpolate(model_fc,prev_frame,curr_frame,input_fps,ouput_fps)
|
| 64 |
+
save_frames(prev_frame,os.path.join(output_dir,"frame_{}.png".format(count)))
|
| 65 |
+
count+=1
|
| 66 |
+
for frame in interpolated_frames:
|
| 67 |
+
save_frames(frame[:,:3,:,:],os.path.join(output_dir,"frame_{}.png".format(count)))
|
| 68 |
+
count+=1
|
| 69 |
+
prev_frame=curr_frame
|
| 70 |
+
save_frames(prev_frame,os.path.join(output_dir,"frame_{}.png".format(count)))
|
| 71 |
+
except StopIteration:
|
| 72 |
+
print("no more Frames")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def interpolate(model_FC, A, B, input_fps, output_fps)-> list[torch.Tensor]:
|
| 76 |
interval = time_steps(input_fps, output_fps)
|
| 77 |
+
input_tensor = torch.cat((A, B), dim=1) # Concatenate Frame A and B to Compare difference
|
|
|
|
| 78 |
with torch.no_grad():
|
| 79 |
flow_output = model_FC(input_tensor)
|
| 80 |
flow_forward = flow_output[:, :2, :, :] # Forward flow
|
| 81 |
flow_backward = flow_output[:, 2:4, :, :] # Backward flow
|
|
|
|
| 82 |
generated_frames = []
|
| 83 |
with torch.no_grad():
|
| 84 |
for t in interval:
|
| 85 |
t_tensor = torch.tensor([t], dtype=torch.float32).view(1, 1, 1, 1).to(device)
|
| 86 |
+
with autocast():
|
| 87 |
+
warped_A = warp_frames(A, flow_forward * t_tensor)
|
| 88 |
+
warped_B = warp_frames(B, flow_backward * (1 - t_tensor))
|
| 89 |
+
interpolated_frame = warped_A * (1 - t_tensor) + warped_B * t_tensor
|
|
|
|
| 90 |
generated_frames.append(interpolated_frame)
|
|
|
|
| 91 |
return generated_frames
|
| 92 |
|
| 93 |
|
| 94 |
def warp_frames(frame, flow):
|
| 95 |
b, c, h, w = frame.size()
|
| 96 |
+
i,j,flow_h, flow_w = flow.size()
|
|
|
|
| 97 |
if h != flow_h or w != flow_w:
|
| 98 |
frame = F.interpolate(frame, size=(flow_h, flow_w), mode='bilinear', align_corners=True)
|
|
|
|
| 99 |
grid_y, grid_x = torch.meshgrid(torch.arange(0, flow_h), torch.arange(0, flow_w), indexing="ij")
|
| 100 |
grid_x = grid_x.float().to(device)
|
| 101 |
grid_y = grid_y.float().to(device)
|
|
|
|
| 102 |
flow_x = flow[:, 0, :, :]
|
| 103 |
flow_y = flow[:, 1, :, :]
|
| 104 |
x = grid_x.unsqueeze(0) + flow_x
|
| 105 |
y = grid_y.unsqueeze(0) + flow_y
|
|
|
|
| 106 |
x = 2.0 * x / (flow_w - 1) - 1.0
|
| 107 |
y = 2.0 * y / (flow_h - 1) - 1.0
|
| 108 |
grid = torch.stack((x, y), dim=-1)
|
| 109 |
|
| 110 |
+
warped_frame = F.grid_sample(frame, grid, align_corners=True,mode='bilinear', padding_mode='border')
|
| 111 |
return warped_frame
|
| 112 |
+
def frames_to_video(frame_dir,output_video,fps):
|
| 113 |
+
frame_pattern = os.path.join(frame_dir, "frame_.png")
|
| 114 |
+
subprocess.run([
|
| 115 |
+
"ffmpeg", "-framerate", str(fps), "-i", frame_pattern,
|
| 116 |
+
"-c:v", "libx264", "-pix_fmt", "yuv420p", output_video
|
| 117 |
+
])
|
| 118 |
def solve():
|
| 119 |
checkpoint = torch.load("SuperSloMo.ckpt")
|
| 120 |
model_FC = UNet(6, 4).to(device) # Initialize flow computation model
|
| 121 |
model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights
|
| 122 |
model_FC.eval()
|
|
|
|
| 123 |
model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model
|
| 124 |
model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights
|
| 125 |
model_AT.eval()
|
| 126 |
+
frames_dir="output"
|
| 127 |
+
input_fps=60
|
| 128 |
+
output_fps=120
|
| 129 |
+
output_dir="interpolated_frames"
|
| 130 |
+
interpolate_video(frames_dir,model_FC,input_fps,output_fps,output_dir)
|
| 131 |
+
final_video="result.mp4"
|
| 132 |
+
frames_to_video(output_dir,final_video,output_fps)
|
| 133 |
|
| 134 |
def main():
|
| 135 |
solve()
|
model.py
CHANGED
|
@@ -109,14 +109,11 @@ class up(nn.Module):
|
|
| 109 |
self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
|
| 110 |
|
| 111 |
def forward(self, x, skpCn):
|
| 112 |
-
# Upsample x
|
| 113 |
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
|
| 114 |
-
# Match dimensions by cropping the skip connection (skpCn) to match x
|
| 115 |
if x.size(-1) != skpCn.size(-1):
|
| 116 |
skpCn = skpCn[:, :, :, :x.size(-1)]
|
| 117 |
if x.size(-2) != skpCn.size(-2):
|
| 118 |
skpCn = skpCn[:, :, :x.size(-2), :]
|
| 119 |
-
# Concatenate and apply convolutions
|
| 120 |
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
|
| 121 |
x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1)
|
| 122 |
return x
|
|
|
|
| 109 |
self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
|
| 110 |
|
| 111 |
def forward(self, x, skpCn):
|
|
|
|
| 112 |
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
|
|
|
|
| 113 |
if x.size(-1) != skpCn.size(-1):
|
| 114 |
skpCn = skpCn[:, :, :, :x.size(-1)]
|
| 115 |
if x.size(-2) != skpCn.size(-2):
|
| 116 |
skpCn = skpCn[:, :, :x.size(-2), :]
|
|
|
|
| 117 |
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
|
| 118 |
x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1)
|
| 119 |
return x
|