从train和test的txt文件中获取每个图片的路径以及其对应的标签
博主的训练和测试文件中的路径和标签的表示形式如下:
导入相关库
import torch
from torch.utils.data.dataset import Dataset
from PIL import Image
import numpy as np
from torchvision import transforms
从文件中读取图片和相应的标签
依据生成图片和标签的格式获取相应的图片和标签,在这里因为图片路径和标签之间以空格区分,所以使用空格进行划分。
def read_txt(path):
imgs, labels = [], []
with open(path, 'a') as f:
for line in f.readlines():
im, label = line.strip().split(' ')
imgs.append(im)
labels.append(int(label))
return imgs, labels
对传入的数据集进行处理
处理训练和测试数据集
class MyCustomDataset(Dataset):
def __init__(self, root_path, class_num=4, transform=None):
self.txt_path = root_path
self.img_path, self.img_label = read_txt(root_path)
self.transform=transform
def __getitem__(self, index):
img_path = self.img_path[index]
img = Image.open(img_path)
if self.transform is not None:
img = self.transform(img)
label = np.array(self.img_label[index])
label = torch.from_numpy(label).type(torch.long)
return img, label
def __len__(self):
return len(self.img_path)
在主函数中将上述模块导入,并使用其处理测试和训练数据集
train_data = MyCustomDataset(args.img_tr, transform=train_transform)
test_data = MyCustomDataset(args.img_te, transform=test_transform)
依据训练和测试数据集所在的路径生成相应的训练和测试txt文件(标签和图片路径以空格划分)
导入相关库
import os
import sys
from pathlib import Path
import os
遍历文件夹获取图片路径并生成相应的标签
训练和测试文件夹中的类别目录如下:
def listfiles(rootDir , txtfile, label=0):
ftxtfile = open(txtfile, 'w')
list_dirs = os.walk(rootDir)
count = 0
dircount = 0
label = -1
# 遍历文件夹中图片,并读取路径信息
for root, dirs, files in list_dirs:
for f in files:
ftxtfile.write(os.path.join(root, f) + ' ' + str(label)+ '\n')
print(os.path.join(root, f) + ' ' + str(label) + '\n')
count += 1
label += 1
将数据集路径以及生成测试或训练txt文件的路径传入函数
if __name__ == '__main__':
listfiles(r'autodl-tmp/mydata/dataset/AID/AID_test5', r'autodl-tmp/mydata/dataset/AID/AID_test5/test.txt')
使用时换成自己文件夹所在的位置即可。