kuldeep0204 commited on
Commit
c13dd19
·
verified ·
1 Parent(s): 84caf82

Create state_dict

Browse files
Files changed (1) hide show
  1. state_dict +17 -0
state_dict ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ # Try loading as state_dict first
3
+ state_dict = torch.load(model_file, map_location=device)
4
+ if isinstance(state_dict, dict):
5
+ print("Loaded state_dict, initializing model...")
6
+ from my_model import MyModel # import your model class
7
+ model = MyModel(...) # init with same architecture
8
+ model.load_state_dict(state_dict)
9
+ else:
10
+ print("Loaded full model object.")
11
+ model = state_dict
12
+ except Exception as e:
13
+ print("state_dict load failed, retrying with weights_only=False:", e)
14
+ model = torch.load(model_file, map_location=device, weights_only=False)
15
+
16
+ model.to(device)
17
+ model.eval()