Commit
·
f90ddf2
1
Parent(s):
7917eea
image is getting generated but it's very ass
Browse files- __pycache__/model.cpython-310.pyc +0 -0
- main.py +58 -46
- model.py +11 -25
__pycache__/model.cpython-310.pyc
CHANGED
|
Binary files a/__pycache__/model.cpython-310.pyc and b/__pycache__/model.cpython-310.pyc differ
|
|
|
main.py
CHANGED
|
@@ -1,68 +1,80 @@
|
|
| 1 |
import torch
|
| 2 |
from model import UNet
|
| 3 |
-
from frames import load_frames
|
| 4 |
from PIL import Image
|
| 5 |
-
from torchvision.transforms import transforms,ToTensor
|
| 6 |
|
| 7 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
image=
|
|
|
|
| 11 |
image.save(out_path)
|
|
|
|
| 12 |
def normalize_frames(tensor):
|
| 13 |
-
tensor=tensor.squeeze(0).detach().cpu()
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
tensor=(
|
| 17 |
-
tensor
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
return tensor
|
| 20 |
-
|
| 21 |
-
|
|
|
|
| 22 |
return []
|
| 23 |
-
k=output_fps//input_fps
|
| 24 |
-
n=k-1
|
| 25 |
-
return [i/n+1 for i in range(1,n+1)]
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
| 29 |
return tensor
|
| 30 |
-
required=target-current_channels
|
| 31 |
-
extra=torch.zeros(batch_size,required,height,width,device=tensor.device,dtype=tensor.dtype)
|
| 32 |
-
return torch.cat((tensor,extra),dim=1)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
| 37 |
with torch.no_grad():
|
| 38 |
-
flow_output=model_FC(input_tensor)
|
| 39 |
-
flow_output=expand_channels(flow_output,20)
|
| 40 |
-
|
|
|
|
| 41 |
with torch.no_grad():
|
| 42 |
for i in interval:
|
| 43 |
-
inter_tensor=torch.tensor([i],dtype=torch.float32).unsqueeze(0).to(device)
|
| 44 |
-
interpolated_frame=model_AT(flow_output,inter_tensor)
|
| 45 |
generated_frames.append(interpolated_frame)
|
| 46 |
return generated_frames
|
| 47 |
|
| 48 |
def solve():
|
| 49 |
-
checkpoint=torch.load("SuperSloMo.ckpt")
|
| 50 |
-
model_FC=UNet(6,4)
|
| 51 |
-
model_FC
|
| 52 |
-
model_FC.load_state_dict(checkpoint["state_dictFC"]) # loading all weights from model
|
| 53 |
-
model_AT=UNet(20,5)
|
| 54 |
-
model_AT.load_state_dict(checkpoint["state_dictAT"],strict=False)
|
| 55 |
-
model_AT=model_AT.to(device)
|
| 56 |
-
model_AT.eval()
|
| 57 |
model_FC.eval()
|
| 58 |
-
A=load_frames("output/1.png")
|
| 59 |
-
B=load_frames("output/69.png")
|
| 60 |
-
interpolated_frames=interpolate(model_FC,model_AT,A,B,30,60)
|
| 61 |
-
for index,value in enumerate(interpolated_frames):
|
| 62 |
-
save_frames(value[:,:3,:,:],"Result_Test/image{}.png".format(index+1))
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def main():
|
| 66 |
solve()
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from model import UNet
|
|
|
|
| 3 |
from PIL import Image
|
| 4 |
+
from torchvision.transforms import transforms, ToTensor
|
| 5 |
|
| 6 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 7 |
+
|
| 8 |
+
def save_frames(tensor, out_path) -> None:
|
| 9 |
+
image = normalize_frames(tensor)
|
| 10 |
+
image = Image.fromarray(image)
|
| 11 |
image.save(out_path)
|
| 12 |
+
|
| 13 |
def normalize_frames(tensor):
|
| 14 |
+
tensor = tensor.squeeze(0).detach().cpu()
|
| 15 |
+
tensor = torch.clamp(tensor, 0.0, 1.0) # Ensure values are in [0, 1]
|
| 16 |
+
tensor = (tensor * 255).byte() # Scale to [0, 255]
|
| 17 |
+
tensor = tensor.permute(1, 2, 0).numpy() # Convert to [H, W, C]
|
| 18 |
+
return tensor
|
| 19 |
+
|
| 20 |
+
def load_frames(image_path):
|
| 21 |
+
transform = transforms.Compose([
|
| 22 |
+
ToTensor() # Converts to [0, 1] range and [C, H, W]
|
| 23 |
+
])
|
| 24 |
+
img = Image.open(image_path).convert("RGB")
|
| 25 |
+
tensor = transform(img).unsqueeze(0).to(device) # Add batch dimension
|
| 26 |
return tensor
|
| 27 |
+
|
| 28 |
+
def time_steps(input_fps, output_fps) -> list[float]:
|
| 29 |
+
if output_fps <= input_fps:
|
| 30 |
return []
|
| 31 |
+
k = output_fps // input_fps
|
| 32 |
+
n = k - 1
|
| 33 |
+
return [i / (n + 1) for i in range(1, n + 1)]
|
| 34 |
+
|
| 35 |
+
def expand_channels(tensor, target):
|
| 36 |
+
batch_size, current_channels, height, width = tensor.shape
|
| 37 |
+
if current_channels >= target:
|
| 38 |
return tensor
|
| 39 |
+
required = target - current_channels
|
| 40 |
+
extra = torch.zeros(batch_size, required, height, width, device=tensor.device, dtype=tensor.dtype)
|
| 41 |
+
return torch.cat((tensor, extra), dim=1)
|
| 42 |
+
|
| 43 |
+
def interpolate(model_FC, model_AT, A, B, input_fps, output_fps):
|
| 44 |
+
interval = time_steps(input_fps, output_fps)
|
| 45 |
+
input_tensor = torch.cat((A, B), dim=1)
|
| 46 |
+
print(f"Time intervals: {interval}")
|
| 47 |
with torch.no_grad():
|
| 48 |
+
flow_output = model_FC(input_tensor) # Output shape: [1, 4, H, W]
|
| 49 |
+
flow_output = expand_channels(flow_output, 20) # Expand to 20 channels
|
| 50 |
+
|
| 51 |
+
generated_frames = []
|
| 52 |
with torch.no_grad():
|
| 53 |
for i in interval:
|
| 54 |
+
inter_tensor = torch.tensor([i], dtype=torch.float32).unsqueeze(0).to(device)
|
| 55 |
+
interpolated_frame = model_AT(flow_output, inter_tensor)
|
| 56 |
generated_frames.append(interpolated_frame)
|
| 57 |
return generated_frames
|
| 58 |
|
| 59 |
def solve():
|
| 60 |
+
checkpoint = torch.load("SuperSloMo.ckpt")
|
| 61 |
+
model_FC = UNet(6, 4).to(device) # Initialize flow computation model
|
| 62 |
+
model_FC.load_state_dict(checkpoint["state_dictFC"]) # Load weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
model_FC.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
model_AT = UNet(20, 5).to(device) # Initialize auxiliary task model
|
| 66 |
+
model_AT.load_state_dict(checkpoint["state_dictAT"], strict=False) # Load weights
|
| 67 |
+
model_AT.eval()
|
| 68 |
+
|
| 69 |
+
A = load_frames("output/1.png")
|
| 70 |
+
B = load_frames("output/69.png")
|
| 71 |
+
interpolated_frames = interpolate(model_FC, model_AT, A, B, 30, 60)
|
| 72 |
+
|
| 73 |
+
for index, value in enumerate(interpolated_frames):
|
| 74 |
+
save_frames(value[:, :3, :, :], f"Result_Test/image{index + 1}.png") # Save only RGB channels
|
| 75 |
|
| 76 |
def main():
|
| 77 |
solve()
|
| 78 |
+
|
| 79 |
+
if __name__ == "__main__":
|
| 80 |
+
main()
|
model.py
CHANGED
|
@@ -107,35 +107,21 @@ class up(nn.Module):
|
|
| 107 |
self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1)
|
| 108 |
# (2 * outChannels) is used for accommodating skip connection.
|
| 109 |
self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
|
| 110 |
-
|
| 111 |
-
def forward(self, x, skpCn):
|
| 112 |
-
"""
|
| 113 |
-
Returns output tensor after passing input `x` to the neural network
|
| 114 |
-
block.
|
| 115 |
-
|
| 116 |
-
Parameters
|
| 117 |
-
----------
|
| 118 |
-
x : tensor
|
| 119 |
-
input to the NN block.
|
| 120 |
-
skpCn : tensor
|
| 121 |
-
skip connection input to the NN block.
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
#
|
| 132 |
-
x = F.leaky_relu(self.conv1(x), negative_slope
|
| 133 |
-
|
| 134 |
-
x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope = 0.1)
|
| 135 |
return x
|
| 136 |
|
| 137 |
|
| 138 |
-
|
| 139 |
class UNet(nn.Module):
|
| 140 |
"""
|
| 141 |
A class for creating UNet like architecture as specified by the
|
|
|
|
| 107 |
self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1)
|
| 108 |
# (2 * outChannels) is used for accommodating skip connection.
|
| 109 |
self.conv2 = nn.Conv2d(2 * outChannels, outChannels, 3, stride=1, padding=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
def forward(self, x, skpCn):
|
| 112 |
+
# Upsample x
|
| 113 |
+
x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
|
| 114 |
+
# Match dimensions by cropping the skip connection (skpCn) to match x
|
| 115 |
+
if x.size(-1) != skpCn.size(-1):
|
| 116 |
+
skpCn = skpCn[:, :, :, :x.size(-1)]
|
| 117 |
+
if x.size(-2) != skpCn.size(-2):
|
| 118 |
+
skpCn = skpCn[:, :, :x.size(-2), :]
|
| 119 |
+
# Concatenate and apply convolutions
|
| 120 |
+
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
|
| 121 |
+
x = F.leaky_relu(self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1)
|
|
|
|
| 122 |
return x
|
| 123 |
|
| 124 |
|
|
|
|
| 125 |
class UNet(nn.Module):
|
| 126 |
"""
|
| 127 |
A class for creating UNet like architecture as specified by the
|