torchvision.transforms.ToTensor()不缩放问题,没有变化

本文讲述了ToTensor()函数在PyTorch中的作用,包括数据类型转换、范围调整和格式调整。作者在处理MNIST和FashionMNIST数据集时遇到ToTensor()未正确缩放的问题,通过先转化为numpy数组再进行转换和归一化解决了这一问题,并提供了CIFAR10的数据预处理示例。
摘要由CSDN通过智能技术生成

简而言之,ToTensor()主要有三个作用:

  1. 将 PIL Image 或 numpy.ndarray 转为 tensor
  2. 将数据范围从[0, 255]转换为[0.0, 1.0]
  3. 将 HWC 的图像格式转为 CHW 的 tensor 格式。CNN训练时需要的数据格式是[N,C,N,W],也就是说经过 ToTensor() 处理的图像可以直接输入到CNN网络中,不需要再进行reshape。

但是,在加载MNIST数据集时,发现即便传入了transform参数,img并未像预期那样被压缩到(0,1)

具体代码如下:

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=apply_transform )

        test_dataset_all = datasets.MNIST(data_dir, train=False, download=True,
                                          transform=apply_transform )

在参考了torchvision.transforms.ToTensor()不缩放问题
torchvision.transforms 数据预处理:ToTensor()
后,我了解了totensor的工作机制,发现,加载的数据集类型并不满足totensor转化的输入格式要求,所以我把数据集先转化为了np.array,再对其应用totensor进行转化和归一化到[0,1],代码如下:

        train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                       transform=None)

        test_dataset_all = datasets.MNIST(data_dir, train=False, download=True,
                                          transform=None)
        train_dataset = transforms.ToTensor()(np.array(train_dataset.data))
        train_dataset = transforms.Normalize([0.1307, ], [0.3081, ])(train_dataset)

其中,0.1307是mnist数据集的均值,0.3081是mnist数据集的标准差,可以通过以下方式求得:

# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='data/', train=True, transform=None, download=True)
test_dataset = datasets.MNIST(root='data/', train=False, transform=None, download=True)

# 获取像素数组
train_data = train_dataset.data
test_data = test_dataset.data

# 转换为浮点数数组
train_data_float = train_data.float()
test_data_float = test_data.float()

# 计算均值和标准差
mean = train_data_float.mean()
std = train_data_float.std()

# 定义归一化和缩放的转换
normalize = transforms.Normalize(mean=[mean], std=[std])

这里,也记录一下FasionMnist和cifar10的归一化数据

        train_dataset = datasets.FashionMNIST(data_dir, train=True, download=True,
                                              transform=None)

        test_dataset_all = datasets.FashionMNIST(data_dir, train=False, download=True,
                                                 transform=None)
        train_dataset = transforms.ToTensor()(np.array(train_dataset))
        train_dataset = transforms.Normalize([0.2860, ], [0.3530, ])(train_dataset)
        #cifar10是[0.5,0.5,0.5],[0.5,0.5,0.5]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值