pytorch BatchNorm1d 输入二维和三维数据的区别

25 篇文章 1 订阅
14 篇文章 0 订阅

在阅读KPConv-PyTorch源码时,发现其对torch.nn.BatchNorm1d进行了封装。

class BatchNormBlock(nn.Module):

    def __init__(self, in_dim, use_bn, bn_momentum):
        """
        Initialize a batch normalization block. If network does not use batch normalization, replace with biases.
        :param in_dim: dimension input features
        :param use_bn: boolean indicating if we use Batch Norm
        :param bn_momentum: Batch norm momentum
        """
        super(BatchNormBlock, self).__init__()
        self.bn_momentum = bn_momentum
        self.use_bn = use_bn
        self.in_dim = in_dim
        if self.use_bn:
            self.batch_norm = nn.BatchNorm1d(in_dim, momentum=bn_momentum)
            #self.batch_norm = nn.InstanceNorm1d(in_dim, momentum=bn_momentum)
        else:
            self.bias = Parameter(torch.zeros(in_dim, dtype=torch.float32), requires_grad=True)
        return

    def reset_parameters(self):
        nn.init.zeros_(self.bias)

    def forward(self, x):	
        if self.use_bn:
			# x: [num_of_point, dim]
            x = x.unsqueeze(2)
            x = x.transpose(0, 2)
            # x: [1, dim, num_of_point]
            x = self.batch_norm(x)
            x = x.transpose(0, 2)
            return x.squeeze()	# x: [num_of_point, dim]
        else:
            return x + self.bias

    def __repr__(self):
        return 'BatchNormBlock(in_feat: {:d}, momentum: {:.3f}, only_bias: {:s})'.format(self.in_dim,
 

x输入时维度为[num_of_point, dim]
经过变换为[1, dim, num_of_point]再输入到batch_norm,return时还原维度

通过阅读torch.nn.BatchNorm1d官方文档发现:

  • 当输入为(N, C, L)时,计算维度 (N, L) 上的统计数据进行归一化。按照维度C恢复全局方差偏置。
  • 当输入为(N, L)时,计算维度N切片上的统计数据进行归一化。按照维度L恢复全局方差偏置。

那么为什么不直接

# x: [num_of_point, dim]
x = self.batch_norm(x)

实验

data与batch_norm采用真实样本与预训练模型参数

self.batch_norm = self.batch_norm.eval()
# data: [num_of_point, dim]

a = self.batch_norm(data)

x = data.unsqueeze(2)
x = x.transpose(0, 2)
x = self.batch_norm(x)
x = x.transpose(0, 2)
b = x.squeeze()

print(a-b)
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',
       grad_fn=<SubBackward0>)

结果

当我对代码进行更改并运行后发现运行速度大幅下降,一开始为还以为是其他问题,后来测试发现:

import time
t = time.time()
for i in range(1000):
	# x: [num_of_point, dim]
	x = x.unsqueeze(2)
	x = x.transpose(0, 2)
	x = self.batch_norm(x)
	x = x.transpose(0, 2)
	x = x.squeeze()
print(time.time() - t)	# 0.7s

t = time.time()
for i in range(1000):
	# x: [num_of_point, dim]
	x = self.batch_norm(x)
print(time.time() - t)	# 2.1s

补充测试:

t = time.time()
for i in range(1000):
	# x: [num_of_point, dim]
	x = x.transpose(0, 2)
	x = self.batch_norm(x)
	x = x.transpose(0, 2)
print(time.time() - t)	# 0.7s

batch_norm处理2个维度数据要比处理3个维度慢3倍,官方代码中并没有提到输入2维和3维的有什么不同,但是既然这样就回退更改,并记录下这个问题。

补充

上述实验章节中的ab也并非完全相同

print(torch.max(a-b))	# 7.4506e-09
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值