类的特殊方法汇总(__init__, forward, __getitem__)

在python中定义的类中,往往有__init__()方法,与模型有关的类还会有forward()方法,以及其他特殊方法,如__getitem__()方法,这些方法什么时候会被执行呢?我们通过下面的例子说明:

__init__()初始化方法和forward()方法
import torch
import torch.nn as nn


class SimpleNet(nn.Module):
    # input_dim, hidden_dim, output_dim = 10, 20, 5
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(SimpleNet, self).__init__()
        # 输入层:使用nn.Linear()pytorch框架中的全连接层
        self.fc1 = nn.Linear(input_dim, hidden_dim)  
        self.relu = nn.ReLU()  # 激活层:nn.ReLU()使用Relu激活函数
        self.fc2 = nn.Linear(hidden_dim, output_dim)  # 输出层:使用nn.Linear()全连接层,

    def forward(self, x):
        print(f"输入张量x.shape:{x.shape}")  # torch.Size([1, 10])
        out = self.fc1(x)
        print(f"经过输入全连接层后的shape为:{out.shape}")  # torch.Size([1, 20])
        out = self.relu(out)
        print(f"经过激活层的shape变为:{out.shape}")  # torch.Size([1, 20]) 激活层不改变shape
        out = self.fc2(out)
        print(f"经过输出全连接层的shape变为:{out.shape}")  # torch.Size([1, 5]) 这是由SimpleNet类中的默认方法__init__()定义的
        return out


if __name__ == '__main__':
    # SimpleNet类的实例化
    net = SimpleNet(10, 20, 5)

    # 构造一个随机的输入张量,大小为 [batch_size, input_dim],这里令 batch_size=1
    input_tensor = torch.randn(1, 10)

    # 调用类的实例: 将创建的输入张量传入网络中,得到输出张量
    output_tensor = net(input_tensor)

    # 打印输出张量的形状
    print(output_tensor.shape)  # torch.Size([1, 5])

    # 打印模型的参数量
    print(
        "Model has {} parameters".format(sum(y.numel() for y in net.parameters()))
    )

在这个例子中,我们可以看到:

__init__()初始化方法在类的实例化时自动执行;

forward()方法在调用类的实例时自动执行;

那__getitem__()怎么用呢?

我们在用 for…in… 迭代对象时,如果对象没有实现 iter next 迭代器协议,Python的解释器就会去寻找__getitem__ 来迭代对象,如果连__getitem__ 都没有定义,这解释器就会报对象不是迭代器的错误:

TypeError: 'Animal' object is not iterable

而__getitem__()方法可以让对象实现迭代功能,这样就可以使用for…in… 来迭代对象了,并且可以定义迭代对象的数据元组组成:

class AISTPPDataset(Dataset):
    def __init__(
        self,
        data_path: str,
        ######
    ):
     ########################### 其他代码

    def __getitem__(self, idx):
        filename_ = self.data["filenames"][idx]
        feature = torch.from_numpy(np.load(filename_))
        return self.data["pose"][idx], feature, filename_, self.data["wavs"][idx] 

train_dataset = AISTPPDataset(
                data_path=opt.data_path,
                backup_path=opt.processed_data_dir,
                train=True,
                force_reload=opt.force_reload,
            )
train_data_loader = DataLoader(
            train_dataset,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=min(int(num_cpus * 0.75), 32),
            pin_memory=True,
            drop_last=True,
        )
for step, (x, cond, filename, wavnames) in enumerate(
   train_data_loader
           ):  # 这里迭代训练数据的数据元组定义与类aist++的__getitem__()相对应!

参考博客:Python.__getitem__方法-CSDN博客

  • 11
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值