不带参数的层
Python 3.8.8 (default, Apr 13 2021, 15:08:03) [MSC v.1916 64 bit (AMD64)]
Type 'copyright', 'credits' or 'license' for more information
IPython 7.22.0 -- An enhanced Interactive Python. Type '?' for help.
PyDev console: using IPython 7.22.0
Python 3.8.8 (default, Apr 13 2021, 15:08:03) [MSC v.1916 64 bit (AMD64)] on win32
import torch
import torch.nn.functional as F
from torch import nn
class CenteredLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X):
return X - X.mean()#把均值减掉
layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))
Out[4]: tensor([-2., -1., 0., 1., 2.])
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
Y = net(torch.rand(4, 8))
Y.mean()
Out[7]: tensor(6.9849e-10, grad_fn=<MeanBackward0>)
带参数的层
class MyLinear(nn.Module):
def __init__(self, in_units, units):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_units, units))
self.bias = nn.Parameter(torch.randn(units,))
def forward(self, X):
linear = torch.matmul(X, self.weight.data) + self.bias.data
return F.relu(linear)
linear = MyLinear(5, 3)
linear.weight
Out[9]:
Parameter containing:
tensor([[ 0.6766, -0.1331, -0.1324],
[ 0.1136, 1.8299, 0.5271],
[ 0.4939, -1.7403, -0.9208],
[ 0.2670, 0.4728, -0.4918],
[ 0.1707, -1.2526, -1.1154]], requires_grad=True)
自定义层
class MyLinear(nn.Module):
def __init__(self, in_units, units):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_units, units))
self.bias = nn.Parameter(torch.randn(units,))
def forward(self, X):
linear = torch.matmul(X, self.weight.data) + self.bias.data
return F.relu(linear)
linear = MyLinear(5, 3)
linear.weight
Out[9]:
Parameter containing:
tensor([[ 0.6766, -0.1331, -0.1324],
[ 0.1136, 1.8299, 0.5271],
[ 0.4939, -1.7403, -0.9208],
[ 0.2670, 0.4728, -0.4918],
[ 0.1707, -1.2526, -1.1154]], requires_grad=True)
使用自定义层进行正向传播计算
linear(torch.rand(2, 5))
Out[10]:
tensor([[2.7365, 2.7208, 0.0000],
[2.6007, 4.7692, 0.2059]])
使用自定义层构建模型
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))
Out[11]:
tensor([[6.2113],
[4.6728]])