PyTorch - torch.nn.Upsample

PyTorch - torch.nn.Upsample

flyfish

上采样

输入是

minibatch x channels x height x width

输出是

H × scale_factor
W × scale_factor

本来名字是上采样,还可以根据参数使用变成下采样
与torch.nn.functional.interpolate对比
torch.nn.functional.interpolate完全可以替代该函数

import torch
import torch.nn as nn

input = torch.arange(0, 16).view(1,1,4,4).float()
print((input))
# tensor([[[[ 0.,  1.,  2.,  3.],
#           [ 4.,  5.,  6.,  7.],
#           [ 8.,  9., 10., 11.],
#           [12., 13., 14., 15.]]]])
m = nn.Upsample(scale_factor=2, mode='nearest')
print(m(input))
# tensor([[[[ 0.,  0.,  1.,  1.,  2.,  2.,  3.,  3.],
#           [ 0.,  0.,  1.,  1.,  2.,  2.,  3.,  3.],
#           [ 4.,  4.,  5.,  5.,  6.,  6.,  7.,  7.],
#           [ 4.,  4.,  5.,  5.,  6.,  6.,  7.,  7.],
#           [ 8.,  8.,  9.,  9., 10., 10., 11., 11.],
#           [ 8.,  8.,  9.,  9., 10., 10., 11., 11.],
#           [12., 12., 13., 13., 14., 14., 15., 15.],
#           [12., 12., 13., 13., 14., 14., 15., 15.]]]])
m = nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False)  #align_corners = False是默认参数
print(m(input))
# tensor([[[[ 0.0000,  0.2500,  0.7500,  1.2500,  1.7500,  2.2500,  2.7500,
#             3.0000],
#           [ 1.0000,  1.2500,  1.7500,  2.2500,  2.7500,  3.2500,  3.7500,
#             4.0000],
#           [ 3.0000,  3.2500,  3.7500,  4.2500,  4.7500,  5.2500,  5.7500,
#             6.0000],
#           [ 5.0000,  5.2500,  5.7500,  6.2500,  6.7500,  7.2500,  7.7500,
#             8.0000],
#           [ 7.0000,  7.2500,  7.7500,  8.2500,  8.7500,  9.2500,  9.7500,
#            10.0000],
#           [ 9.0000,  9.2500,  9.7500, 10.2500, 10.7500, 11.2500, 11.7500,
#            12.0000],
#           [11.0000, 11.2500, 11.7500, 12.2500, 12.7500, 13.2500, 13.7500,
#            14.0000],
#           [12.0000, 12.2500, 12.7500, 13.2500, 13.7500, 14.2500, 14.7500,
#            15.0000]]]])
m = nn.Upsample(scale_factor=2, mode='bilinear',align_corners=True)
print(m(input))
# tensor([[[[ 0.0000,  0.4286,  0.8571,  1.2857,  1.7143,  2.1429,  2.5714,
#             3.0000],
#           [ 1.7143,  2.1429,  2.5714,  3.0000,  3.4286,  3.8571,  4.2857,
#             4.7143],
#           [ 3.4286,  3.8571,  4.2857,  4.7143,  5.1429,  5.5714,  6.0000,
#             6.4286],
#           [ 5.1429,  5.5714,  6.0000,  6.4286,  6.8571,  7.2857,  7.7143,
#             8.1429],
#           [ 6.8571,  7.2857,  7.7143,  8.1429,  8.5714,  9.0000,  9.4286,
#             9.8571],
#           [ 8.5714,  9.0000,  9.4286,  9.8571, 10.2857, 10.7143, 11.1429,
#            11.5714],
#           [10.2857, 10.7143, 11.1429, 11.5714, 12.0000, 12.4286, 12.8571,
#            13.2857],
#           [12.0000, 12.4286, 12.8571, 13.2857, 13.7143, 14.1429, 14.5714,
#            15.0000]]]])
m = nn.Upsample(scale_factor=0.5, mode='bilinear',align_corners=False)
print(m(input))
#可以起到下采样的作用

# tensor([[[[ 2.5000,  4.5000],
#           [10.5000, 12.5000]]]])

m = nn.Upsample(scale_factor=0.5, mode='nearest')
print(m(input))

# tensor([[[[ 0.,  2.],
#           [ 8., 10.]]]])
相关推荐
©️2020 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客 返回首页