img = Image.open(path)
img = img.resize((200, 100), Image.ANTIALIAS)
transf = transforms.ToTensor()
img_tensor = transf(img)
# 第一维为batch
train_x = torch.normal()
train_y = torch.normal()
train_set = Data.TensorDataset(train_x, train_y)
train_iter = Data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=0)
def load_data_fashion_mnist(batch_size, num_workers=0, resize=None):
'''load dataset fashion_mnist'''
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(root="../data",
train=True,
transform=trans,
download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data",
train=False,
transform=trans,
download=True)
return (torch.utils.data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=num_workers),
torch.utils.data.DataLoader(mnist_test, batch_size, shuffle=False, num_workers=num_workers))
图片转换为dataloader和数据集加载函数
最新推荐文章于 2023-04-03 21:17:19 发布