带validation, label和unlabel dataloader 分开,如果不需要验证集,同时需要把label 和 unlabel data分开设置loader可参考下面代码进行修改,删除validation部分。
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
import numpy as np
from functools import reduce
from operator import __or__
def load_data_val(path, args):
if args.dataset == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif args.dataset == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif args.dataset == 'svhn':
mean = [x / 255 for x in [127.5, 127.5, 127.5]]
std = [x / 255 for x in [127.5, 127.5, 127.5]]
elif args.dataset == 'mnist':
mean = (0.5,)
std = (0.5,)
elif args.dataset == 'stl10':
assert False, 'Do not finish stl10 code'
elif args.dataset == 'imagenet':
assert False, 'Do not finish imagenet code'
else:
assert False, "Unknow dataset : {}".format(args.dataset)
if args.dataset == 'svhn':
train_transform = transforms.Compose([
transform