import torch pt_dict = torch.load('DeMoE.pt', map_location='cpu') print(pt_dict['params'].keys()) print(len(pt_dict['params'].keys()))