数据格式要求
The 4 dimensions of input_patch are <batch size, image height, image width, image channel> respectively. In Pytorch, the input channel should be in the second dimension. That’s why the permutation is required.
After the permutation, the 4 dimensions of in_img will be <batch size, image channel, image height, image width> 1.
将 numpy格式图像转化为相应的tensor格式
在喂入深度学习网络之前,需要对数据格式进行转换,需要将 N x H x W x C
的 numpy格式图像转化为相应的 tensor格式 N x C x H x W
2:
def toTensor(img):
img = torch.from_numpy(img.transpose((0, 3, 1, 2)))
return img.float().div(255).unsqueeze(0)
按要求对输入数据进行格式转换
使用.permute()
函数对数据的各维度进行调整,并使用.unsqueeze()
函数对其进行升维,代码及演示如下:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time : 2022/7/12 12:26
# @FileName: unsqueeze.py
# @Software: PyCharm
import torch
x = torch.rand(8,128,192)
print(x.shape) # torch.Size([8, 128, 192])
x1 = torch.unsqueeze(x, -1)
print(x1.shape) # torch.Size([8, 128, 192, 1])
print(x1[0].shape) # torch.Size([128, 192, 1])
print(x1[1].shape) # torch.Size([128, 192, 1])
print(x1[-1].shape) # torch.Size([128, 192, 1])
如果我的这篇文章帮助到了你,那我也会感到很高兴,一个人能走多远,在于与谁同行。