参考文章:https://cloud.tencent.com/developer/article/1433735
数据集目录
– data
– train
– 0.jpg
– 1.jpg
…
– test
– 0.jpg
– 1.jpg
…
– train.txt
– test.txt
train.text内容
DataSet重写代码
- 加载数据
- 重写 getitem() 函数
- 重写 len() 函数
from torch.utils.data import Dataset
import os
from PIL import Image
class MyDataset(Dataset):
def __init__(self, root, train=True, transform=None):
super(MyDataset, self).__init__()
self.root = root
self.train = train
self.transform = transform
if train:
file = open(root + 'train.txt', 'r')
else:
file = open(root + 'test.txt', 'r')
imgs = []
for line in file: # 按行循环 txt 文本
line = line.rstrip() # 删除本行字符串末尾的 '\n'
word = line.split() # 将图片图片路径和图片标签分开
imgs.append((word[0], int(word[1]))) # 把 txt 中的内容保存到 imgs 列表
self.imgs = imgs
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(self.root+fn).convert('RGB') # 按路径读取图片
if self.transform is not None:
img = self.transform(img) # 对图片进行 transform 操作
return img, label
def __len__(self):
return len(self.imgs)
使用重写的Dataset类
from torch.utils.data import DataLoader
from torchvision import transforms
if __name__ == '__main__':
tran = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
tr_set = MyDataset(root='./data/', train=True, tran)
tr_loader = DataLoader(
dataset=tr_set,
shuffle=True,
batch_size=32
)