AI learning 用于学习train,验证val的图片目录结构形式如下:(一般为自己构造的图像数据的目录)
这是一个简单的图像二分类问题,两个类别为正常(normal)或者异常(abnormal)。
数据集分为:train 训练集,val 验证集,test 测试集。
data---
---train
-----abnormal
----001.jpg
----002.jpg
----....
-----normal
----001.jpg
----002.jpg
----....
---val
-----abnormal
----001.jpg
----002.jpg
----....
-----normal
----001.jpg
----002.jpg
----....
---test
----001.jpg
----002.jpg
----....
使用Dataset 继承,需要重新写自己的dataset函数,包含标签(abnormal 为标签1, normal 为标签0),有标签的情况主要是用于学习和验证使用。
from torch.utils.data import Dataset
from torchvision import transforms
def get_label(root, phase):
label_list =[]
img_list1 = []
img_root = os.path.join(root,phase)
imgs = os.listdir(img_root)
for im in imgs:
image_list = os.listdir(os.path.join(img_root,im))
for img_path in image_list:
img_list = os.path.join(os.path.join(img_root, im),img_path)
label = 1 if img_list.split('\\')[-2] == 'abnormal' else 0
label_list.append(label)
img_list1.append(img_list)
return img_list1, label_list
class MyData(Dataset):
def __init__(self, root_dir, phase, transform=None):
self.root_dir = root_dir
self.transform = transform
self.phase = phase
self.data = self.load_img(self)
def load_img(self):
image_list, label_list=get_label(self.root_dir,self.phase)
data =[]
for im in range(len(image_list)):
img = Image.open(image_list[im]).convert('RGB')
sample =(img,label_list[im])
data.append(sample)
return data
def __len__(self):
return len(self.data)
def __getitem__(self,index):
image_info,img_label = self.data[index]
if self.transform:
sample = self.transform(image_info)
else:
sample = image_info
return sample,img_label
无标签的情况,主要是来进行测试用。
def get_images(root):
img_list1 = []
img_root = os.path.join(root)
imgs = os.listdir(img_root)
for im in imgs:
image_list = os.path.join(img_root,im)
img_list1.append(image_list)
return img_list1
class MyDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.data = self.load_img()
def load_img(self):
image_list =get_images(self.root_dir)
data =[]
for im in image_list:
img = Image.open(im).convert('RGB')
data.append(img)
return data
def __len__(self):
return len(self.data)
def __getitem__(self,index):
image_info = self.data[index]
if self.transform:
sample = self.transform(image_info)
return sample
else:
return image_info
调用
from torch.utils.data import DataLoader,Dataset
data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()])
train_dataset = MyData('data','train', transform = data_transform)
val_dataset = MyData('data','val', transform = data_transform)
test_dataset = MyDataset('data\\test', transform = data_transform)
test_loader = Dataloader(test_dataset, batch_size = 32)
for step,data in enumerate(test_loader):
images = data
# [预测代码]