学习记录 | 使用nn.ModuleList()的必要性

在我的代码中想自己实现一个可变层数的多层感知机(MLP),即全连接网络。写了一个类如下:

class MLPLayer(nn.Module):
    def __init__(self, in_dim, out_dim, mlp_hid_dim=None):
        super(MLPLayer, self).__init__()
        if mlp_hid_dim is None:
            mlp_hid_dim = [64]
        self.fc_i = nn.Linear(in_dim, mlp_hid_dim[0])
        self.num_layers = len(mlp_hid_dim)
        self.fc = [nn.Linear(mlp_hid_dim[i], mlp_hid_dim[i + 1]) for i in range(self.num_layers - 1)]
        self.fc_o = nn.Linear(mlp_hid_dim[-1], out_dim)

    def forward(self, x):
        x = F.relu(self.fc_i(x))
        for i in range(self.num_layers - 1):
            x = F.relu(self.fc[i](x))
        x = self.fc_o(x)
        return x

平时能正常跑起来,没发现问题。最近将其放到GPU上跑时,出现报错:

 RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm) Step #0.00Error: tcpip::Socket::recvAndCheck @ recv: peer shutdown

 看提示是说有的数据在CPU而有的在GPU上了。但我明明把所有模型和数据都加上了.to(device)了呀?一番排查发现只要将上面隐藏层改为1层就不会报错了。但为什么呢?好,下面有请我们永远滴神——ChatGPT老师进行解答:

看起来问题出在MLPLayer类中的self.fc列表上。在PyTorch中,nn.ModuleList用于存储模型的子模块(如层或其他模块),这样当调用.to(device)时,所有子模块也会随之移动到指定的设备。然而,在你的代码中,self.fc是一个普通的Python列表,它不会自动处理子模块的设备迁移。

要解决这个问题,你需要将self.fc从Python列表转换为nn.ModuleList。这样,当你调用.to(device)时,nn.ModuleList中的所有层也会被移动到相应的设备上。下面是如何修改你的MLPLayer类来实现这一点:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MLPLayer(nn.Module):
    def __init__(self, in_dim, out_dim, mlp_hid_dim=None):
        super(MLPLayer, self).__init__()
        if mlp_hid_dim is None:
            mlp_hid_dim = [64]
        self.fc_i = nn.Linear(in_dim, mlp_hid_dim[0])
        self.num_layers = len(mlp_hid_dim)
        # 将self.fc转换为nn.ModuleList
        self.fc = nn.ModuleList([nn.Linear(mlp_hid_dim[i], mlp_hid_dim[i + 1]) for i in range(self.num_layers - 1)])
        self.fc_o = nn.Linear(mlp_hid_dim[-1], out_dim)

    def forward(self, x):
        x = F.relu(self.fc_i(x))
        for layer in self.fc:  # 直接遍历nn.ModuleList
            x = F.relu(layer(x))
        x = self.fc_o(x)
        return x

 破案了,原来是我用了python list来装每层全连接层,而导致在整个模型.to(device)时无法将其移动到GPU上。此时应使用nn.ModuleList()将普通列表转一下,才能正确地被torch处理。基础不牢地动山摇,指不定什么时候写的bug就出现并绊你一脚……学习了!

  • 9
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值