PyTorch计算机视觉入门:从官方数据集到自定义数据集的获取

一、PyTorch与计算机视觉简介

PyTorch是一个开源的深度学习框架,其动态图的特性非常适合快速实验和模型原型设计。在计算机视觉任务中,如图像分类、目标检测、图像分割等,PyTorch提供了丰富的API和预训练模型,帮助开发者快速搭建和优化模型。

二、使用官方数据集

1. 数据集准备

PyTorch附带了torchvision库,它不仅包含了常用的计算机视觉模型,还有对经典数据集(如CIFAR-10、CIFAR-100、MNIST、ImageNet等)的便捷访问。以MNIST为例,您可以这样加载数据集:

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='data/',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='data/',
                                          train=False,
                                          transform=transforms.ToTensor())

数据将会保存在data路径下

三、生成自己的数据集合

1、使用官方数据集从压缩包转成图片跟标签。

def convert_to_img(train=True):
    if(train):
        f=open('./data/train.txt','w')
        data_path='./data/data_train/'
        if(not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i,(img,label) in enumerate(zip(train_set[0],train_set[1])):
            img_path=data_path+str(i)+'.jpg'
            print('train_img_path:', img_path, 'train_img_num:', i)
            io.imsave(img_path,img.numpy())
            f.write(str(label.item()) + '\n')
        f.close()
    else:
        f = open('./data/test.txt', 'w')
        data_path = './data/data_test/'
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i, (img, label) in enumerate(zip(test_set[0], test_set[1])):
            img_path = data_path + str(i) + '.jpg'
            print('test_img_path:', img_path, 'test_img_num:', i)
            io.imsave(img_path, img.numpy())
            f.write(str(label.item()) + '\n')
        f.close()

最终我们便将官方数据集合转成自己的数据集,可以自行使用。最终的数据的组成如下:

四、构建自定义数据集

当标准数据集不能满足特定需求时,创建自定义数据集变得尤为重要。

1. 数据集结构

首先,您需要按照一定的结构组织您的数据。一般建议为每个类别创建单独的文件夹,文件夹内存放对应类别的图片。

2. 编写数据集类

继承torch.utils.data.Dataset,实现__len____getitem__方法:

class CustomImageDataset(Dataset):
    def __init__(self, data_path, model, transform=None, target_transform=None):
        self.data_path = data_path
        self.model = model
        self.img_labels = []
        self.image_lists =[]
        self.transform = transform
        self.target_transform = target_transform
        self.obtain_label_image()

    def __len__(self):
        return len(self.img_labels)


    def __getitem__(self, idx):
        img = Image.open(self.image_lists[idx])
        image = np.array(img)
        label = self.img_labels[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

    def obtain_label_image(self):
        if(self.model == "train"):
            # 指定文件夹路径
            folder_path = self.data_path + 'data_train/'
            # 获取文件夹中的文件列表
            file_list = os.listdir(folder_path)
            for i in range(len(file_list)):
                image_path = folder_path + str(i) +".jpg"
                #print(image_path)
                self.image_lists.append(image_path)

            file_path = self.data_path + 'train.txt'  # 替换为实际文件路径
            with open(file_path, 'r') as file:
                # 逐行读取文件内容
                for line in file:
                    # 处理每一行的数据,例如打印或存储
                    self.img_labels.append(int(line.strip()))  # 使用strip()方法去除行末的换行符
        if (self.model == "test"):
            # 指定文件夹路径
            folder_path = self.data_path + 'data_test/'
            # 获取文件夹中的文件列表
            file_list = os.listdir(folder_path)
            for i in range(len(file_list)):
                image_path = folder_path + str(i) +".jpg"
                #print(image_path)
                self.image_lists.append(image_path)

            file_path = self.data_path + 'test.txt'  # 替换为实际文件路径
            with open(file_path, 'r') as file:
                # 逐行读取文件内容
                for line in file:
                    # 处理每一行的数据,例如打印或存储
                    self.img_labels.append(int(line.strip()))  # 使用strip()方法去除行末的换行符

通过以上步骤,您已成功使用PyTorch从官方数据集过渡到了自定义数据集的训练流程,这是进行计算机视觉项目定制化研究和应用的重要起点。随着实践的深入,您将能够更熟练地利用PyTorch的强大功能,探索更多计算机视觉的前沿应用。

关注我的公众号auto_driver_ai(Ai fighting), 第一时间获取更新内容。

  • 9
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch允许您创建自定义数据集以便于加载和处理您自己的数据。以下是一个简单的示例来创建自定义数据集: 首先,您需要导入必要的库: ```python import torch from torch.utils.data import Dataset ``` 然后,创建一个继承自`Dataset`类的自定义数据集类,并实现以下方法: - `__init__`:初始化数据集,例如加载数据或设置转换。 - `__len__`:返回数据集的大小。 - `__getitem__`:根据给定的索引返回一个样本。 下面是一个示例,假设您有一组图像数据和相应的标签: ```python class CustomDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, index): sample = self.data[index] label = self.labels[index] # 在这里进行必要的数据转换 return sample, label ``` 在上面的示例中,`data`是图像数据的列表,`labels`是相应的标签的列表。然后,您可以在`__getitem__`方法中执行必要的数据转换,例如将图像转换为张量或应用任何其他预处理步骤。 要使用自定义数据集,您可以创建一个实例并将其传递给`DataLoader`类: ```python # 假设您有图像数据和标签 data = [...] # 图像数据列表 labels = [...] # 标签列表 # 创建自定义数据集实例 dataset = CustomDataset(data, labels) # 创建数据加载器 dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True) ``` 现在,您可以使用`dataloader`来迭代加载批量的数据,并在训练模型时使用它们。 这只是一个简单的示例,您可以根据您的需求进行更多的自定义和扩展。希望这可以帮助到您!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值