[Pytorch系列-38]: 工具集 - torchvision预定义模型的两种模式model.train和model.eval的表面和本质区别

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121176467


目录

第1章 为什么需要讨论model.train()和model.eval的区别。

1.1 利用torchvision.model预定义的模型

1.2 预定义模型的定义:定义模型非常方便!

1.3 预定义模型的两种模式

第2章 两种模式的区别

2.1 代码执行结果的区别

2.2 train模式与eval模式的本质区别

第3章 函数原型说明

3.1 model.train(mode=True)

3.2 model.eval (mode=True)

第4章 源码解析

4.1 model.train(mode=True)

4.2 model.eval (mode=True)



第1章 为什么需要讨论model.train()和model.eval的区别。

1.1 利用torchvision.model预定义的模型

在前面的文章中,我们都在探讨如何手工搭建卷积神经网络,训练网络模型。

很显然,对于一些知名的模型,手工搭建的效率较低,且容易出错。

因此,利用pytorch提供的搭建好的知名模型,是一个不错的选择。

[Pytorch系列-37]:工具集 - torchvision库详解(数据集、数据预处理、模型)_文火冰糖(王文兵)的博客-CSDN博客作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客本文网址:目录第1章Pytorch常见的工具集简介第2章Pytorch的torchvision工具集简介第3章torchvision.datasets 简介3.1 简介3.2 支持的数据集列表第4章torchvision.models简介4.1 简介4.2 支持的模型4.3构造具有随机权重的模型4.4 使用预预训练好的模型第5章 torchvision.tr...https://blog.csdn.net/HiWangWenBing/article/details/121149809

1.2 预定义模型的定义:定义模型非常方便!

(1)案例1

# 2-3 使用torchvision.models定义神经网络
net_a = models.alexnet(num_classes = 10)
print(net_a)
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )
)

(2)案例2:

net_b = models.vgg16(num_classes = 10)
print(net_b)
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=10, bias=True)
  )
)

1.3 预定义模型的两种模式

预定义模型支持两种工作模式

  • model.train:适用于训练模式
  • model.eval:适用于预测模式

第2章 两种模式的区别

2.1 代码执行结果的区别

(1)train模式

# 2-4 定义网络预测输出
# 测试网络是否能够工作
print("定义测试数据")
input = torch.randn(1, 3, 224, 224)
print(input.shape)

print("")
net_a.train()    #设置在训练模式,不设置,模式为训练模式

print("net_a的输出方法1:")
out = net_a(input)
print(out.shape)
print(out)

print("")
print("net_a的输出方法2:")
out = net_a.forward(input)
print(out)

定义测试数据
torch.Size([1, 3, 224, 224])

net_a的输出方法1:
torch.Size([1, 10])
tensor([[ 1.3295e-02,  3.6363e-05,  9.0679e-04,  2.2399e-03,  2.4310e-03,
         -6.0430e-03, -9.2381e-03, -8.3501e-03, -7.8011e-03, -7.4917e-03]],
       grad_fn=<AddmmBackward>)

net_a的输出方法2:
tensor([[ 0.0195, -0.0109, -0.0017, -0.0027,  0.0131, -0.0014, -0.0212, -0.0030,
         -0.0114, -0.0041]], grad_fn=<AddmmBackward>)

备注:

连续两次的预测结果是不完全一样的。

(2)评估模式

# 2-4 定义网络预测输出
# 测试网络是否能够工作
print("定义测试数据")
input = torch.randn(1, 3, 224, 224)
print(input.shape)

print("")
net_a.eval()    #设置在训练模式,不设置,模式为训练模式

print("net_a的输出方法1:")
out = net_a(input)
print(out.shape)
print(out)

print("")
print("net_a的输出方法2:")
out = net_a.forward(input)
print(out)
定义测试数据
torch.Size([1, 3, 224, 224])

net_a的输出方法1:
torch.Size([1, 10])
tensor([[ 0.0106, -0.0096,  0.0060,  0.0017,  0.0017, -0.0004, -0.0031, -0.0126,
         -0.0078, -0.0079]], grad_fn=<AddmmBackward>)

net_a的输出方法2:
tensor([[ 0.0106, -0.0096,  0.0060,  0.0017,  0.0017, -0.0004, -0.0031, -0.0126,
         -0.0078, -0.0079]], grad_fn=<AddmmBackward>)

备注:

连续两次的预测结果是完全一样的!

(3)train模式预eval模式的区别

  • model.train:连续两次的预测结果是不完全一样的。
  • model.eval:连续两次的预测结果是完全一样的!

为什么呢?

2.2 train模式与eval模式的本质区别

(1)根本原因

对于一些含有BatchNormDropout等层的模型,model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层的处理上不同。

(2)train模式

如果模型中有BN层(Batch Normalization)和 Dropout,需要在训练时添加model.train(),启用 Batch Normalization Dropout的特殊处理的功能。

model.train()是保证BN层能够用到每一批数据的均值和方差。

对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,导致在该模式下,每次预测,使用的网络参数连接不尽相同。这就导致,每次预测输出的结果不完全一样。

备注:

至于Dropout和Batch Normalization,请查看相关的文档,文本不再叙述。

(3)eval模式

如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。关闭 Batch Normalization 和 Dropout的功能。

model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。

对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。
 

第3章 函数原型说明

torch.nn — PyTorch 1.10.0 documentationhttps://pytorch.org/docs/stable/nn.html?highlight=module%20eval#torch.nn.Module.eval

3.1 model.train(mode=True)

train

3.2 model.eval (mode=True)

å¨è¿éæå¥å¾çæè¿°

第4章 源码解析

4.1 model.train(mode=True)

    def train(self, mode=True):
        r"""Sets the module in training mode.

        This has any effect only on certain modules. See documentations of
        particular modules for details of their behaviors in training/evaluation
        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
        etc.

        Returns:
            Module: self
        """
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self

4.2 model.eval (mode=True)

def eval(self):
    return self.train(False)

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121176467

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

文火冰糖的硅基工坊

你的鼓励是我前进的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值