类中的__init__,__call__,__len__,__getitem__,__setitem__,__delitem__函数

目录

自己认为的:

__init__()

__call__()

__len__()

序列协议:

__getitem__:

__setitem__:

__delitem__:

参考别人的:

对应于(自己认为的)那部分的代码 


自己认为的:

__init__()

在下面的代码中,用到了类,其中类中的__init__(self, output_size)和__call__(self, output_size)函数,介绍一下。

当我一遍一遍的debug的时候,scale = Rescale(256)中的256赋值output_size,进行的是初始化。而没有调用__call__()。

scale = Rescale(256)
crop = RandomCrop(128)

__call__()

在下面代码的时候,采用enumerate来返回的sample_batched(等价于Rescale和RandomCrop的类实例)来调用__call__()

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

所以一个类实例也可以成为类似函数这样能直接调用的对象,只要定义的时候有__call__()方法就可以。

__len__()

len可以用len()函数返回对象实例的“长度”

注:要让 len() 函数工作正常,类必须提供一个特殊方法__len__(),它返回元素的个数

class Students(object):

    def __init__(self, *args):
        self.names = args

    def __len__(self):
        return len(self.names)

s = Students('tom', 'jack')
print( len(s) )

序列协议

绑定了__getitem__方法的对象被认为是序列


__getitem__

如果类把某个属性定义为序列,可以使用__getitem__输出序列属性中的某个元素。把绑定__getitem__()方法的实例对象称为自定义序列对象。

红色框中的代码实现了相同的效果,验证了那句话:绑定了__getitem__方法的对象被认为是序列

__setitem__

如果类把某个属性定义为序列,可以使用__setitem__修改序列属性中的某个元素。


__delitem__

如果类把某个属性定义为序列,可以使用__delitem__删除序列属性中的某个元素。

class Test:
    def __init__(self, names):
        self.names = names

    def __getitem__(self, item):
        return self.names[item]

    def __setitem__(self, key, value):
        self.names[key] = value

    def __delitem__(self, key):
        del self.names[key]


names = ["zhangsan", "lisi", "wangwu", "zhaoliu"]
test = Test(names)
print("++++++++++++++++++++")
for name in test:
    print(name)
for name in names:
    print(name)
print("++++++++++++++++++++")
print(test[1])
print(test[1:3])
test[1] = "test"
print(test.names)
del test[1]
print(test.names)
print()


"""
运行结果:
++++++++++++++++++++
zhangsan
lisi
wangwu
zhaoliu
++++++++++++++++++++
lisi
['lisi', 'wangwu']
['zhangsan', 'test', 'wangwu', 'zhaoliu']
['zhangsan', 'wangwu', 'zhaoliu']


Process finished with exit code 0
"""

参考别人的:

链接:https://www.cnblogs.com/lyu454978790/p/8630215.html

具体看下这里:

>>>class Reader():

    def __init__(self,name,nationality):

      self.name = name

      self.nationality = nationality

    def __call__(self):

      print('Reader: %s    Nationality: %s' % (self.name, self.nationality))

>>>r = Reader('Annie','Chinese')

>>>r()

Reader:Annie  Nationality: Chinese

__call__()方法还可以带参数

定义一个可以直接调用类实例的Reader类,并可统计读者数量

>>>class Reader():

    count = 0

    def __init__(self,name,nationality):
      self.name = name
      self.nationality = nationality
      Reader.count += 1

    def __call__(self, behave):
      print('Reader: %s' % self.name)
      print('Nationality: %s' % self.nationality)
      print('%s is being %s.' % (self.name, behave))
      print('The total number of readers is %s.' % Reader.count)

>>>a = Reader('Annie','Chinese')

>>>a('Nice')

Reader: Annie

Nationality: Chinese

Annie is being Nice.

The total number of readers is 1.

>>>b = Reader('Adam','American')

>>>b('Silly')

Reader: Adam

Nationality: American

Adam is being Silly.

The total number of readers is 2.      #自动增加

代码链接:https://pytorch.apachecn.org/docs/1.4/5.html

对应于(自己认为的)那部分的代码 

main函数

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from Rescale import Rescale,RandomCrop,ToTensor
# Ignore warnings
import warnings

from FaceLandmarksDataset import FaceLandmarksDataset

warnings.filterwarnings("ignore")

plt.ion()   # interactive mode
landmarks_frame = pd.read_csv('./data/faces/face_landmarks.csv')

n = 65
img_name = landmarks_frame.iloc[n, 0]
landmarks = landmarks_frame.iloc[n, 1:]
landmarks = np.asarray(landmarks)
landmarks = landmarks.astype('float').reshape(-1, 2)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
show_landmarks(io.imread(os.path.join('data/faces/', img_name)),
               landmarks)
plt.show()

face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                    root_dir='data/faces/')

fig = plt.figure()

for i in range(len(face_dataset)):
    sample = face_dataset[i]

    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break

scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
                               RandomCrop(224)])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = face_dataset[65]
for i, tsfrm in enumerate([scale, crop, composed]):
    transformed_sample = tsfrm(sample)

    ax = plt.subplot(1, 3, i + 1)
    plt.tight_layout()
    ax.set_title(type(tsfrm).__name__)
    show_landmarks(**transformed_sample)

plt.show()

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]

    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break
dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid_border_size = 2

    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))

    for i in range(batch_size):
        plt.scatter(landmarks_batch[i, :, 0].numpy() + i * im_size + (i + 1) * grid_border_size,
                    landmarks_batch[i, :, 1].numpy() + grid_border_size,
                    s=10, marker='.', c='r')

        plt.title('Batch from dataloader')

for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['landmarks'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        plt.figure()
        show_landmarks_batch(sample_batched)
        plt.axis('off')
        plt.ioff()
        plt.show()
        break

类函数 

import torch
from skimage import io, transform

import numpy as np

class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)

        img = transform.resize(image, (new_h, new_w))

        # h and w are swapped for landmarks because for images,
        # x and y axes are axis 1 and 0 respectively
        landmarks = landmarks * [new_w / w, new_h / h]

        return {'image': img, 'landmarks': landmarks}

class RandomCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        top = np.random.randint(0, h - new_h)
        left = np.random.randint(0, w - new_w)

        image = image[top: top + new_h,
                      left: left + new_w]

        landmarks = landmarks - [left, top]

        return {'image': image, 'landmarks': landmarks}

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'landmarks': torch.from_numpy(landmarks)}

 

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值