Modular loading checkpoint
put your checkpoint path and model name into “checkpoint_path” and “model_name” respectively.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# ------------------------------------------------- #
# put your checkpoint path and model name into "checkpoint_path" and "model_name" respectively.
# ------------------------------------------------- #
checkpoint_path = " checkpoint path "
model_name = " model name "
print('Loading weights into state dict...')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_dict = eval(model_name).state_dict()
pretrained_dict = torch.load(model_path, map_location=device)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
model_dict.update(pretrained_dict)
eval(model_name).load_state_dict(model_dict)
print('Finished!')