Pytorch载入部分参数并冻结
参考资料
- pytorch 模型部分参数的加载
- Pytorch中,只导入部分模型参数的做法
- Correct way to freeze layers
- Pytorch自由载入部分模型参数并冻结
- pytorch冻结部分参数训练另一部分
- PyTorch更新部分网络,其他不更新
- Pytorch固定部分参数(只训练部分层)
加载部分参数
如果加载现有模型的所有参数,我们常使用的是代码如下:
1 | torch.load(model.state_dict()) |
在训练过程中,我们常常会使用预训练模型,有时我们是在自己的模型中加入别人的某些模块,或者对别人的模型进行局部修改,这个时候再使用torch.load(model.state_dict())
,就会出现类似这些的错误:RuntimeError: Error(s) in loading state_dict for Net:Missing key(s) in state_dict:xxx
。出现这个错误就是某些参数缺失或者不匹配。
保持原来网络层的名称和结构不变
现有模型中引入的那部分网络结构的网络层的名称和结构保持不变,这时候加载参数的代码很简单。
1 | # 加载引入的网络模型 |
引入的网络层名称发生修改
这个时候再直接使用上面的加载方法,会导致部分key的value无法实现更新。 我就曾在这个位置犯过很严重的错误。首先我定义了AttentionResNet
,这是一个UNet来实现图像分割,然后在另一个模型中我使用了这个模型self.attention_map = AttentionResNet(XXX)
。因为我在引用的过程中并没有对AttentionResNet
那部分代码进行修改,所以本能的觉得这部分网络层的名称是相同的,所以加载这部分参数时,我直接使用了上面的方法。这个错误隐藏了差不多一个星期。直到我开始冻结这部分参数进行训练时,发现情况不对。因为我在输出attention_map
的特征图时,我发现它是一张全黑图(像素全为0),这表示加载的参数不对,然后我尝试输出pretrained_dict
时,它是一个空字典。然后继续输出pretrained_dict.keys()
(未修改之前的pretrained_dict
)和model_dict.keys()
发现预期相同的那部分key中都多了一部分attention_map.
。问题主要出在self.attention_map = AttentionResNet(XXX)
这一句,它使原有的网络层名称都加了个前缀attention_map.
,知道了错误,修改起来很简单。
1 | # 加载引入的网络模型 |
其实我这个位置的修改有点投机,更加常规的方法是: 引用自Pytorch自由载入部分模型参数并冻结
我们看出只要构建一个字典,使得字典的keys和我们自己创建的网络相同,我们在从各种预训练网络把想要的参数对着新的keys填进去就可以有一个新的state_dict了,这样我们就可以load这个新的state_dict,这是最普适的方法适用于所有的网络变化。
先输出两个模型的参数字典,观察需要加载的那部分参数所处的位置,然后利用for循环构建新的字典。
冻结参数
- 将需要固定的那部分参数的
requires_grad
置为False. - 在优化器中加入filter根据
requires_grad
进行过滤.
ps: 解决AttributeError: ‘NoneType’ object has no attribute ‘data’
问题的一种思路就是冻结参数,参考博客 代码如下:
1 | # requires_grad置为False |
当需要冻结的那部分参数的网络层名称不太明确时,可以采用pytorch冻结部分参数训练另一部分的思路,打印出所有网络层,通过参数名称进行冻结。