vfontech commited on
Commit
1e61b7e
·
verified ·
1 Parent(s): 5fa4f12

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -3
README.md CHANGED
@@ -49,8 +49,8 @@ from torchvision.transforms import Compose, ToTensor, Resize, Normalize
49
  from utils.utils import denorm
50
  from model.hub import MultiInputResShiftHub
51
 
52
- model = MultiInputResShiftHub.from_pretrained("vfontech/Multiple-Input-Resshift-VFI")
53
- model.requires_grad_(False).cuda().eval()
54
 
55
  img0_path = r"_data\example_images\frame1.png"
56
  img2_path = r"_data\example_images\frame3.png"
@@ -66,7 +66,8 @@ img0 = transforms(Image.open(img0_path).convert("RGB")).unsqueeze(0).cuda()
66
  img2 = transforms(Image.open(img2_path).convert("RGB")).unsqueeze(0).cuda()
67
  tau = 0.5
68
 
69
- img1 = model.reverse_process([img0, img2], tau)
 
70
 
71
  plt.figure(figsize=(10, 5))
72
  plt.subplot(1, 3, 1)
 
49
  from utils.utils import denorm
50
  from model.hub import MultiInputResShiftHub
51
 
52
+ model = MultiInputResShiftHub.from_pretrained("vfontech/Multiple-Input-Resshift-VFI").cuda()
53
+ model.eval()
54
 
55
  img0_path = r"_data\example_images\frame1.png"
56
  img2_path = r"_data\example_images\frame3.png"
 
66
  img2 = transforms(Image.open(img2_path).convert("RGB")).unsqueeze(0).cuda()
67
  tau = 0.5
68
 
69
+ with torch.no_grad()
70
+ img1 = model.reverse_process([img0, img2], tau)
71
 
72
  plt.figure(figsize=(10, 5))
73
  plt.subplot(1, 3, 1)