关于tf.contrib.layers.batch_norm的pytorch替代torch.nn.BatchNorm1d() 的一些注意事项

本文讲述了在TensorFlow(1.14)和PyTorch(1.13.1)中,如何在处理三维输入数据时正确使用torch.nn.BatchNorm1d进行归一化,以及与tf.contrib.layers.batch_norm的差异,强调了对输入数据最后两维转置的必要性。
摘要由CSDN通过智能技术生成

tf版本为1.14,torch为1.13.1,总结,可以用 torch.nn.BatchNorm1d() 替代,但三维输入数据时需要在应用 BatchNorm1d 前对最后两维进行转置,应用后再转置回来。二维输入时可直接替代。

三维输入数据

data = [[[1.,2.,3.,4.],[9.,8.,7.,6.]],[[4.,2.,3.,5.],[4.,5.,6.,3.]],[[1.,5.,3.,9.],[7.,5.,4.,2.]]]

tf_data = tf.constant(data)

torch_data = torch.tensor(data)

# 使用 tf.contrib.layers.batch_norm 进行归一化
output_tensor = tf.contrib.layers.batch_norm(tf_data, is_training=True)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
config = tf.ConfigProto(allow_soft_placement = True)
init = tf.global_variables_initializer()
with tf.Session(config = config) as sess:
    sess.run(init)
    tf_result = sess.run(output_tensor)
print(f'tf.contrib.layers.batch_norm输出为{tf_result}')

# 使用 torch.nn.BatchNorm1d 进行归一化
batch_norm_layer = torch.nn.BatchNorm1d(2)# 2为data形状(3,2,4)的第二维
torch_result = batch_norm_layer(torch_data)
print(f'torch.nn.BatchNorm1d输出为{torch_result}')

结果是

显然不对,这里需要对pytorch版本的代码进行修改,参考

详解torch.nn.BatchNorm1d的具体计算过程

对三维数据data的最后两维进行转置,应用BatchNorm1d后再变回来即可,代码如下

data = [[[1.,2.,3.,4.],[9.,8.,7.,6.]],[[4.,2.,3.,5.],[4.,5.,6.,3.]],[[1.,5.,3.,9.],[7.,5.,4.,2.]]]

tf_data = tf.constant(data)

torch_data = torch.tensor(data)

# 使用 tf.contrib.layers.batch_norm 进行归一化
output_tensor = tf.contrib.layers.batch_norm(tf_data, is_training=True)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
config = tf.ConfigProto(allow_soft_placement = True)
init = tf.global_variables_initializer()
with tf.Session(config = config) as sess:
    sess.run(init)
    tf_result = sess.run(output_tensor)
print(f'tf.contrib.layers.batch_norm输出为{tf_result}')

# 使用 torch.nn.BatchNorm1d 进行归一化
batch_norm_layer = torch.nn.BatchNorm1d(4)# 4为data后俩维转置后,形状为(3,4,2)的第二维
torch_result = batch_norm_layer(torch_data.permute(0,2,1)).permute(0,2,1)
print(f'torch.nn.BatchNorm1d输出为{torch_result}')

注意BatchNorm1d(x)中的x始终是输入数据的第 2 维,输出结果为

可以看到除了精度,是相同的结果。

二维输入数据

data = [[1.,2.,3.,4.],[9.,8.,7.,6.],[4.,2.,3.,5.],[4.,5.,6.,3.],[1.,5.,3.,9.],[7.,5.,4.,2.]]

tf_data = tf.constant(data)

torch_data = torch.tensor(data)

# 使用 tf.contrib.layers.batch_norm 进行归一化
output_tensor = tf.contrib.layers.batch_norm(tf_data, is_training=True)
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
config = tf.ConfigProto(allow_soft_placement = True)
init = tf.global_variables_initializer()
with tf.Session(config = config) as sess:
    sess.run(init)
    tf_result = sess.run(output_tensor)
print(f'tf.contrib.layers.batch_norm输出为{tf_result}')

# 使用 torch.nn.BatchNorm1d 进行归一化
batch_norm_layer = torch.nn.BatchNorm1d(4)# 4为data(形状为(6,4))的第二维
torch_result = batch_norm_layer(torch_data)
print(f'torch.nn.BatchNorm1d输出为{torch_result}')

结果除了精度,其余完全相同

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值