半监督数据加载:把需要设置为无标签样本的标签设置为-1,这样可以在交叉熵的时候设置忽略-1的标签
class_criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=NO_LABEL)
数据加载多种多样:
1.只有train, test,没有validattion。通过 torchvision.dataset非图片。
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from functools import reduce
from operator import __or__
from torch.utils.data.sampler import Sampler
import itertools
import numpy as np
def load_data(path, args, NO_LABEL=-1):
if args.dataset == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif args.dataset == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif args.dataset == 'svhn':
mean = [x / 255 for x in [127.5, 127.5, 127.5]]
std = [x / 255 for x in [127.5, 127.5, 127.5]]
elif args.dataset == 'mnist':
mean = (0.5, )
std = (0.5, )
elif args.dataset == 'stl10':
assert False, 'Do not finish stl10 code'
elif args.dataset == 'imagenet':
assert False, 'Do not finish imagenet code'
else:
assert False, "Unknow dataset : {}".format(args.dataset)
if args.dataset == 'svhn':
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=2),
transforms.ToTensor(),
transforms.Normalize(mean