摘要
本文使用纯 Python 和 PyTorch 对比 dropout 函数及其反向传播.
相关
原理和详细解释, 请参考文章 :
dropout函数详解及反向传播中的梯度求导
系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981
正文
import torch
import numpy as np
class Dropout:
"""
http://arxiv.org/abs/1207.0580
"""
def __init__(self, dropout_ratio=0.5):
self.dropout_ratio = dropout_ratio
self.train_flg = True
self.mask = None
def __call__(self, x, manual_mask=None, train_flg=True):
if train_flg:
if manual_mask is None:
self.mask = np.random.rand(*x.shape) > self.dropout_ratio
else:
self.mask = manual_mask
out =