pytorch之nn.Upsample函数

在PyTorch中,nn.Upsample是一个用于上采样(即放大)输入张量(tensor)的模块。上采样是许多计算机视觉任务中的关键步骤,特别是在图像分割、超分辨率和生成对抗网络(GANs)等领域。nn.Upsample可以通过不同的方法来放大输入张量,包括最近邻插值(nearest neighbor interpolation)、线性插值(对于二维数据,等同于双线性插值;对于三维数据,等同于三线性插值)、双三次插值(bicubic interpolation,但需注意PyTorch的nn.Upsample默认不支持双三次插值,除非使用torch.nn.functional.interpolate)等。

基本用法

nn.Upsample的基本用法如下:

import torch.nn as nn  
  
# 创建一个Upsample实例  
# size=(H, W) 是目标大小,scale_factor=(h_scale, w_scale) 是缩放比例  
# mode='nearest' 指定插值模式,其他常用选项包括'linear', 'bilinear', 'trilinear'等  
# align_corners=True 在某些插值模式下会影响输出形状  
upsample = nn.Upsample(size=None, scale_factor=2, mode='bilinear', align_corners=False)  
  
# 假设input是一个形状为[batch_size, channels, height, width]的张量  
# 应用upsample  
output = upsample(input)

注意点

  • size与scale_factor:这两个参数是互斥的,即只能指定其中一个。如果指定了size,则输出张量的大小将直接设置为size;如果指定了scale_factor,则输出张量的大小将是输入张量大小乘以scale_factor
  • 插值模式:不同的插值模式适用于不同的场景和数据类型。例如,最近邻插值速度快但质量较差,适合像素级的任务;双线性插值(对于二维数据)和双三次插值(如果可用)则能提供更平滑的放大效果,适合需要高质量图像的任务。
  • align_corners:这个参数在mode='linear''bilinear''trilinear'时有效。如果设置为True,则输出的四个角点会与输入的对齐,这可能会改变输出的大小。默认值为False

替代方案

从PyTorch 1.3版本开始,推荐使用torch.nn.functional.interpolate函数作为nn.Upsample的替代,因为它提供了更多的灵活性和功能,比如直接支持双三次插值。

import torch.nn.functional as F  
  
# 使用torch.nn.functional.interpolate进行上采样  
output = F.interpolate(input, scale_factor=2, mode='bilinear', align_corners=False)

 torch.nn.functional.interpolatenn.Upsample在功能上非常相似,但前者提供了更多的选项和灵活性,因此在新代码中更推荐使用

函数说明

Upsample

CLASS torch.nn.Upsample(size=None, scale_factor=None, mode='nearest', align_corners=None)

上采样一个给定的多通道的 1D (temporal,如向量数据), 2D (spatial,如jpg、png等图像数据) or 3D (volumetric,如点云数据)数据
假设输入数据的格式为minibatch x channels x [optional depth] x [optional height] x width。因此对于一个空间spatial输入,我们期待着4D张量的输入,即minibatch x channels x height x width。而对于体积volumetric输入,我们则期待着5D张量的输入,即minibatch x channels x depth x height x width

对于上采样有效的算法分别有对 3D, 4D和 5D 张量输入起作用的 最近邻、线性,、双线性, 双三次(bicubic)和三线性(trilinear)插值算法

你可以给定scale_factor来指定输出为输入的scale_factor倍或直接使用参数size指定目标输出的大小(但是不能同时制定两个)

参数:

  • size (int or Tuple[int] or Tuple[intint] or Tuple[intintint]optional) – 根据不同的输入类型制定的输出大小

  • scale_factor (float or Tuple[float] or Tuple[floatfloat] or Tuple[floatfloatfloat]optional) – 指定输出为输入的多少倍数。如果输入为tuple,其也要制定为tuple类型

  • mode (stroptional) – 可使用的上采样算法,有'nearest''linear''bilinear''bicubic' and 'trilinear'默认使用'nearest'

  • align_corners (booloptional) – 如果为True,输入的角像素将与输出张量对齐,因此将保存下来这些像素的值。仅当使用的算法为'linear''bilinear'or 'trilinear'时可以使用。默认设置为False

输入输出形状:

注意:

当align_corners = True时,线性插值模式(线性、双线性、双三线性和三线性)不按比例对齐输出和输入像素,因此输出值可以依赖于输入的大小。这是0.3.1版本之前这些模式的默认行为。从那时起,默认行为是align_corners = False,如下图:

上面的图是source pixel为4*4上采样为target pixel为8*8的两种情况,这就是对齐和不对齐的差别,会对齐左上角元素,即设置为align_corners = True时输入的左上角元素是一定等于输出的左上角元素。但是有时align_corners = False时左上角元素也会相等,官网上给的例子就不太能说明两者的不同(也没有试出不同的例子,大家理解这个概念就行了)

如果您想下采样/常规调整大小,您应该使用interpolate()方法,这里的上采样方法已经不推荐使用了。

UpsamplingNearest2d

CLASS torch.nn.UpsamplingNearest2d(size=None, scale_factor=None)

专门用于2D数据的线性插值算法,参数等跟上面的差不多,省略

形状:

UpsamplingBilinear2d

CLASS torch.nn.UpsamplingBilinear2d(size=None, scale_factor=None)

专门用于2D数据的双线性插值算法,参数等跟上面的差不多,省略

形状:

注意:最好还是使用nn.functional.interpolate(..., mode='bilinear', align_corners=True)

import torch
from torch import nn
input = torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2)
print(input)
print(input.shape)
m = nn.Upsample(scale_factor=2, mode='nearest')
print(m(input))
print(m(input).shape)
# m = nn.Upsample(size=(3,5), mode='bilinear',align_corners=True)
# print(m(input))
# print(m(input).shape)
m = nn.UpsamplingNearest2d(scale_factor=2)
print(m(input))
print(m(input).shape)
m = nn.UpsamplingBilinear2d(scale_factor=2)
print(m(input))
print(m(input).shape)

输出:
tensor([[[[1., 2.],
          [3., 4.]]]])
torch.Size([1, 1, 2, 2])
tensor([[[[1., 1., 2., 2.],
          [1., 1., 2., 2.],
          [3., 3., 4., 4.],
          [3., 3., 4., 4.]]]])
torch.Size([1, 1, 4, 4])
tensor([[[[1., 1., 2., 2.],
          [1., 1., 2., 2.],
          [3., 3., 4., 4.],
          [3., 3., 4., 4.]]]])
torch.Size([1, 1, 4, 4])
tensor([[[[1.0000, 1.3333, 1.6667, 2.0000],
          [1.6667, 2.0000, 2.3333, 2.6667],
          [2.3333, 2.6667, 3.0000, 3.3333],
          [3.0000, 3.3333, 3.6667, 4.0000]]]])
torch.Size([1, 1, 4, 4])

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

浩瀚之水_csdn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值