pytorch Dataset 的ImageFolder
ImageFolder例子
def load_data(root_dir,domain,batch_size):
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize([28, 28]),
transforms.ToTensor(),
transforms.Normalize(mean=(0,0,0),std=(1,1,1)),
]
)
image_folder = datasets.ImageFolder(
root=root_dir + domain,
transform=transform
)
data_loader = torch.utils.data.DataLoader(dataset=image_folder,batch_size=batch_size,shuffle=True,num_workers=2,drop_last=True
)
return data_loader
data_src = data_loader.load_data(
root_dir=rootdir, domain='amazon', batch_size=BATCH_SIZE[0])
for e in tqdm(range(1, N_EPOCH + 1)):
model = train(model=model, optimizer=optimizer,
epoch=e, data_src=data_src, data_tar=data_tar)
【PyTorch学习笔记】14:划分训练-验证-测试集,使用正则化项
torch.utils.data.random_split源码