Traceback (most recent call last): File "train.py", line 41, in <module> net.load_state_dict(state_dict) File "/home/cgq/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2189, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: "module.c1.layer.0.weight", "module.c1.layer.1.weight", "module.c1.layer.1.bias", "module.c1.layer.1.running_mean", "module.c1.layer.1.running_var", "module.c1.layer.4.weight", "module.c1.layer.5.weight", "module.c1.layer.5.bias", "module.c1.layer.5.running_mean", "module.c1.layer.5.running_var", "module.d1.layer.0.weight", "module.d1.layer.1.weight", "module.d1.layer.1.bias", "module.d1.layer.1.running_mean", "module.d1.layer.1.running_var", "module.c2.layer.0.weight", "module.c2.layer.1.weight", "module.c2.layer.1.bias", "module.c2.layer.1.running_mean", "module.c2.layer.1.running_var", "module.c2.layer.4.weight", "module.c2.layer.5.weight", "module.c2.layer.5.bias", "module.c2.layer.5.running_mean", "module.c2.layer.5.running_var", "module.d2.layer.0.weight", "module.d2.layer.1.weight", "module.d2.layer.1.bias", "module.d2.layer.1.running_mean", "module.d2.layer.1.running_var", "module.c3.layer.0.weight", "module.c3.layer.1.weight", "module.c3.layer.1.bias", "module.c3.layer.1.running_mean", "module.c3.layer.1.running_var", "module.c3.layer.4.weight", "module.c3.layer.5.weight", "module.c3.layer.5.bias", "module.c3.layer.5.running_mean", "module.c3.layer.5.running_var", "module.d3.layer.0.weight", "module.d3.layer.1.weight", "module.d3.layer.1.bias", "module.d3.layer.1.running_mean", "module.d3.layer.1.running_var", "module.c4.layer.0.weight", "module.c4.layer.1.weight", "module.c4.layer.1.bias", "module.c4.layer.1.running_mean", "module.c4.layer.1.running_var", "module.c4.layer.4.weight", "module.c4.layer.5.weight", "module.c4.layer.5.bias", "module.c4.layer.5.running_mean", "module.c4.layer.5.running_var", "module.d4.layer.0.weight", "module.d4.layer.1.weight", "module.d4.layer.1.bias", "module.d4.layer.1.running_mean", "module.d4.layer.1.running_var", "module.c5.layer.0.weight", "module.c5.layer.1.weight", "module.c5.layer.1.bias", "module.c5.layer.1.running_mean", "module.c5.layer.1.running_var", "module.c5.layer.4.weight", "module.c5.layer.5.weight", "module.c5.layer.5.bias", "module.c5.layer.5.running_mean", "module.c5.layer.5.running_var", "module.u1.layer.weight", "module.u1.layer.bias", "module.c6.layer.0.weight", "module.c6.layer.1.weight", "module.c6.layer.1.bias", "module.c6.layer.1.running_mean", "module.c6.layer.1.running_var", "module.c6.layer.4.weight", "module.c6.layer.5.weight", "module.c6.layer.5.bias", "module.c6.layer.5.running_mean", "module.c6.layer.5.running_var", "module.u2.layer.weight", "module.u2.layer.bias", "module.c7.layer.0.weight", "module.c7.layer.1.weight", "module.c7.layer.1.bias", "module.c7.layer.1.running_mean", "module.c7.layer.1.running_var", "module.c7.layer.4.weight", "module.c7.layer.5.weight", "module.c7.layer.5.bias", "module.c7.layer.5.running_mean", "module.c7.layer.5.running_var", "module.u3.layer.weight", "module.u3.layer.bias", "module.c8.layer.0.weight", "module.c8.layer.1.weight", "module.c8.layer.1.bias", "module.c8.layer.1.running_mean", "module.c8.layer.1.running_var", "module.c8.layer.4.weight", "module.c8.layer.5.weight", "module.c8.layer.5.bias", "module.c8.layer.5.running_mean", "module.c8.layer.5.running_var", "module.u4.layer.weight", "module.u4.layer.bias", "module.c9.layer.0.weight", "module.c9.layer.1.weight", "module.c9.layer.1.bias", "module.c9.layer.1.running_mean", "module.c9.layer.1.running_var", "module.c9.layer.4.weight", "module.c9.layer.5.weight", "module.c9.layer.5.bias", "module.c9.layer.5.running_mean", "module.c9.layer.5.running_var", "module.out.weight", "module.out.bias". Unexpected key(s) in state_dict: "c1.layer.0.weight", "c1.layer.1.weight", "c1.layer.1.bias", "c1.layer.1.running_mean", "c1.layer.1.running_var", "c1.layer.1.num_batches_tracked", "c1.layer.4.weight", "c1.layer.5.weight", "c1.layer.5.bias", "c1.layer.5.running_mean", "c1.layer.5.running_var", "c1.layer.5.num_batches_tracked", "d1.layer.0.weight", "d1.layer.1.weight", "d1.layer.1.bias", "d1.layer.1.running_mean", "d1.layer.1.running_var", "d1.layer.1.num_batches_tracked", "c2.layer.0.weight", "c2.layer.1.weight", "c2.layer.1.bias", "c2.layer.1.running_mean", "c2.layer.1.running_var", "c2.layer.1.num_batches_tracked", "c2.layer.4.weight", "c2.layer.5.weight", "c2.layer.5.bias", "c2.layer.5.running_mean", "c2.layer.5.running_var", "c2.layer.5.num_batches_tracked", "d2.layer.0.weight", "d2.layer.1.weight", "d2.layer.1.bias", "d2.layer.1.running_mean", "d2.layer.1.running_var", "d2.layer.1.num_batches_tracked", "c3.layer.0.weight", "c3.layer.1.weight", "c3.layer.1.bias", "c3.layer.1.running_mean", "c3.layer.1.running_var", "c3.layer.1.num_batches_tracked", "c3.layer.4.weight", "c3.layer.5.weight", "c3.layer.5.bias", "c3.layer.5.running_mean", "c3.layer.5.running_var", "c3.layer.5.num_batches_tracked", "d3.layer.0.weight", "d3.layer.1.weight", "d3.layer.1.bias", "d3.layer.1.running_mean", "d3.layer.1.running_var", "d3.layer.1.num_batches_tracked", "c4.layer.0.weight", "c4.layer.1.weight", "c4.layer.1.bias", "c4.layer.1.running_mean", "c4.layer.1.running_var", "c4.layer.1.num_batches_tracked", "c4.layer.4.weight", "c4.layer.5.weight", "c4.layer.5.bias", "c4.layer.5.running_mean", "c4.layer.5.running_var", "c4.layer.5.num_batches_tracked", "d4.layer.0.weight", "d4.layer.1.weight", "d4.layer.1.bias", "d4.layer.1.running_mean", "d4.layer.1.running_var", "d4.layer.1.num_batches_tracked", "c5.layer.0.weight", "c5.layer.1.weight", "c5.layer.1.bias", "c5.layer.1.running_mean", "c5.layer.1.running_var", "c5.layer.1.num_batches_tracked", "c5.layer.4.weight", "c5.layer.5.weight", "c5.layer.5.bias", "c5.layer.5.running_mean", "c5.layer.5.running_var", "c5.layer.5.num_batches_tracked", "u1.layer.weight", "u1.layer.bias", "c6.layer.0.weight", "c6.layer.1.weight", "c6.layer.1.bias", "c6.layer.1.running_mean", "c6.layer.1.running_var", "c6.layer.1.num_batches_tracked", "c6.layer.4.weight", "c6.layer.5.weight", "c6.layer.5.bias", "c6.layer.5.running_mean", "c6.layer.5.running_var", "c6.layer.5.num_batches_tracked", "u2.layer.weight", "u2.layer.bias", "c7.layer.0.weight", "c7.layer.1.weight", "c7.layer.1.bias", "c7.layer.1.running_mean", "c7.layer.1.running_var", "c7.layer.1.num_batches_tracked", "c7.layer.4.weight", "c7.layer.5.weight", "c7.layer.5.bias", "c7.layer.5.running_mean", "c7.layer.5.running_var", "c7.layer.5.num_batches_tracked", "u3.layer.weight", "u3.layer.bias", "c8.layer.0.weight", "c8.layer.1.weight", "c8.layer.1.bias", "c8.layer.1.running_mean", "c8.layer.1.running_var", "c8.layer.1.num_batches_tracked", "c8.layer.4.weight", "c8.layer.5.weight", "c8.layer.5.bias", "c8.layer.5.running_mean", "c8.layer.5.running_var", "c8.layer.5.num_batches_tracked", "u4.layer.weight", "u4.layer.bias", "c9.layer.0.weight", "c9.layer.1.weight", "c9.layer.1.bias", "c9.layer.1.running_mean", "c9.layer.1.running_var", "c9.layer.1.num_batches_tracked", "c9.layer.4.weight", "c9.layer.5.weight", "c9.layer.5.bias", "c9.layer.5.running_mean", "c9.layer.5.running_var", "c9.layer.5.num_batches_tracked", "out.weight", "out.bias".
这个问题是由于在之前训练模型的时候,在加载之前的模型的时候未包装在 nn.DataParallel
中的原始模型权重。故此保存的模型中,键名中没有 module.
前缀,但是在之后训练时,又使用多卡,包装在在nn.DataParallel
中,这会导致键名不匹配的问题。
针对此问题解决办法,
- 加载之前单卡训练的模型权重并添加
module.
前缀:这样可以与当前nn.DataParallel
包装的模型兼容。 - 在保存模型时去掉
module.
前缀,以便之后可以在不使用nn.DataParallel
时也可以直接加载。
import os
import tqdm
from torch import nn, optim
import torch
from torch.utils.data import DataLoader
from data import *
from net import *
from torchvision.utils import save_image
# 设置使用的显卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' # 指定使用0,1号显卡
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path = 'params/unet.pth'
data_path = r'data/'
save_path = 'train_image'
# 修改加载权重的代码,在加载权重时检查并修改键名:
def remove_module_prefix(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v # 去除 'module.' 前缀
else:
new_state_dict[k] = v
return new_state_dict
def add_module_prefix(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
if not k.startswith('module.'):
new_state_dict['module.' + k] = v # 添加 'module.' 前缀
else:
new_state_dict[k] = v
return new_state_dict
if __name__ == '__main__':
num_classes = 255 # +1是背景也为一类
data_loader = DataLoader(MyDataset(data_path), batch_size=5, shuffle=True)
net = UNet(num_classes)
# 使用 DataParallel 包装模型
net = nn.DataParallel(net)
net = net.to(device)
if os.path.exists(weight_path):
state_dict = torch.load(weight_path)
# 检查并添加 'module.' 前缀
if not list(state_dict.keys())[0].startswith('module.'):
state_dict = add_module_prefix(state_dict)
net.load_state_dict(state_dict)
print('Successfully loaded weights!')
else:
print('Failed to load weights.')
# 优化器和损失函数
opt = optim.Adam(net.parameters())
loss_fun = nn.CrossEntropyLoss()
# 添加学习率调度器
scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.1)
epoch = 1
while epoch < 200:
for i, (image, segment_image) in enumerate(tqdm.tqdm(data_loader)):
image, segment_image = image.to(device), segment_image.to(device)
out_image = net(image)
train_loss = loss_fun(out_image, segment_image.long())
opt.zero_grad()
train_loss.backward()
opt.step()
if i % 5 == 0:
print(f'{epoch}-{i}-train_loss===>>{train_loss.item()}')
_image = image[0]
_segment_image = torch.unsqueeze(segment_image[0], 0) * 255
_out_image = torch.argmax(out_image[0], dim=0).unsqueeze(0) * 255
# 将 _segment_image 和 _out_image 转换为三通道
_segment_image = _segment_image.repeat(3, 1, 1)
_out_image = _out_image.repeat(3, 1, 1)
img = torch.stack([_image, _segment_image, _out_image], dim=0) # 将三张图片进行拼接显示
save_image(img, f'{save_path}/{i}.png')
# 调度器步进更新
scheduler.step()
if epoch % 20 == 0:
# 保存模型时使用去除前缀的 state_dict
state_dict = net.state_dict()
state_dict = remove_module_prefix(state_dict)
torch.save(state_dict, f'{weight_path}_epoch_{epoch}.pth')
print('Save successfully!')
epoch += 1
解释:
remove_module_prefix
:用于从state_dict
中移除module.
前缀。add_module_prefix
:用于在state_dict
中添加module.
前缀。- 加载模型权重时:
- 先检查第一个键是否以
module.
开头,如果不是,则调用add_module_prefix
添加前缀,然后加载权重。
- 先检查第一个键是否以
- 保存模型权重时:
- 调用
remove_module_prefix
移除前缀,以便之后可以在不使用nn.DataParallel
时直接加载。
- 调用
针对后面,模型训练的时候。 我们可以在保存时,移除 module.
前缀。 在加载时,添加上module.
前缀。 这样,保存的模型,不论我们在单卡上进行训练,或者多卡上进行训练时,都可以正常加载。
其实明白原理,就很简单,无非就是添加,或删除 module问题。