简而言之,ToTensor()主要有三个作用:
- 将 PIL Image 或 numpy.ndarray 转为 tensor
- 将数据范围从[0, 255]转换为[0.0, 1.0]
- 将 HWC 的图像格式转为 CHW 的 tensor 格式。CNN训练时需要的数据格式是[N,C,N,W],也就是说经过 ToTensor() 处理的图像可以直接输入到CNN网络中,不需要再进行reshape。
但是,在加载MNIST数据集时,发现即便传入了transform参数,img并未像预期那样被压缩到(0,1)
具体代码如下:
apply_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=apply_transform )
test_dataset_all = datasets.MNIST(data_dir, train=False, download=True,
transform=apply_transform )
在参考了torchvision.transforms.ToTensor()不缩放问题
和torchvision.transforms 数据预处理:ToTensor()
后,我了解了totensor的工作机制,发现,加载的数据集类型并不满足totensor转化的输入格式要求,所以我把数据集先转化为了np.array,再对其应用totensor进行转化和归一化到[0,1],代码如下:
train_dataset = datasets.MNIST(data_dir, train=True, download=True,
transform=None)
test_dataset_all = datasets.MNIST(data_dir, train=False, download=True,
transform=None)
train_dataset = transforms.ToTensor()(np.array(train_dataset.data))
train_dataset = transforms.Normalize([0.1307, ], [0.3081, ])(train_dataset)
其中,0.1307是mnist数据集的均值,0.3081是mnist数据集的标准差,可以通过以下方式求得:
# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='data/', train=True, transform=None, download=True)
test_dataset = datasets.MNIST(root='data/', train=False, transform=None, download=True)
# 获取像素数组
train_data = train_dataset.data
test_data = test_dataset.data
# 转换为浮点数数组
train_data_float = train_data.float()
test_data_float = test_data.float()
# 计算均值和标准差
mean = train_data_float.mean()
std = train_data_float.std()
# 定义归一化和缩放的转换
normalize = transforms.Normalize(mean=[mean], std=[std])
这里,也记录一下FasionMnist和cifar10的归一化数据
train_dataset = datasets.FashionMNIST(data_dir, train=True, download=True,
transform=None)
test_dataset_all = datasets.FashionMNIST(data_dir, train=False, download=True,
transform=None)
train_dataset = transforms.ToTensor()(np.array(train_dataset))
train_dataset = transforms.Normalize([0.2860, ], [0.3530, ])(train_dataset)
#cifar10是[0.5,0.5,0.5],[0.5,0.5,0.5]