Modular torch💌1——load checkpoint

"load weights into state dict"

Posted by fuhao7i on April 17, 2021

Modular loading checkpoint

put your checkpoint path and model name into “checkpoint_path” and “model_name” respectively.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# ------------------------------------------------- #
#   put your checkpoint path and model name into "checkpoint_path" and "model_name" respectively. 
# ------------------------------------------------- #
 
checkpoint_path = " checkpoint path " 
model_name = " model name "  

print('Loading weights into state dict...')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = eval(model_name).state_dict()
pretrained_dict = torch.load(model_path, map_location=device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) ==  np.shape(v)}
model_dict.update(pretrained_dict)
eval(model_name).load_state_dict(model_dict)
print('Finished!')