Pytorch搭建模型需要注意一个问题

Pytorch搭建模型需要注意一个问题

【问题描述】

我在搭建一个包括多分支的网络,由于分支数量不确定需要用到容器;最初直接用的list,出现了子模块的self.training属性无法和模型自身同步的问题。

【解决方式】

需要将list转为torch.nn.ModuleList类别,才能同步self.training属性。搭建网络时,子模块都需要使用torch定义的container存放。如果是串接就用torch.nn.Sequential (这个很熟悉,相信都不会犯错);如果不是串接,自己定义forward方式的话,要用torch.nn.ModuleList而不能用Python中原生的容器。

【代码示例】

以下示意代码仅供说明这种情况:

import torch
import torch.nn as nn

class SubModule(nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
    
    def forward(self, x):
        if self.training:
            print('running sub-module in training mode.')
            return x
        else:
            print('running sub-module in evaluation mode.')
            return x+1


class Model(nn.Module):
    def __init__(self, is_container=False, num_of_subs=3):
        super(Model, self).__init__()
        
        # self.sub_modules = nn.ModuleList()
        # for i in range(num_of_subs):
        #     self.sub_modules.append(SubModule())

        self.sub_modules = []
        for i in range(num_of_subs):
            self.sub_modules.append(SubModule())
        if is_container:       
            self.sub_modules = nn.ModuleList(self.sub_modules)
    
    def forward(self, x):
        outputs = []
        for md in self.sub_modules:
            outputs.append(md(x))
        return sum(outputs)


if __name__ == '__main__':
    x = 0
    print('[WRONG]-------------------------')
    model = Model(is_container=False)
    model.eval()        
    print('model is training mode.' if model.training else 'model is evaluation mode.')
    with torch.no_grad():
        y = model(x)
        print(y)

    print('[RIGHT]-------------------------')
    model = Model(is_container=True)
    model.eval()
    print('model is training mode.' if model.training else 'model is evaluation mode.')
    with torch.no_grad():
        y = model(x)
        print(y)

输出如下:

[WRONG]-------------------------
model is evaluation mode.
running sub-module in training mode.
running sub-module in training mode.
running sub-module in training mode.
0
[RIGHT]-------------------------
model is evaluation mode.
running sub-module in evaluation mode.
running sub-module in evaluation mode.
running sub-module in evaluation mode.
3

通过结果可以发现,就算指定了model.eval()并且加上了torch.no_grad()限制,若不使用torch定义的容器存放子模块,它是无法同步模型的training和evaluation状态的。此种情况下,当需要针对子模块在训练和推理阶段做不同操作时就会产生错误。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

蔡逸超

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值