1. 保存模型
1
| torch.save(model.state_dict(),"./Double.pth")
|
2. 加载训练好的模型
1
2
3
4
| # 创建模型
model=Net()
# 加载预训练模型的参数
model.load_state_dict(torch.load("./Double.pth"))
|
3. 查看训练好的模型
1
2
3
4
5
| import torch
content = torch.load('/root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth')
for key in content:
print(key, content[key].size(),sep=' ')
|
或
1
2
3
4
5
| import torch # 命令行是逐行立即执行的
content = torch.load('/root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth')
for key, value in content:
print(key, value.size(),sep=' ')
|
result:
conv1.weight torch.Size([64, 3, 7, 7])
bn1.running_mean torch.Size([64])
bn1.running_var torch.Size([64])
bn1.weight torch.Size([64])
bn1.bias torch.Size([64])
layer1.0.conv1.weight torch.Size([64, 64, 1, 1])
layer1.0.bn1.running_mean torch.Size([64])
layer1.0.bn1.running_var torch.Size([64])
layer1.0.bn1.weight torch.Size([64])
layer1.0.bn1.bias torch.Size([64])
layer1.0.conv2.weight torch.Size([64, 64, 3, 3])
layer1.0.bn2.running_mean torch.Size([64])
layer1.0.bn2.running_var torch.Size([64])
layer1.0.bn2.weight torch.Size([64])
...
4. 将对应层的预训练权重导入
1
2
3
4
5
| pretrained_dict = torch.load('/content/drive/MyDrive/search/mmdetection/data/resneXt_imagenet_338x600.pth')
model_dict = self.model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.model.load_state_dict(model_dict)
|
#. 引用
- PyTorch使用预训练模型(保存,加载,加载部分,冻结某些参数,修改网络某些层等…)