face_dataloader

from torch.utils.data import DataLoader
from torchvision import transforms as trans
from torchvision.datasets import ImageFolder
import numpy as np
e_preprocess(tensor):
    return tensor*0.5+0.5

def get_train_dataset(imgs_folder):
    train_transform = trans.Compose([
        trans.RandomHorizontalFlip(),
        trans.ToTensor(),
        trans.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    ])
    ds = ImageFolder(imgs_folder,train_transform)
    class_num = ds[-1][1]+1
    return ds,class_num
def get_train_loader(conf):
    ds,class_num = get_train_dataset(conf.imgpath/'imgs')
    loader = DataLoader(ds,batch_size=conf.batch_size,shuffle=True,pin_memory=conf.pin_memory,num_workers=conf.num_workers)
    return loader,class_num
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值