Pytorch载入部分参数并冻结

参考资料

  1. pytorch 模型部分参数的加载
  2. Pytorch中,只导入部分模型参数的做法
  3. Correct way to freeze layers
  4. Pytorch自由载入部分模型参数并冻结
  5. pytorch冻结部分参数训练另一部分
  6. PyTorch更新部分网络,其他不更新
  7. Pytorch固定部分参数(只训练部分层)

加载部分参数

如果加载现有模型的所有参数,我们常使用的是代码如下:

torch.load(model.state_dict())

在训练过程中,我们常常会使用预训练模型,有时我们是在自己的模型中加入别人的某些模块,或者对别人的模型进行局部修改,这个时候再使用torch.load(model.state_dict()),就会出现类似这些的错误:RuntimeError: Error(s) in loading state_dict for Net:Missing key(s) in state_dict:xxx。出现这个错误就是某些参数缺失或者不匹配。

保持原来网络层的名称和结构不变

现有模型中引入的那部分网络结构的网络层的名称和结构保持不变,这时候加载参数的代码很简单。

# 加载引入的网络模型
model_path = "xxx"
checkpoint = torch.load(os.path.join(model_path, map_location=torch.device('cpu'))
pretrained_dict = checkpoint['net']
# 获取现有模型的参数字典
model_dict =  model.state_dict()
# 获取两个模型相同网络层的参数字典
state_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict.keys()}
# update必不可少,实现相同key的value同步
model_dict.update(state_dict)
# 加载模型部分参数
model.load_state_dict(model_dict)

引入的网络层名称发生修改

这个时候再直接使用上面的加载方法,会导致部分key的value无法实现更新。

我就曾在这个位置犯过很严重的错误。首先我定义了AttentionResNet,这是一个UNet来实现图像分割,然后在另一个模型中我使用了这个模型self.attention_map = AttentionResNet(XXX)。因为我在引用的过程中并没有对AttentionResNet那部分代码进行修改,所以本能的觉得这部分网络层的名称是相同的,所以加载这部分参数时,我直接使用了上面的方法。这个错误隐藏了差不多一个星期。直到我开始冻结这部分参数进行训练时,发现情况不对。因为我在输出attention_map的特征图时,我发现它是一张全黑图(像素全为0),这表示加载的参数不对,然后我尝试输出pretrained_dict时,它是一个空字典。然后继续输出pretrained_dict.keys()(未修改之前的pretrained_dict)和model_dict.keys()发现预期相同的那部分key中都多了一部分attention_map.。问题主要出在self.attention_map = AttentionResNet(XXX)这一句,它使原有的网络层名称都加了个前缀attention_map.,知道了错误,修改起来很简单。

# 加载引入的网络模型
model_path = "xxx"
checkpoint = torch.load(os.path.join(model_path, map_location=torch.device('cpu'))
pretrained_dict = checkpoint['net']
# 获取现有模型的参数字典
model_dict =  model.state_dict()
# 获取两个模型相同网络层的参数字典
state_dict = {'attention_map.' + k:v for k,v in pretrained_dict.items() if 'attention_map.' + k in model_dict.keys()}
# update必不可少,实现相同key的value同步
model_dict.update(state_dict)
# 加载模型部分参数
model.load_state_dict(model_dict)

其实我这个位置的修改有点投机,更加常规的方法是:

引用自Pytorch自由载入部分模型参数并冻结

我们看出只要构建一个字典,使得字典的keys和我们自己创建的网络相同,我们在从各种预训练网络把想要的参数对着新的keys填进去就可以有一个新的state_dict了,这样我们就可以load这个新的state_dict,这是最普适的方法适用于所有的网络变化。

先输出两个模型的参数字典,观察需要加载的那部分参数所处的位置,然后利用for循环构建新的字典。

冻结参数

  1. 将需要固定的那部分参数的requires_grad置为False.
  2. 在优化器中加入filter根据requires_grad进行过滤.

ps: 解决AttributeError: ‘NoneType’ object has no attribute ‘data’问题的一种思路就是冻结参数,参考博客

代码如下:

# requires_grad置为False
for p in net.XXX.parameters():
    p.requires_grad = False

# filter
optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

当需要冻结的那部分参数的网络层名称不太明确时,可以采用pytorch冻结部分参数训练另一部分的思路,打印出所有网络层,通过参数名称进行冻结。

转载请注明:Onwaier‘s Blog » Pytorch载入部分参数并冻结