从python中的一些特殊方法讲到pytorch的官方例子mnist(主要针对pytorch的自定义dataset中的几个特殊函数进行说明)

本文详细介绍了PyTorch中自定义Dataset时涉及的特殊方法,包括__getitem__、__repr__和__str__,并结合MNIST官方例子进行说明。__getitem__用于获取数据和标签,是通过索引访问数据的关键;__repr__和__str__则在打印对象时调用,返回打印内容。通过实例代码展示了这些方法的使用和返回结果。
摘要由CSDN通过智能技术生成

__str__(self)

  • 该方法中必须有一个return
  • 调用该方法的时机是print对象时
  • return的内容就是print打印的内容

__repr__(self)

  • 该方法与__str__(self)方法一样都是打印时调用的
  • 其中也必须含有return
  • 下面用pytorch官方例子mnist来进行说明:
#每一个自定义的dataset都必须继承类torch.utils.data.Dataset,还必须重写__len__(self)方法和__getitem__(self,index)方法。
#顺便可能会重写__repr__(self)方法,这个不是必须,下面的例子就是自定义dataset类--MNIST类的部分代码:
def __repr__(self):
    fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
    fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
    tmp = 'train' if self.train is True else 'test'
    fmt_str += '    Split: {}\n'.format(tmp)
    fmt_str += '    Root Location: {}\n'.format(self.root)
    tmp = '    Transforms (if any): '
    fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    tmp = '    Target Transforms (if any): '
    fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
    return fmt_str
#当用如下代码实例化MNIST类之后,运用print函数就能调用MNIST类的__repr__(self)函数
d = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))]))
print(d)#此时就会调用函数__repr__(self)函数
#输出结果如下:
Dataset MNIST
    Number of datapoints: 60000
    Split: train
    Root Location: ../data
    Transforms (if any): Compose(
                             ToTensor()
                             Normalize(mean=(0.1307,), std=(0.3081,))
                         )
    Target Transforms (if any): None

__getitem__(self,index)

  • 既然说到了__getitem__(self,index),也对其进行讲解。该方法也必须有返回return。
  • 该方法的调用方式是对象名后面加上中括号,中括号中填上索引值。
  • 下面以代码的方式进行展示:
#跟上面代码一样,我们实例化了自定义dataset,然后用索引来取值:
d = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))]))
print(d[0])#此时就会调用函数__getitem__(self,index)函数
  • 打印的结果如下,里面包含了一张28×28大小的图像,就是MNIST数据集里面每张图像的大小;还有后面一个一维的tensor就是这张图像的标签索引值。
  • 从打印结果也能够看出,__getitem__(self,index)函数的主要功能就是获取数据和标签,并将其返回。
(tensor([[[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值