eval()与train()(结合源码理解)

eval()与train()(结合源码理解)

结论

先上结论:eval和training实际上就是需要判断BN层(dorpout类似)中全局统计特征moving_mean和moving_variance是否需要迭代更新

1.eval()与train()的实质

请添加图片描述

pytorch.org上eval的源码
从源码上不难看出,eval()和trian()实际上就是遍历了所有子模块,然后把其中的training属性给设置为Ture和False。

2.结合BN层理解

那么这个training属性有什么用呢,我们直接定位BN层的源代码,参考理解
请添加图片描述
抛开中间的部分,就看绿色的部分可以发现module中的training到了这里,通过training给bn_training进行了赋值。看到这会很自然产生两个问题,这个training怎么来的,这个bn_training又是什么捏。
training从何而来
我们看看上面图片中,这个定义的类叫_BatchNorm,它继承自_NormBase而这个_NormBase又继承自module,如下图所示:
在这里插入图片描述
所以module的training属性也就自然而然继承了过来,接下来看看bn_training又有什么用。
在先前起到的_NormBase类的forward中的return结果中,我们找到了bn_training的用处,如下图所示:
在这里插入图片描述
这个参数有什么用呢,自然就要追溯一下F.batch_norm方法是如何使用的
参考飞桨的开发文档如下:
在这里插入图片描述
通过BN层的原理我们可以知道,要对每个Batch进行标准化需要的是两部分,一部分是全局统计特征(moving_mean和moving_variance),一部分是局部也就是minibatch的统计特征,μβ 和 σ2β。
前一部分通过对每个batch的数据特征计算迭代得到(类似参数的更新,但不完全是,后面会提到),后一部分对当前batch计算即可得到。

所以综上可以得到,eval和training实际上就是需要判断这个全局统计特征moving_mean和moving_variance是否需要迭代更新,这不也就对应了train和validation的两个阶段吗:)。

另外提到的BN层的迭代类似参数但不是的原因是:moving_mean和moving_variance不是模型的参数,而是作为一个buffer保存在模型中,所以其更新与参数的更新(loss反传)不一样,与require_grad要区别开哦。(这里后续会继续开一篇文章讲解buffer、parameter、module的相关)

以上就是eval()和train()的理解,感兴趣的小伙伴可以点点关注蹲蹲后续的一些分享,同时欢迎各位大佬评论区留言指正!:)

参考资料:
[1].飞桨参考文档https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/BatchNorm1D_cn.html#cn-api-nn-batchnorm1d
[2].pytorch官网
https://pytorch.org/

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
【资源介绍】 基于预训练模型BERT的阅读理解python源码+使用说明.zip 1. Prepare data, the virtual python environment and install the package in requirements.txt 2. Run the command below to fine tune ```bash python bert_qa.py \ --model_type bert \ --model_name_or_path bert-base-chinese \ --do_train \ --do_eval \ --do_lower_case \ --train_file ~/cmrc2018_train.json \ --predict_file ~/cmrc2018_trial.json \ --per_gpu_train_batch_size 12 \ --learning_rate 3e-5 \ --num_train_epochs 2.0 \ --max_seq_length 384 \ --doc_stride 128 \ --output_dir ~/chinese-qa-with-bert/output \ --save_steps 200 ``` #### 4. Load the fine-tuned params and run this command to interactive ```bash python interactive.py \ --output_dir ~/chinese-qa-with-bert/output \ --model_type bert \ --predict_file ~/chinese-qa-with-bert/input.json \ --state_dict ~/chinese-qa-with-bert/output/pytorch_model.bin \ --model_name_or_path bert-base-chinese ``` 【备注】 该项目是个人毕设/课设/大作业项目,代码都经过本地调试测试,功能ok才上传,高分作品,可快速上手运行!欢迎下载使用,可用于小白学习、进阶。 该资源主要针对计算机、通信、人工智能、自动化等相关专业的学生、老师或从业者下载使用,亦可作为期末课程设计、课程大作业、毕业设计等。 项目整体具有较高的学习借鉴价值!基础能力强的可以在此基础上修改调整,以实现不同的功能。 欢迎下载使用,也欢迎交流学习!
基于paddle框架的图像超分和降噪实现源码+项目使用说明.zip 基于paddle框架的图像超分和降噪实现源码+项目使用说明.zip 基于paddle框架的图像超分和降噪实现源码+项目使用说明.zip 【资源说明】 该项目是个人毕设项目源码,评审分达到95分,调试运行正常,确保可以运行!放心下载使用。 该项目资源主要针对计算机、自动化等相关专业的学生或从业者下载使用,也可作为期末课程设计、课程大作业、毕业设计等。 具有较高的学习借鉴价值!基础能力强的可以在此基础上修改调整,以实现类似其他功能。 项目建议安装anaconda,在anaconda中配置环境,然后导入到pycharm中运行 ## 1. 模型 * [x] [Unet]() ## 2. 数据准备 ## 3. 训练 此模型支持单机单卡和单机多卡训练,以下使用`UNet`跑`denoise`举例。 ### 3.1 单机单卡 ``` python tools/train.py -c ./configs/denoise/unet_watermark.yaml ``` ### 3.2 单机多卡 ``` python -m paddle.distributed.launch --gpus=0,1,2,3 tools/train.py -c ./configs/denoise/unet_watermark.yaml ``` 所有的训练日志都默认保存在 `./output/UNet/train.log` ## 4. 实验结果 ## 5. 评估 此模型支持单机单卡和单机多卡评估,以下使用`UNet`跑`denoise`举例,生成的模型位置在`/output/UNet/best_model.pdparams` ### 5.1 单机单卡 ``` python tools/eval.py -c ./configs/denoise/unet_watermark.yaml -o Global.pretrained_model=./output/UNet/best_model.pdparams ``` ### 5.2 单机多卡 ``` python -m paddle.distributed.launch --gpus=0,1,2,3 tools/eval.py -c ./configs/denoise/unet_watermark_4gpu.yaml -o Global.pretrained_model=./output/UNet/best_model.pdparams ``` 所有的评估日志都默认保存在 `./output/UNet/eval.log` ## 6. 推理 ### 6.1 模型导出 首先需要导出推理模型,例如训练好的模型参数在`./output/UNet/best_model.pdparams`,命令为 ``` python tools/export_model.py -c ./configs/denoise/unet_watermark.yaml -o Global.pretrained_model=./output/UNet/best_model.pdparams ``` 模型将自动导出到`Global.output_dir/Arch.name/inference`,其中`Global.output_dir`和`Arch.name`均在配置文件`.yaml`中。 ### 6.2 模型推理 模型导出后将使用测试数据集对模型进行推理,例如所有的测试文件都在`./test_data`中,运行命令 ``` python tools/inference.py -c ./configs/denoise/unet_watermark.yaml -o Data.Test.path=./test_data ``` 模型会将推理的结果放入`Global.output_dir/Arch.name/Img`中。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值