R-Drop是Regularized Dropout
为了解决Dropout中训练和测试(推理)不一致的问题
Dropout本质上是一种集成学习,即在训练的时候同时训练多个神经网络
R-Drop使得通过Drop产生的不同的子模型,它们输出的分布要彼此一致。
具体来说,对每个训练样本,R-Dropout都会将两个子模型的KL散度进行一个最小化
in each mini-batch training, each data sample goes through the forward pass twice, and each pass is processed by a different sub model by randomly dropping out some hidden units. R-Drop forces the two distributions for the same data sample outputted by the two sub models to be consistent with each other, through minimizing the bidirectional Kullback-Leibler (KL) divergence between the two distributions
代码实现
伪代码,演示原理
import torch from torch import nn import numpy as np #模拟两层网络 def train(p, x, w1, b1, w2, b2): layer1 = np.maximum(0, np.dot(w1, x) + b1) mask1 = np.random.binomial(1, 1-p, layer1.shape) layer1 = layer1 * mask1 layer1 = layer1 / (1-p) layer2 = np.maximum(0, np.dot(w2, layer1) + b2) mask2 = np.random.binomial(1, 1-p, layer2.shape) layer2 = layer2 * mask2 layer2 = layer2 / (1-p) return layer2 #模拟两层网络 def train_r_dropout(p, x, w1, b1, w2, b2): bs = x.shape[0] x = torch.cat((x,x), dim=0) #-----------原Dropout部分保持不变--------- layer1 = np.maximum(0, np.dot(w1, x) + b1) mask1 = np.random.binomial(1, 1-p, layer1.shape) layer1 = layer1 * mask1 layer1 = layer1 / (1-p) layer2 = np.maximum(0, np.dot(w2, layer1) + b2) mask2 = np.random.binomial(1, 1-p, layer2.shape) layer2 = layer2 * mask2 layer2 = layer2 / (1-p) #------------------------------------- logits = func(layer2) logits1, logits2 = logits[:bs, :], logits[bs:, :] nll1 = nll(logits1, label) nll2 = nll(logits2, label) kl_loss = kl(logits1, logits2) loss = nll1 + nll2 + kl_loss return loss def test(x, w1, b1, w2, b2): layer1 = np.maximum(0, np.dot(w1, x)+b1) layer2 = np.maximum(0, np.dot(w2, layer1) + b2) return layer2 input = np.random.randn(5, 4) w1 = np.random.rand(30,20) b1 = np.random.rand(30) w2 = np.random.rand(40,30) b2 = np.random.rand(40) output1 = train(p=0.5, x=input.reshape(-1), w1=w1, b1=b1, w2=w2, b2=b2) print(output1) output2 = test(x=input.reshape(-1), w1=w1, b1=b1, w2=w2, b2=b2) print(output2)
R-Dropout
最新推荐文章于 2024-02-19 19:59:19 发布