半监督之数据加载

本文探讨了在半监督学习中如何处理数据,特别是将无标签样本标记为-1以在交叉熵计算中忽略。内容涉及数据加载方法,包括只包含训练集和测试集的情况,以及使用torchvision.dataset处理非图片数据。
摘要由CSDN通过智能技术生成

半监督数据加载:把需要设置为无标签样本的标签设置为-1,这样可以在交叉熵的时候设置忽略-1的标签

class_criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=NO_LABEL)

数据加载多种多样:

1.只有train, test,没有validattion。通过 torchvision.dataset非图片。

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from functools import reduce
from operator import __or__
from torch.utils.data.sampler import Sampler
import itertools
import numpy as np

def load_data(path, args, NO_LABEL=-1):
    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif args.dataset == 'svhn':
        mean = [x / 255 for x in [127.5, 127.5, 127.5]]
        std = [x / 255 for x in [127.5, 127.5, 127.5]]
    elif args.dataset == 'mnist':
        mean = (0.5, )
        std = (0.5, )
    elif args.dataset == 'stl10':
        assert False, 'Do not finish stl10 code'
    elif args.dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)

    if args.dataset == 'svhn':
        train_transform = transforms.Compose([
             transforms.RandomCrop(32, padding=2),
             transforms.ToTensor(),
             transforms.Normalize(mean
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值