深度学习模型训练中.module 后缀问题。

Traceback (most recent call last): File "train.py", line 41, in <module> net.load_state_dict(state_dict) File "/home/cgq/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2189, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for DataParallel: Missing key(s) in state_dict: "module.c1.layer.0.weight", "module.c1.layer.1.weight", "module.c1.layer.1.bias", "module.c1.layer.1.running_mean", "module.c1.layer.1.running_var", "module.c1.layer.4.weight", "module.c1.layer.5.weight", "module.c1.layer.5.bias", "module.c1.layer.5.running_mean", "module.c1.layer.5.running_var", "module.d1.layer.0.weight", "module.d1.layer.1.weight", "module.d1.layer.1.bias", "module.d1.layer.1.running_mean", "module.d1.layer.1.running_var", "module.c2.layer.0.weight", "module.c2.layer.1.weight", "module.c2.layer.1.bias", "module.c2.layer.1.running_mean", "module.c2.layer.1.running_var", "module.c2.layer.4.weight", "module.c2.layer.5.weight", "module.c2.layer.5.bias", "module.c2.layer.5.running_mean", "module.c2.layer.5.running_var", "module.d2.layer.0.weight", "module.d2.layer.1.weight", "module.d2.layer.1.bias", "module.d2.layer.1.running_mean", "module.d2.layer.1.running_var", "module.c3.layer.0.weight", "module.c3.layer.1.weight", "module.c3.layer.1.bias", "module.c3.layer.1.running_mean", "module.c3.layer.1.running_var", "module.c3.layer.4.weight", "module.c3.layer.5.weight", "module.c3.layer.5.bias", "module.c3.layer.5.running_mean", "module.c3.layer.5.running_var", "module.d3.layer.0.weight", "module.d3.layer.1.weight", "module.d3.layer.1.bias", "module.d3.layer.1.running_mean", "module.d3.layer.1.running_var", "module.c4.layer.0.weight", "module.c4.layer.1.weight", "module.c4.layer.1.bias", "module.c4.layer.1.running_mean", "module.c4.layer.1.running_var", "module.c4.layer.4.weight", "module.c4.layer.5.weight", "module.c4.layer.5.bias", "module.c4.layer.5.running_mean", "module.c4.layer.5.running_var", "module.d4.layer.0.weight", "module.d4.layer.1.weight", "module.d4.layer.1.bias", "module.d4.layer.1.running_mean", "module.d4.layer.1.running_var", "module.c5.layer.0.weight", "module.c5.layer.1.weight", "module.c5.layer.1.bias", "module.c5.layer.1.running_mean", "module.c5.layer.1.running_var", "module.c5.layer.4.weight", "module.c5.layer.5.weight", "module.c5.layer.5.bias", "module.c5.layer.5.running_mean", "module.c5.layer.5.running_var", "module.u1.layer.weight", "module.u1.layer.bias", "module.c6.layer.0.weight", "module.c6.layer.1.weight", "module.c6.layer.1.bias", "module.c6.layer.1.running_mean", "module.c6.layer.1.running_var", "module.c6.layer.4.weight", "module.c6.layer.5.weight", "module.c6.layer.5.bias", "module.c6.layer.5.running_mean", "module.c6.layer.5.running_var", "module.u2.layer.weight", "module.u2.layer.bias", "module.c7.layer.0.weight", "module.c7.layer.1.weight", "module.c7.layer.1.bias", "module.c7.layer.1.running_mean", "module.c7.layer.1.running_var", "module.c7.layer.4.weight", "module.c7.layer.5.weight", "module.c7.layer.5.bias", "module.c7.layer.5.running_mean", "module.c7.layer.5.running_var", "module.u3.layer.weight", "module.u3.layer.bias", "module.c8.layer.0.weight", "module.c8.layer.1.weight", "module.c8.layer.1.bias", "module.c8.layer.1.running_mean", "module.c8.layer.1.running_var", "module.c8.layer.4.weight", "module.c8.layer.5.weight", "module.c8.layer.5.bias", "module.c8.layer.5.running_mean", "module.c8.layer.5.running_var", "module.u4.layer.weight", "module.u4.layer.bias", "module.c9.layer.0.weight", "module.c9.layer.1.weight", "module.c9.layer.1.bias", "module.c9.layer.1.running_mean", "module.c9.layer.1.running_var", "module.c9.layer.4.weight", "module.c9.layer.5.weight", "module.c9.layer.5.bias", "module.c9.layer.5.running_mean", "module.c9.layer.5.running_var", "module.out.weight", "module.out.bias". Unexpected key(s) in state_dict: "c1.layer.0.weight", "c1.layer.1.weight", "c1.layer.1.bias", "c1.layer.1.running_mean", "c1.layer.1.running_var", "c1.layer.1.num_batches_tracked", "c1.layer.4.weight", "c1.layer.5.weight", "c1.layer.5.bias", "c1.layer.5.running_mean", "c1.layer.5.running_var", "c1.layer.5.num_batches_tracked", "d1.layer.0.weight", "d1.layer.1.weight", "d1.layer.1.bias", "d1.layer.1.running_mean", "d1.layer.1.running_var", "d1.layer.1.num_batches_tracked", "c2.layer.0.weight", "c2.layer.1.weight", "c2.layer.1.bias", "c2.layer.1.running_mean", "c2.layer.1.running_var", "c2.layer.1.num_batches_tracked", "c2.layer.4.weight", "c2.layer.5.weight", "c2.layer.5.bias", "c2.layer.5.running_mean", "c2.layer.5.running_var", "c2.layer.5.num_batches_tracked", "d2.layer.0.weight", "d2.layer.1.weight", "d2.layer.1.bias", "d2.layer.1.running_mean", "d2.layer.1.running_var", "d2.layer.1.num_batches_tracked", "c3.layer.0.weight", "c3.layer.1.weight", "c3.layer.1.bias", "c3.layer.1.running_mean", "c3.layer.1.running_var", "c3.layer.1.num_batches_tracked", "c3.layer.4.weight", "c3.layer.5.weight", "c3.layer.5.bias", "c3.layer.5.running_mean", "c3.layer.5.running_var", "c3.layer.5.num_batches_tracked", "d3.layer.0.weight", "d3.layer.1.weight", "d3.layer.1.bias", "d3.layer.1.running_mean", "d3.layer.1.running_var", "d3.layer.1.num_batches_tracked", "c4.layer.0.weight", "c4.layer.1.weight", "c4.layer.1.bias", "c4.layer.1.running_mean", "c4.layer.1.running_var", "c4.layer.1.num_batches_tracked", "c4.layer.4.weight", "c4.layer.5.weight", "c4.layer.5.bias", "c4.layer.5.running_mean", "c4.layer.5.running_var", "c4.layer.5.num_batches_tracked", "d4.layer.0.weight", "d4.layer.1.weight", "d4.layer.1.bias", "d4.layer.1.running_mean", "d4.layer.1.running_var", "d4.layer.1.num_batches_tracked", "c5.layer.0.weight", "c5.layer.1.weight", "c5.layer.1.bias", "c5.layer.1.running_mean", "c5.layer.1.running_var", "c5.layer.1.num_batches_tracked", "c5.layer.4.weight", "c5.layer.5.weight", "c5.layer.5.bias", "c5.layer.5.running_mean", "c5.layer.5.running_var", "c5.layer.5.num_batches_tracked", "u1.layer.weight", "u1.layer.bias", "c6.layer.0.weight", "c6.layer.1.weight", "c6.layer.1.bias", "c6.layer.1.running_mean", "c6.layer.1.running_var", "c6.layer.1.num_batches_tracked", "c6.layer.4.weight", "c6.layer.5.weight", "c6.layer.5.bias", "c6.layer.5.running_mean", "c6.layer.5.running_var", "c6.layer.5.num_batches_tracked", "u2.layer.weight", "u2.layer.bias", "c7.layer.0.weight", "c7.layer.1.weight", "c7.layer.1.bias", "c7.layer.1.running_mean", "c7.layer.1.running_var", "c7.layer.1.num_batches_tracked", "c7.layer.4.weight", "c7.layer.5.weight", "c7.layer.5.bias", "c7.layer.5.running_mean", "c7.layer.5.running_var", "c7.layer.5.num_batches_tracked", "u3.layer.weight", "u3.layer.bias", "c8.layer.0.weight", "c8.layer.1.weight", "c8.layer.1.bias", "c8.layer.1.running_mean", "c8.layer.1.running_var", "c8.layer.1.num_batches_tracked", "c8.layer.4.weight", "c8.layer.5.weight", "c8.layer.5.bias", "c8.layer.5.running_mean", "c8.layer.5.running_var", "c8.layer.5.num_batches_tracked", "u4.layer.weight", "u4.layer.bias", "c9.layer.0.weight", "c9.layer.1.weight", "c9.layer.1.bias", "c9.layer.1.running_mean", "c9.layer.1.running_var", "c9.layer.1.num_batches_tracked", "c9.layer.4.weight", "c9.layer.5.weight", "c9.layer.5.bias", "c9.layer.5.running_mean", "c9.layer.5.running_var", "c9.layer.5.num_batches_tracked", "out.weight", "out.bias".

这个问题是由于在之前训练模型的时候,在加载之前的模型的时候未包装在 nn.DataParallel 中的原始模型权重。故此保存的模型中,键名中没有 module. 前缀,但是在之后训练时,又使用多卡,包装在在nn.DataParallel 中,这会导致键名不匹配的问题。

针对此问题解决办法,

  • 加载之前单卡训练的模型权重并添加 module. 前缀:这样可以与当前 nn.DataParallel 包装的模型兼容。
  • 在保存模型时去掉 module. 前缀,以便之后可以在不使用 nn.DataParallel 时也可以直接加载。
import os
import tqdm
from torch import nn, optim
import torch
from torch.utils.data import DataLoader
from data import *
from net import *
from torchvision.utils import save_image

# 设置使用的显卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'  # 指定使用0,1号显卡

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path = 'params/unet.pth'
data_path = r'data/'
save_path = 'train_image'

# 修改加载权重的代码,在加载权重时检查并修改键名:
def remove_module_prefix(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('module.'):
            new_state_dict[k[7:]] = v  # 去除 'module.' 前缀
        else:
            new_state_dict[k] = v
    return new_state_dict

def add_module_prefix(state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        if not k.startswith('module.'):
            new_state_dict['module.' + k] = v  # 添加 'module.' 前缀
        else:
            new_state_dict[k] = v
    return new_state_dict

if __name__ == '__main__':
    num_classes = 255  # +1是背景也为一类
    data_loader = DataLoader(MyDataset(data_path), batch_size=5, shuffle=True)
    net = UNet(num_classes)

    # 使用 DataParallel 包装模型
    net = nn.DataParallel(net)
    net = net.to(device)

    if os.path.exists(weight_path):
        state_dict = torch.load(weight_path)
        
        # 检查并添加 'module.' 前缀
        if not list(state_dict.keys())[0].startswith('module.'):
            state_dict = add_module_prefix(state_dict)
        
        net.load_state_dict(state_dict)
        print('Successfully loaded weights!')
    else:
        print('Failed to load weights.')

    # 优化器和损失函数
    opt = optim.Adam(net.parameters())
    loss_fun = nn.CrossEntropyLoss()

    # 添加学习率调度器
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=10, gamma=0.1)

    epoch = 1
    while epoch < 200:
        for i, (image, segment_image) in enumerate(tqdm.tqdm(data_loader)):
            image, segment_image = image.to(device), segment_image.to(device)
            out_image = net(image)
            train_loss = loss_fun(out_image, segment_image.long())
            opt.zero_grad()
            train_loss.backward()
            opt.step()

            if i % 5 == 0:
                print(f'{epoch}-{i}-train_loss===>>{train_loss.item()}')

            _image = image[0]
            _segment_image = torch.unsqueeze(segment_image[0], 0) * 255
            _out_image = torch.argmax(out_image[0], dim=0).unsqueeze(0) * 255

            # 将 _segment_image 和 _out_image 转换为三通道
            _segment_image = _segment_image.repeat(3, 1, 1)
            _out_image = _out_image.repeat(3, 1, 1)

            img = torch.stack([_image, _segment_image, _out_image], dim=0)  # 将三张图片进行拼接显示
            save_image(img, f'{save_path}/{i}.png')

        # 调度器步进更新
        scheduler.step()

        if epoch % 20 == 0:
            # 保存模型时使用去除前缀的 state_dict
            state_dict = net.state_dict()
            state_dict = remove_module_prefix(state_dict)
            torch.save(state_dict, f'{weight_path}_epoch_{epoch}.pth')
            print('Save successfully!')
        epoch += 1

解释:

  1. remove_module_prefix:用于从 state_dict 中移除 module. 前缀。
  2. add_module_prefix:用于在 state_dict 中添加 module. 前缀。
  3. 加载模型权重时
    • 先检查第一个键是否以 module. 开头,如果不是,则调用 add_module_prefix 添加前缀,然后加载权重。
  4. 保存模型权重时
    • 调用 remove_module_prefix 移除前缀,以便之后可以在不使用 nn.DataParallel 时直接加载。

针对后面,模型训练的时候。 我们可以在保存时,移除 module. 前缀。 在加载时,添加上module. 前缀。 这样,保存的模型,不论我们在单卡上进行训练,或者多卡上进行训练时,都可以正常加载。

其实明白原理,就很简单,无非就是添加,或删除 module问题。

  • 22
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
深度学习模型训练环境的代码可以包括以下几个方面: 1. 数据预处理:包括数据集的读取、数据清洗、数据归一化、数据增强等操作。 2. 模型定义:包括神经网络结构的定义、各层参数的设置、激活函数的选择等。 3. 损失函数和优化器的定义:选择适合问题的损失函数和优化器,如交叉熵损失函数、Adam优化器等。 4. 训练过程:包括模型训练、验证和测试。训练过程需要设置学习率、批次大小、训练轮数等超参数,并记录训练过程的损失和准确率等指标。 5. 可视化和保存:可视化训练过程的损失和准确率等指标,保存训练好的模型以便后续使用。 一个示例代码如下: ```python import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms # 数据预处理 train_transforms = transforms.Compose( [transforms.RandomRotation(30), transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) test_transforms = transforms.Compose( [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) train_dataset = datasets.ImageFolder('train/', transform=train_transforms) test_dataset = datasets.ImageFolder('test/', transform=test_transforms) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False) # 模型定义 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.fc1 = nn.Linear(256 * 28 * 28, 1024) self.fc2 = nn.Linear(1024, 2) self.dropout = nn.Dropout(p=0.5) self.relu = nn.ReLU() def forward(self, x): x = self.pool(self.relu(self.conv1(x))) x = self.pool(self.relu(self.conv2(x))) x = self.pool(self.relu(self.conv3(x))) x = x.view(-1, 256 * 28 * 28) x = self.dropout(self.relu(self.fc1(x))) x = self.fc2(x) return x model = Net() # 损失函数和优化器的定义 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 训练过程 for epoch in range(10): running_loss = 0.0 for i, data in enumerate(train_loader, 0): inputs, labels = data optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print('[%d] loss: %.3f' % (epoch + 1, running_loss / len(train_loader))) # 测试过程 correct = 0 total = 0 with torch.no_grad(): for data in test_loader: images, labels = data outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the test images: %d %%' % (100 * correct / total)) # 模型保存 torch.save(model.state_dict(), 'model.pth') ``` 这段代码实现了一个简单的图像分类模型,使用了数据增强技术和深度网络结构,采用交叉熵损失函数和Adam优化器进行训练,最后保存了训练好的模型

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Stuomasi_xiaoxin

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值