本文整理汇总了Python中torchvision.datasets.SVHN属性的典型用法代码示例。如果您正苦于以下问题:Python datasets.SVHN属性的具体用法?Python datasets.SVHN怎么用?Python datasets.SVHN使用的例子?那么恭喜您, 这里精选的属性代码示例或许可以为您提供帮助。您也可以进一步了解该属性所在模块torchvision.datasets的用法示例。
在下文中一共展示了datasets.SVHN属性的26个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: get_svhn
点赞 6
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def get_svhn(train, get_dataset=False, batch_size=cfg.batch_size):
"""Get SVHN dataset loader."""
# image pre-processing
pre_process = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(
mean=cfg.dataset_mean,
std=cfg.dataset_std)])
# dataset and data loader
svhn_dataset = datasets.SVHN(root=cfg.data_root,
split='train' if train else 'test',
transform=pre_process,
download=True)
if get_dataset:
return svhn_dataset
else:
svhn_data_loader = torch.utils.data.DataLoader(
dataset=svhn_dataset,
batch_size=batch_size,
shuffle=True)
return svhn_data_loader
开发者ID:corenel,项目名称:pytorch-atda,代码行数:24,
示例2: get_loader
点赞 6
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def get_loader(config):
"""Builds and returns Dataloader for MNIST and SVHN dataset."""
transform = transforms.Compose([
transforms.Scale(config.image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform)
mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform)
svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers)
mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers)
return svhn_loader, mnist_loader
开发者ID:yunjey,项目名称:mnist-svhn-transfer,代码行数:23,
示例3: get_targets
点赞 6
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def get_targets(dataset):
"""Get the targets of a dataset without any target target transforms(!)."""
if isinstance(dataset, TransformedDataset):
return get_targets(dataset.dataset)
if isinstance(dataset, data.Subset):
targets = get_targets(dataset.dataset)
return torch.as_tensor(targets)[dataset.indices]
if isinstance(dataset, data.ConcatDataset):
return torch.cat([get_targets(sub_dataset) for sub_dataset in dataset.datasets])
if isinstance(
dataset, (datasets.MNIST, datasets.ImageFolder,)
):
return torch.as_tensor(dataset.targets)
if isinstance(dataset, datasets.SVHN):
return dataset.labels
raise NotImplementedError(f"Unknown dataset {dataset}!")
开发者ID:BlackHC,项目名称:BatchBALD,代码行数:20,
示例4: get_dataset
点赞 6
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def get_dataset(self):
"""
Uses torchvision.datasets.CIFAR100 to load dataset.
Downloads dataset if doesn't exist already.
Returns:
torch.utils.data.TensorDataset: trainset, valset
"""
trainset = datasets.SVHN('datasets/SVHN/train/', split='train', transform=self.train_transforms,
target_transform=None, download=True)
valset = datasets.SVHN('datasets/SVHN/test/', split='test', transform=self.val_transforms,
target_transform=None, download=True)
extraset = datasets.SVHN('datasets/SVHN/extra', split='extra', transform=self.train_transforms,
target_transform=None, download=True)
trainset = torch.utils.data.ConcatDataset([trainset, extraset])
return trainset, valset
开发者ID:MrtnMndt,项目名称:OCDVAEContinualLearning,代码行数:20,
示例5: __init__
点赞 5
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def __init__(self, root, train=True,
transform=None, target_transform=None, download=False):
if train:
split = 'train'
else:
split = 'test'
super(SVHN, self).__init__(root, split=split, transform=transform,
target_transform=target_transform, download=download)
# Subsample images to balance the training set
if split == 'train':
# compute the histogram of original label set
label_set = np.unique(self.labels)
num_cls = len(label_set)
count,_ = np.histogram(self.labels.squeeze(), bins=num_cls)
min_num = min(count)
# subsample
ind = np.zeros((num_cls, min_num), dtype=int)
for i in label_set:
binary_ind = np.where(self.labels.squeeze() == i)[0]
np.random.shuffle(binary_ind)
ind[i % num_cls,:] = binary_ind[:min_num]
ind = ind.flatten()
# shuffle 5 times
for i in range(100):
np.random.shuffle(ind)
self.labels = self.labels[ind]
self.data = self.data[ind]
开发者ID:jhoffman,项目名称:cycada_release,代码行数:34,
示例6: __init__
点赞 5
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def __init__(self, root, train=True,
transform=None, target_transform=None, download=False):
if train:
split = 'train'
else:
split = 'test'
super(SVHN, self).__init__(root, split=split, transform=transform,
target_transform=target_transform, download=download)
开发者ID:jhoffman,项目名称:cycada_release,代码行数:10,
示例7: __init__
点赞 5
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def __init__(self):
super(SVHNMetaInfo, self).__init__()
self.label = "SVHN"
self.root_dir_name = "svhn"
self.dataset_class = SVHNFine
self.num_training_samples = 73257
开发者ID:osmr,项目名称:imgclsmob,代码行数:8,
示例8: svhn_loaders
点赞 5
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def svhn_loaders(batch_size):
train = datasets.SVHN("./data", split='train', download=True, transform=transforms.ToTensor(), target_transform=replace_10_with_0)
test = datasets.SVHN("./data", split='test', download=True, transform=transforms.ToTensor(), target_transform=replace_10_with_0)
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False, pin_memory=True)
return train_loader, test_loader
开发者ID:locuslab,项目名称:convex_adversarial,代码行数:8,
示例9: LoadSVHN
点赞 5
# 需要导入模块: from torchvision import datasets [as 别名]
# 或者: from torchvision.datasets import SVHN [as 别名]
def LoadSVHN(data_root, batch_size=32, split='train', shuffle=True):
if not os.path.exists(data_root):
os.makedirs(data_root)
svhn_dataset = datasets.SVHN(data_root, split=split, download=True,
transform=transforms.ToTensor())
return DataLoader(svhn_dataset,batch_size=batch_size, shuffle=shuffle, drop_last=True)
开发者ID:Alexander-H-Liu,