Pytorch自定义网络层

Pytorch、Tensoflow等许多深度学习框架集成了大量常见的网络层,为我们搭建神经网络提供了诸多便利。但在实际工作中,因为项目要求、研究需要或者发论文需要等等,大家一般都会需要自己发明一个现在在深度学习框架中还不存在的层。 在这些情况下,就必须构建自定义层。

博主在学习了沐神的动手学深度学习这本书之后,学到了许多东西。这里记录一下书中基于Pytorch实现简单自定义网络层的方法,仅供参考。

一、不带参数的层

首先,我们构造一个没有任何参数的自定义层,要构建它,只需继承基础层类并实现前向传播功能。

import torch
import torch.nn.functional as F
from torch import nn


class CenteredLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        return X - X.mean()

输入一些数据,验证一下网络是否能正常工作:

layer = CenteredLayer()
print(layer(torch.FloatTensor([1, 2, 3, 4, 5])))

输出结果如下:

tensor([-2., -1.,  0.,  1.,  2.])

运行正常,表明网络没有问题。

现在将我们自建的网络层作为组件合并到更复杂的模型中,并输入数据进行验证:

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
Y = net(torch.rand(4, 8))
print(Y.mean())  # 因为模型参数较多,输出也较多,所以这里输出Y的均值,验证模型可运行即可

结果如下:

tensor(-5.5879e-09, grad_fn=<MeanBackward0>)

二、带参数的层

这里使用内置函数来创建参数,这些函数可以提供一些基本的管理功能,使用更加方便。

这里实现了一个简单的自定义的全连接层,大家可根据需要自行修改即可。

class MyLinear(nn.Module):
    def __init__(self, in_units, units):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units, units))
        self.bias = nn.Parameter(torch.randn(units,))
    def forward(self, X):
        linear = torch.matmul(X, self.weight.data) + self.bias.data
        return F.relu(linear)

接下来实例化类并访问其模型参数:

linear = MyLinear(5, 3)
print(linear.weight)

结果如下:

Parameter containing:
tensor([[-0.3708,  1.2196,  1.3658],
        [ 0.4914, -0.2487, -0.9602],
        [ 1.8458,  0.3016, -0.3956],
        [ 0.0616, -0.3942,  1.6172],
        [ 0.7839,  0.6693, -0.8890]], requires_grad=True)

而后输入一些数据,查看模型输出结果:

print(linear(torch.rand(2, 5)))
# 结果如下
tensor([[1.2394, 0.0000, 0.0000],
        [1.3514, 0.0968, 0.6667]])

我们还可以使用自定义层构建模型,使用方法与使用内置的全连接层相同。

net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
print(net(torch.rand(2, 64)))
# 结果如下
tensor([[4.1416],
        [0.2567]])

三、总结

  • 我们可以通过基本层类设计自定义层。这允许我们定义灵活的新层,其行为与深度学习框架中的任何现有层不同。

  • 在自定义层定义完成后,我们就可以在任意环境和网络架构中调用该自定义层。

  • 层可以有局部参数,这些参数可以通过内置函数创建。

四、参考

《动手学深度学习》 — 动手学深度学习 2.0.0-beta0 documentationicon-default.png?t=M4ADhttps://zh-v2.d2l.ai/

  • 1
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
PyTorch自定义网络模型结构图可以通过使用PyTorch内置的`torchsummary`模块来生成。这个模块可以帮助我们快速地展示模型的参数数量、每一层的输出形状等重要信息。 以下是一个简单的示例,展示了如何使用`torchsummary`模块来生成自定义网络模型的结构图: ``` python import torch import torch.nn as nn from torchsummary import summary class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.fc1 = nn.Linear(64 * 8 * 8, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = nn.functional.relu(x) x = self.pool(x) x = self.conv2(x) x = nn.functional.relu(x) x = self.pool(x) x = x.view(-1, 64 * 8 * 8) x = self.fc1(x) x = nn.functional.relu(x) x = self.fc2(x) return x model = MyModel() summary(model, input_size=(3, 32, 32)) ``` 运行以上代码,就可以得到如下的输出: ``` ---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 32, 32, 32] 896 MaxPool2d-2 [-1, 32, 16, 16] 0 Conv2d-3 [-1, 64, 16, 16] 18,496 MaxPool2d-4 [-1, 64, 8, 8] 0 Linear-5 [-1, 128] 524,416 Linear-6 [-1, 10] 1,290 ================================================================ Total params: 545,098 Trainable params: 545,098 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.01 Forward/backward pass size (MB): 0.75 Params size (MB): 2.08 Estimated Total Size (MB): 2.85 ---------------------------------------------------------------- ``` 可以看到,`summary`函数生成了一个包含每一层输出形状、参数数量等信息的表格,以及估计的模型大小。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值