Dali杂货铺🐰8——PyTorch 保存,加载和查看训练好的模型

"save, load, weights, 载入预训练权重"

Posted by fuhao7i on March 16, 2021

1. 保存模型

1
torch.save(model.state_dict(),"./Double.pth")

2. 加载训练好的模型

1
2
3
4
# 创建模型
model=Net()
# 加载预训练模型的参数
model.load_state_dict(torch.load("./Double.pth"))

3. 查看训练好的模型

1
2
3
4
5
import torch 
content = torch.load('/root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth')

for key in content:
    print(key, content[key].size(),sep='      ')

1
2
3
4
5
import torch  # 命令行是逐行立即执行的
content = torch.load('/root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth')

for key, value in content:
    print(key, value.size(),sep='      ')

result:

conv1.weight      torch.Size([64, 3, 7, 7])
bn1.running_mean      torch.Size([64])
bn1.running_var      torch.Size([64])
bn1.weight      torch.Size([64])
bn1.bias      torch.Size([64])
layer1.0.conv1.weight      torch.Size([64, 64, 1, 1])
layer1.0.bn1.running_mean      torch.Size([64])
layer1.0.bn1.running_var      torch.Size([64])
layer1.0.bn1.weight      torch.Size([64])
layer1.0.bn1.bias      torch.Size([64])
layer1.0.conv2.weight      torch.Size([64, 64, 3, 3])
layer1.0.bn2.running_mean      torch.Size([64])
layer1.0.bn2.running_var      torch.Size([64])
layer1.0.bn2.weight      torch.Size([64])
   ...

4. 将对应层的预训练权重导入

1
2
3
4
5
pretrained_dict = torch.load('/content/drive/MyDrive/search/mmdetection/data/resneXt_imagenet_338x600.pth')
model_dict = self.model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.model.load_state_dict(model_dict)

#. 引用

  1. PyTorch使用预训练模型(保存,加载,加载部分,冻结某些参数,修改网络某些层等…)