from torch.utils.data import DataLoader
from torchvision import transforms as trans
from torchvision.datasets import ImageFolder
import numpy as np
e_preprocess(tensor):
return tensor*0.5+0.5
def get_train_dataset(imgs_folder):
train_transform = trans.Compose([
trans.RandomHorizontalFlip(),
trans.ToTensor(),
trans.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
ds = ImageFolder(imgs_folder,train_transform)
class_num = ds[-1][1]+1
return ds,class_num
def get_train_loader(conf):
ds,class_num = get_train_dataset(conf.imgpath/'imgs')
loader = DataLoader(ds,batch_size=conf.batch_size,shuffle=True,pin_memory=conf.pin_memory,num_workers=conf.num_workers)
return loader,class_num
face_dataloader
最新推荐文章于 2022-06-05 20:35:36 发布