01. Pytorch 筛选部分网络层权重参数加载
1. 从权重文件中筛选并加载
1 2 3 4 5 6
| pretrained_weight = torch.load('weight.pth') new_net = My_Net() new_net_dict = new_net.state_dict() pretrained_dict = {k: v for k, v in pretrained_weight.items() if k in net_dict} new_net_dict.update(pretrained_dict) new_net.load_state_dict(new_net_dict)
|
上面代码筛选
步骤需要两个网络的网络层命名一样才能筛选成功,如果不一样参考下面网络层名映射的方法。
2. 从Pytorch官方模型中筛选部分层并加载
1. 加载预训练模型
1 2 3 4 5 6 7 8 9
| import torch import torchvision as tv
pretrained_net = tv.models.alexnet(pretrained=True) pretrained_weight = pretrained_net.state_dict()
print(pretrained_net) print(pretrained_weight.keys())
|
部分输出截取如下:
选取红框部分的网络参数加载
2. 自定义网络
根据已有模型的输入输出设计网络
1 2 3 4 5 6 7 8 9 10 11
| class My_net(torch.nn.Module): def __init__(self): super(My_net, self).__init__() self.layer = torch.nn.Sequential( torch.nn.Linear(in_features=4096, out_features=4096, bias=True), torch.nn.ReLU(inplace=True), torch.nn.Linear(in_features=4096, out_features=1000, bias=True) ) def forward(self, x): return self.layer(x)
|
3. 实例化网络并确定映射关系
因为权重参数是通过字典存储的,当你重新定义一样的网络(即时输入输出相同)但是每层的名称会不一样。导致无法加载。
1 2 3 4 5 6 7 8 9 10 11 12 13
| net = My_net() net_dict = net.state_dict() net_dict.keys() need_weights = ['classifier.4.weight', 'classifier.4.bias', 'classifier.6.weight', 'classifier.6.bias'] layer_name_map = { a: b for a,b in zip(need_weights, net_dict.keys())} print(layer_name_map) """ outs: {'classifier.4.weight': 'layer.0.weight', 'classifier.4.bias': 'layer.0.bias', 'classifier.6.weight': 'layer.2.weight', 'classifier.6.bias': 'layer.2.bias'} """
|
4. 筛选并加载参数
1 2 3 4 5 6 7
| pretrained_dict = {} for k, v in pretrained_weight.items(): if k in need_weights: pretrained_dict[layer_name_map[k]] = v
net.load_state_dict(pretrained_dict)
|
END