解决办法
file = open(file_path, 'r',encoding='utf-8')
class MyDataset(Dataset):
def __init__(self, txt, data_path=None, transform=None, target_transform=None, loader=default_loader):
super(MyDataset, self).__init__() # 对继承父类的属性初始化
# 在__init__()方法中得到图像的路径,然后将图像路径组成一个数组
file_path = data_path + txt
file = open(file_path, 'r',encoding='utf-8')
imgs = []
for line in file:
line = line.split()
# print(line[0].rstrip(',')) # img
# print(line[1].rstrip('\n')) # label
imgs.append((line[0].rstrip(','), line[1].rstrip('\n')))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.data_path = data_path
def __getitem__(self, index):
# 按照索引读取每个元素的具体内容
imgName, label = self.imgs[index]
imgPath = self.data_path + imgName
img = self.loader(imgPath)
if self.transform is not None:
img = self.transform(img) # 数据标签转换为Tensor
label = torch.from_numpy(np.array(int(label)))
return img, label
def __len__(self):
# 数据集的图片数量
return len(self.imgs)