revliter commited on
Commit
be3bca7
·
verified ·
1 Parent(s): 997d42e

Upload test_load_large.py

Browse files
Files changed (1) hide show
  1. test_load_large.py +18 -0
test_load_large.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors.torch import load_file
2
+ from models.InternVideo_next import internvideo_next_large_patch14_224
3
+
4
+ # test loading
5
+ test_model = internvideo_next_large_patch14_224()
6
+ model_ckpt = load_file('internvideo_next_large_ps14_res224.safetensors')
7
+
8
+ msg = test_model.load_state_dict(model_ckpt, strict=False)
9
+ print(msg)
10
+
11
+ # test input
12
+ import torch
13
+
14
+ test_model = test_model.cuda().half()
15
+ input_data = torch.randn(16, 3, 16, 224, 224).cuda().half() # B, C, T, H, W
16
+ output_embedding = test_model(input_data)
17
+
18
+ print(output_embedding.shape) # [16, 4096, 1024]