01. Pytorch 筛选部分网络层权重参数加载

01. Pytorch 筛选部分网络层权重参数加载

1. 从权重文件中筛选并加载

1
2
3
4
5
6
pretrained_weight = torch.load('weight.pth') # 加载预训练的全部权重参数
new_net = My_Net() # 定义自己的部分网络
new_net_dict = new_net.state_dict() # 获取自己网络的权重参数
pretrained_dict = {k: v for k, v in pretrained_weight.items() if k in net_dict} # 筛选
new_net_dict.update(pretrained_dict) # 更新自己网络的权重参数
new_net.load_state_dict(new_net_dict) # 重新加载进自己的网络中

上面代码筛选步骤需要两个网络的网络层命名一样才能筛选成功,如果不一样参考下面网络层名映射的方法。

2. 从Pytorch官方模型中筛选部分层并加载

1. 加载预训练模型

1
2
3
4
5
6
7
8
9
import torch
import torchvision as tv

pretrained_net = tv.models.alexnet(pretrained=True)
pretrained_weight = pretrained_net.state_dict()

print(pretrained_net)
print(pretrained_weight.keys())

部分输出截取如下:

image-20220220163646781

选取红框部分的网络参数加载

2. 自定义网络

根据已有模型的输入输出设计网络

1
2
3
4
5
6
7
8
9
10
11
class My_net(torch.nn.Module):
def __init__(self):
super(My_net, self).__init__()
self.layer = torch.nn.Sequential(
torch.nn.Linear(in_features=4096, out_features=4096, bias=True),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(in_features=4096, out_features=1000, bias=True)
)

def forward(self, x):
return self.layer(x)

3. 实例化网络并确定映射关系

因为权重参数是通过字典存储的,当你重新定义一样的网络(即时输入输出相同)但是每层的名称会不一样。导致无法加载。

1
2
3
4
5
6
7
8
9
10
11
12
13
net = My_net()
net_dict = net.state_dict()
net_dict.keys()
need_weights = ['classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias']
layer_name_map = { a: b for a,b in zip(need_weights, net_dict.keys())} # 网络层名称映射
print(layer_name_map)
"""
outs:
{'classifier.4.weight': 'layer.0.weight',
'classifier.4.bias': 'layer.0.bias',
'classifier.6.weight': 'layer.2.weight',
'classifier.6.bias': 'layer.2.bias'}
"""

4. 筛选并加载参数

1
2
3
4
5
6
7
pretrained_dict = {}
for k, v in pretrained_weight.items():
if k in need_weights:
# 通过映射获取自己网络的网络名
pretrained_dict[layer_name_map[k]] = v

net.load_state_dict(pretrained_dict)

END

Post Author: jasonyang