[n, h, w, c]转[n, c, h, w]
由于Pytorch默认图片数据格式为[n, c, h, w]、rgb,因此若数据集为[n, h, w, c]格式时需要进行转换:
# 为便于处理,先转为numpy数组
x = np.array(trainset['train_set_x'])
x.shape, type(x)
# ((1080, 64, 64, 3), numpy.ndarray)
x = x.transpose([0, 3, 1, 2])
x = t.as_tensor(x)
x.size()
# torch.Size([1080, 3
原创
2020-08-12 17:27:31 ·
1710 阅读 ·
0 评论