def dropout_layer(x,dropout): assert 0 <= dropout<=1 if dropout==1: return torch.zeros_like(x) if dropout==0: return x mask=(torch.randn(x.shape)>dropout).float() return mask*x /(1.0-dropout)
dropout代码实现原理
最新推荐文章于 2024-08-05 00:59:31 发布
def dropout_layer(x,dropout): assert 0 <= dropout<=1 if dropout==1: return torch.zeros_like(x) if dropout==0: return x mask=(torch.randn(x.shape)>dropout).float() return mask*x /(1.0-dropout)