pytorch-构建自己的dataset类

如今,与keras 、tf相比,pytorch高效的资源利用率,越来越多的Aier应用pytorch。我来讲讲初入pytorch最重要的东西:dataset

网上有很多介绍pytorch dataset类的文章,不过大多数都是讲解某一类任务的数据集模型建立。不太具有泛化性,本文将提出一个通用的数据集接口解决技巧,供大家参考

实验环境:

python==3.7.3

ubuntu==16.04

pytorch==1.1.0


dataset类

为什么dataset是初入pytorch最重要的东西?因为我们复现项目的时候,最需要改的就是数据集其他调调参改改模型问题都不大。

如果弄明白了pytorch中dataset类,你可以创建适应任意模型的数据集接口

所谓数据集,无非就是一组{x:y}的集合吗,你只需要在这个类里说明“有一组{x:y}的集合”就可以了。

对于图像分类任务,图像+分类

对于目标检测任务,图像+bbox、分类

对于超分辨率任务,低分辨率图像+超分辨率图像

对于文本分类任务,文本+分类

...

你只需定义好这个项目的x和y是什么好了,上面都是扯闲篇,我们直接看dataset代码:

class Dataset(object):
    """An abstract class representing a Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """
 
    def __getitem__(self, index):
        raise NotImplementedError
 
    def __len__(self):
        raise NotImplementedError
 
    def __add__(self, other):
        return ConcatDataset([self, other])

 

上面的代码是pytorch给出的官方代码,其中__getitem__和__len__是子类必须继承的。

很好解释,pytorch给出的官方代码限制了标准,你要按照它的标准进行数据集建立首先,__getitem__就是获取样本对,模型直接通过这一函数获得一对样本对{x:y}__len__是指数据集长度。

自己建立一个dataset试试:

class MyDataSet(Dataset):
    def __init__(self):
        self.sample_list = ...
 
    def __getitem__(self, index):
        x= ...
        y= ...
        return x, y
 
    def __len__(self):
        return len(self.sample_list)

咱只需按照需求把模板填完就Ok了,那么为什么说这个模板使用于各种任务的数据集建造呢?还得依靠一个trick:通过txt文件映射

举个实例,假设我要给一个分类器训练喂数据,我的数据是images+number的组合,比如{img:3},这代表这个图像应该分在“3”类。我怎么写代码呢?

from torch.utils.data import Dataset
 
class MyDataSet(Dataset):
    def __init__(self, dataset_type, transform=None, update_dataset=False):
        """
        dataset_type: ['train', 'test']
        """
 
        dataset_path = '/home/muzhan/projects/dataset/'
 
        if update_dataset:
            make_txt_file(dataset_path)  # update datalist
 
        self.transform = transform
        self.sample_list = list()
        self.dataset_type = dataset_type
        f = open(dataset_path + self.dataset_type + '/datalist.txt')
        lines = f.readlines()
        for line in lines:
            self.sample_list.append(line.strip())
        f.close()
 
    def __getitem__(self, index):
        item = self.sample_list[index]
        # img = cv2.imread(item.split(' _')[0])
        img = Image.open(item.split(' _')[0])
        if self.transform is not None:
            img = self.transform(img)
        label = int(item.split(' _')[-1])
        return img, label
 
    def __len__(self):
        return len(self.sample_list)

上面有个transform参数,用于对数据集进行预处理的,可以根据项目选择使用。

上面有一个make_txt_file的函数需要说明一下,这个函数可以在数据集目录下创建一个txt文件,代表x和y的映射关系。这个函数大家可以自己写,一个简单脚本而已,我就不共享代码了 。(如有需要,留言告知)

我给大家看一下我的datalist.txt中的几行:

/home/shiwuzhe/projects/dataset/test/250_04.png _0
/home/shiwuzhe/projects/dataset/test/250_05.png _7
/home/shiwuzhe/projects/dataset/test/250_06.png _3
/home/shiwuzhe/projects/dataset/test/250_07.png _2
/home/shiwuzhe/projects/dataset/test/250_08.png _2
/home/shiwuzhe/projects/dataset/test/250_09.png _3
/home/shiwuzhe/projects/dataset/test/250_10.png _4
/home/shiwuzhe/projects/dataset/test/250_11.png _0
/home/shiwuzhe/projects/dataset/test/250_12.png _9

 

这样就可以理解我在__getitem__函数中解析x和y的方法吧,在文本中用字符串' _'隔开,当然你可以用其他字符,能够保证剪切字符串不出错即可。

我们需要测试这个dataset类是否成功:

if __name__ == '__main__':
    ds = MyDataSet()
    print(ds.__len__())
    img, gt = ds.__getitem__(34) # get the 34th sample
    print(type(img))
    print(gt)

 

上面有输出,并且和你数据集一致,那证明这个dataset类是成功的。

有了这个,用DataLoader函数就可以加载我们的数据集了。

 

  • 24
    点赞
  • 117
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论
下面是一个简单的 PyTorch 手写体识别示例代码: ```python import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset from torchvision.datasets import MNIST from torchvision.transforms import ToTensor # 加载数据集 train_dataset = MNIST(root='data', train=True, transform=ToTensor(), download=True) test_dataset = MNIST(root='data', train=False, transform=ToTensor(), download=True) # 构建数据加载器 train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) # 定义模型 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), ) self.fc = nn.Sequential( nn.Linear(64 * 7 * 7, 128), nn.ReLU(), nn.Linear(128, 10), ) def forward(self, x): x = self.conv(x) x = x.view(x.size(0), -1) x = self.fc(x) return x # 定义模型、损失函数和优化器 model = Net() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 训练模型 for epoch in range(10): running_loss = 0.0 for i, (inputs, labels) in enumerate(train_loader, 0): optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() # 打印训练信息 print('Epoch %d loss: %.3f' % (epoch+1, running_loss/len(train_loader))) # 测试模型 correct = 0 total = 0 with torch.no_grad(): for inputs, labels in test_loader: outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() # 打印测试结果 print('Accuracy on %d test images: %.2f %%' % (total, 100 * correct / total)) ``` 在这个示例代码中,我们首先使用 PyTorch 内置的 MNIST 数据集下载工具加载了训练集和测试集,并构建了数据加载器。然后,我们定义了一个包含两个卷积层和两个全连接层的卷积神经网络,并使用交叉熵损失函数和 Adam 优化器进行模型训练。最后,我们使用测试集对模型进行测试,并计算出了模型在测试集上的准确率。 这个示例代码只是一个简单的手写体识别的 PyTorch 实现,你可以根据自己的需求进行修改和扩展。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

DLANDML

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值