神经网络 AI torch 构造自己的数据集(包含标签或者不包含标签)

AI learning 用于学习train,验证val的图片目录结构形式如下:(一般为自己构造的图像数据的目录)

这是一个简单的图像二分类问题,两个类别为正常(normal)或者异常(abnormal)。

数据集分为:train 训练集,val 验证集,test 测试集。

data---

        ---train               

                -----abnormal

                        ----001.jpg

                        ----002.jpg

                        ----....

                -----normal            

                        ----001.jpg

                        ----002.jpg

                        ----....

        ---val            

                -----abnormal

                        ----001.jpg

                        ----002.jpg

                        ----....

                -----normal

                        ----001.jpg

                        ----002.jpg

                        ----....

        ---test

                ----001.jpg

                ----002.jpg

                ----....

使用Dataset 继承,需要重新写自己的dataset函数,包含标签(abnormal 为标签1, normal 为标签0),有标签的情况主要是用于学习和验证使用。

from torch.utils.data import Dataset
from torchvision import transforms

def get_label(root, phase):
    label_list =[]
    img_list1 = []
    img_root = os.path.join(root,phase)
    imgs = os.listdir(img_root)
    for im in imgs:
        image_list = os.listdir(os.path.join(img_root,im))
        for img_path in image_list:
            img_list = os.path.join(os.path.join(img_root, im),img_path)
            label = 1 if img_list.split('\\')[-2] == 'abnormal' else 0
            label_list.append(label)
            img_list1.append(img_list)
    return img_list1, label_list

class MyData(Dataset):
    def __init__(self, root_dir, phase, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.phase = phase
        self.data = self.load_img(self)
    def load_img(self):
        image_list, label_list=get_label(self.root_dir,self.phase)
        data =[]
        for im in range(len(image_list)):
            img = Image.open(image_list[im]).convert('RGB')
            sample =(img,label_list[im])
            data.append(sample)
        return data
    def __len__(self):
        return len(self.data)
    def __getitem__(self,index):
        image_info,img_label = self.data[index]
        if self.transform:
            sample = self.transform(image_info)
        else:
            sample = image_info
        return sample,img_label

无标签的情况,主要是来进行测试用。

def get_images(root):
    img_list1 = []
    img_root = os.path.join(root)
    imgs = os.listdir(img_root)
    for im in imgs:
        image_list = os.path.join(img_root,im)
        img_list1.append(image_list)
    return img_list1

class MyDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = self.load_img()
    def load_img(self):
        image_list =get_images(self.root_dir)
        data =[]
        for im in image_list:
            img = Image.open(im).convert('RGB')
            data.append(img)
        return data
    def __len__(self):
        return len(self.data)
    def __getitem__(self,index):
        image_info = self.data[index]
        if self.transform:
            sample = self.transform(image_info)
            return sample
        else:
            return image_info

调用

from torch.utils.data import DataLoader,Dataset

data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()])
train_dataset = MyData('data','train', transform = data_transform)
val_dataset = MyData('data','val', transform = data_transform)
test_dataset = MyDataset('data\\test', transform = data_transform)
test_loader = Dataloader(test_dataset, batch_size = 32)
for step,data in enumerate(test_loader):
    images = data
  # [预测代码]

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值