初次使用时,还不会使用,该函数可以对一个batch的图像数据进行上下采样,但是有一些地方需要注意,否则会报错!!!!
我想要对(64, 256, 256)大小的图像进行2倍下采样,于是有了下面输入
import torch
from torch import nn
import torch.nn.functional as F
if __name__ == '__main__':
input = torch.randn(64, 256, 256)
print(input)
outputs = F.interpolate(input, scale_factor=0.5, mode='nearest')
print(outputs)
print(outputs. Shape)
输出结果是:
torch.Size([64, 256, 128])
很明显不是我想要的(64,128,128)
经过修改,应该添加一个维度:
import torch
from torch import nn
import torch.nn.functional as F
if __name__ == '__main__':
input = torch.randn(1, 64, 256, 256) # b, c, h, w
print(input)
outputs = F.interpolate(input, scale_factor=0.5, mode='nearest')
print(outputs)
print(outputs. Shape)
运行结果为:
torch.Size([1, 64, 128, 128])
这样就对了,但是后面要恢复到原来的维度,使用.squeeze(0)函数去掉一个维度。