数据集说明
MNIST数据集
MNIST手写数字数据集是基础的数字分类数据集,MNIST的训练集中包含了60000个数字,测试集中包含10000个数字。每张图都是灰度图,尺寸都为固定的28*28。
SVHN数据集
SVHN 是一个从谷歌街景图像中的门牌号获得的真实世界的图像数据集,用于开发机器学习和对象识别算法。
包含73257 个用于训练的数字,26032 个用于测试的数字,以及 531131 个额外的、难度稍低的样本,用作额外的训练数据。
有两种格式:
- 具有字符级边界框的原始图像。
- 以单个字符为中心的类似 MNIST 的 32×32 图像(许多图像的侧面包含一些干扰物)
格式一:
格式二:
通过下面的方式下载导入的是第二种格式的图片,若要自行下载数据集可到官网:http://ufldl.stanford.edu/housenumbers/
配对
利用torch.utils.data中的Dataset可以灵活的实现MNIST和SVHN的随机配对:
from torch.utils.data import Dataset
from torchvision import transforms
from torchvison.datasets import MNIST, SVHN
import random
class MNIST_SVHN(Dataset):
data_transform = {
'mnist': transforms.Compose([
transforms.Resize([32, 32]),
transforms.ToTensor(),
]),
'svhn': transforms.ToTensor()
}
def __init__(self,
data_path: str,
split: str
):
if split == 'train':
self.mnist = MNIST(data_path, train=True, download=False,
transform=self.data_transform['mnist'])
self.svhn = SVHN(data_path, split='train', download=False,
transform=self.data_transform['svhn'])
elif split == 'test':
self.mnist = MNIST(data_path, train=False, download=False,
transform=self.data_transform['mnist'])
self.svhn = SVHN(data_path, split='test', download=False,
transform=self.data_transform['svhn'])
def __len__(self):
return len(self.svhn)
def __getitem__(self, idx):
svhn_data, svhn_label = self.svhn[idx]
mnist_l, mnist_l_idx = self.mnist.targets.sort()
cor_label_list = mnist_l_idx[mnist_l == svhn_label]
len_list = len(cor_label_list)
random_idx = random.randrange(len_list)
cor_minst_idx = cor_label_list[random_idx]
mnist_data = self.mnist[cor_minst_idx][0]
return [mnist_data, svhn_data], svhn_label