最新学习的时候用到了torchvision中的MNIST数据,产生了一些疑问
前言
train_dataset = torchvision.datasets.MNIST(root='../MNISTdata',
train=True,
download=True)
for i, (data, labels) in enumerate(train_dataset):
if i % 60000 == 0: # MNIST里面一共有60000个数据
print(type(data))
"""
输出:
<class 'PIL.Image.Image'>
"""
type(train_dataset.data)
"""
输出:
torch.Tensor
"""
如图所示,也就是说我直接从MNIST数据集中取得的数据的数据类型与通过迭代方式取得的不一样,这让我很迷……
那到底是什么原因造成了这样的差异呢?
进到了MNIST的源码里(mnist.py)才发现,它对数据做了图像转换。
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
super(MNIST, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
data_file = self.training_file
else:
data_file = self.test_file
self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
以上是MNIST的初始化函数
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
以上是MNIST的迭代取数据方法,可以发现有以下这两行代码把数据从tensor转换到了Image类型
img, target = self.data[index], int(self.targets[index])
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
正文
实际的转换方法有很多,以下方法只是其中之一
1、tensor转换为image(PIL)
tensor→numpy→Image
numpy = tensor.numpy()
image = Image.fromarray(numpy())
2、image转换为tensor
采用torchvision.transforms方法
transform = torchvision.transforms.Compose([
transforms.ToTensor()])
tensor = transform(image)