【pytorch】自定义层

深度学习的一个魅力在于神经网络中各式各样的层,例如 全连接层、卷积层、池化层与循环层。虽然PyTorch提供了大量常用的层,但有时候我们依然希望自定义层。

这篇文章介绍如何使用Module来自定义层,从而可以被重复调用。


不含模型参数的自定义层

我们先介绍如何定义一个不含参数的自定义层。事实上,创建自定义层 与 使用 Module类 构造模型类似。

下面的 CenteredLayer 类通过继承 Module类 自定义了一个将输入减掉均值后输出的层,并将层的计算定义在了forward函数里。(这个层里不含模型参数)

import torch
from torch import nn

class CenteredLayer(nn.Module):
    def __init__(self):
        super(CenteredLayer, self).__init__()

    def forward(self, x):
        x -= torch.mean(x, dim=0)
        return x - x.mean()


net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
y = net(torch.rand(4, 8))
print(y.shape)   # torch.Size([4, 128])
print(y.mean().item())   # 1.2655618775170296e-09

含模型参数的自定义层

之前的文章已经介绍了模型参数。 Parameter类 其实是Tensor的子类,如果一个 Tensor 是 Parameter,那么它会自动被添加到模型的参数列表里。

所以在自定义含模型参数的层时,我们需要将参数定义成Parameter。
除了直接定义成Parameter类外,还可以使用 ParameterListParameterDict 分别定义参数的列表和字典。

1)ParameterList

ParameterList : 能接收一个 Parameter 实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用 append 和 extend 在列表后面新增参数。

import torch
from torch import nn

class MyDense(nn.Module):
    def __init__(self):
        super(MyDense, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4, 1)))

    def forward(self, x):
        for i in range(len(self.params)):
            x = torch.mm(x, self.params[i])
        return x
        
net = MyDense()
print(net)

# MyDense(
#   (params): ParameterList(
#       (0): Parameter containing: [torch.FloatTensor of size 4x4]
#       (1): Parameter containing: [torch.FloatTensor of size 4x4]
#       (2): Parameter containing: [torch.FloatTensor of size 4x4]
#       (3): Parameter containing: [torch.FloatTensor of size 4x1]
#   )
# )

2)ParameterDict

import torch
from torch import nn

class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn.ParameterDict({
                'linear1': nn.Parameter(torch.randn(4, 4)),
                'linear2': nn.Parameter(torch.randn(4, 1))
        })
        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增

    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])

net = MyDictDense()
print(net)

# MyDictDense(
#   (params): ParameterDict(
#       (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
#       (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
#       (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
#   )
# )
  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Enzo 想砸电脑

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值