1. make_grid
torchvision.utils.make_grid()作用是将若干幅图像拼成一幅图像,如下所示:
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False)
tensor是4D张量,比如(B, H, W, C)的图像;nrow表示每行图片数,最后grid_size为(B/nrow, row);padding是拼接后图像间距离;normalize表示对图像归一化到0-1之间;range为范围[min, max],将图像像素变换到该范围内;scale_each表示每个图像单独归一化。
2. ReflectionPad2d
nn.ReflectionPad2d()是paddingLayer,padding的方式多种,可以是指定一个值,也可以是不规则方式,即给出一个四元组,如下所示:
input = torch.randn(64, 3, 220, 220) # input size
# 4-tuple
pad = nn.ReflectionPad2d((3, 3, 5, 5)) # left, right, top, bottom
output = pad(input) # size(64, 3, 230, 226)
# int
pad = nn.ReflectionPad2d(3)
output = pad(input) # size(64, 3, 226, 226)
参考自博客:https://blog.csdn.net/LemonTree_Summer/article/details/81433638
3. BatchNorm2d
官方文档如下图所示:
Batchnormalization是在batch内,对所有图像的每一个channel求均值和方差,比如输入为[B, C, H, W],则从1到B个图像,每个图像的第一个channel(一共是B个)的二维feature map上的所有元素求均值和方差(即B*H*W个像素的均值和方差),然后再计算所有图像的第二个channel的均值和方差,直到最后一个channel。所以batchnormalization中会得到C个均值和C个方差,参数是2C个(每一个channel都有其)。所谓的batch也就是指计算均值和方差是跨越image在整个batch内进行的。
实验如下:
m= nn.BatchNorm2d(3, momentum=1) # 第一个参数是output_channel
input = torch.randn(4, 3, 2, 2)
output = m(input)
# print(output)
a = (input[0, 0, :, :]+input[1, 0, :, :]+input[2, 0, :, :]+input[3, 0, :, :]).sum()/16
b = (input[0, 1, :, :]+input[1, 1, :, :]+input[2, 1, :, :]+input[3, 1, :, :]).sum()/16
c = (input[0, 2, :, :]+input[1, 2, :, :]+input[2, 2, :, :]+input[3, 2, :, :]).sum()/16
print('The mean value of the first channel is %f'% a.data)
print('The mean value of the second channel is %f' % b.data)
print('The mean value of the third channel is %f' % c.data)
print('The output mean value of the BN layer is %f, %f, %f' % (m.running_mean.data[0],m.running_mean.data[1],m.running_mean.data[2]))
print(m)
输出结果如下:
The mean value of the first channel is -0.276033
The mean value of the second channel is -0.365808
The mean value of the third channel is -0.107249
The output mean value of the BN layer is -0.276033, -0.365808, -0.107249
BatchNorm2d(3, eps=1e-05, momentum=1, affine=True, track_running_stats=True)
上述代码部分参考自博客:http://www.mamicode.com/info-detail-2378483.html