MNIST和SVHN数据集配对

这篇博客介绍了两个经典的数据集——MNIST和SVHN,分别用于手写数字识别和真实世界图像的数字识别。文章详细阐述了两个数据集的特点,包括图像大小和数量,并提供了数据集的下载链接。此外,还展示了如何使用PyTorch库将这两个数据集进行随机配对,以创建更具挑战性的训练集。这有助于在深度学习模型中提升泛化能力。
摘要由CSDN通过智能技术生成

数据集说明

MNIST数据集

MNIST手写数字数据集是基础的数字分类数据集,MNIST的训练集中包含了60000个数字,测试集中包含10000个数字。每张图都是灰度图,尺寸都为固定的28*28。
在这里插入图片描述

SVHN数据集

SVHN 是一个从谷歌街景图像中的门牌号获得的真实世界的图像数据集,用于开发机器学习和对象识别算法。

包含73257 个用于训练的数字,26032 个用于测试的数字,以及 531131 个额外的、难度稍低的样本,用作额外的训练数据。

有两种格式:

  1. 具有字符级边界框的原始图像。
  2. 以单个字符为中心的类似 MNIST 的 32×32 图像(许多图像的侧面包含一些干扰物)

格式一:

格式二:

通过下面的方式下载导入的是第二种格式的图片,若要自行下载数据集可到官网:http://ufldl.stanford.edu/housenumbers/

配对

利用torch.utils.data中的Dataset可以灵活的实现MNIST和SVHN的随机配对:

from torch.utils.data import Dataset
from torchvision import transforms
from torchvison.datasets import MNIST, SVHN
import random


class MNIST_SVHN(Dataset):
    data_transform = {
        'mnist': transforms.Compose([
            transforms.Resize([32, 32]),
            transforms.ToTensor(),
        ]),
        'svhn': transforms.ToTensor()
    }

    def __init__(self,
                 data_path: str,
                 split: str
                 ):
        if split == 'train':
            self.mnist = MNIST(data_path, train=True, download=False,
                               transform=self.data_transform['mnist'])
            self.svhn = SVHN(data_path, split='train', download=False,
                             transform=self.data_transform['svhn'])
        elif split == 'test':
            self.mnist = MNIST(data_path, train=False, download=False,
                               transform=self.data_transform['mnist'])
            self.svhn = SVHN(data_path, split='test', download=False,
                             transform=self.data_transform['svhn'])

    def __len__(self):
        return len(self.svhn)

    def __getitem__(self, idx):
        svhn_data, svhn_label = self.svhn[idx]
        mnist_l, mnist_l_idx = self.mnist.targets.sort()
        cor_label_list = mnist_l_idx[mnist_l == svhn_label]
        len_list = len(cor_label_list)
        random_idx = random.randrange(len_list)
        cor_minst_idx = cor_label_list[random_idx]
        mnist_data = self.mnist[cor_minst_idx][0]

        return [mnist_data, svhn_data], svhn_label
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值