nn.Module
torch.nn是专门为深度学习所设计的模块,torch.nn的核心数据结构是nn.Module,它是一个抽象的概念,既可以表示神经网络中的某个层,又可以表示某个含有很多层的网络。在实际使用中最常用的做法是继承nn.Module,撰写自己的网络层
#现在先来实现一下多层感知机
from torch import nn
import torch as t
class Perception(nn.Module):
def __init__(self,in_feature,hidden_feature,out_feature):
nn.Module.__init__(self)
self.layer1 = nn.Linear(in_feature,hidden_feature)
self.layer2 = nn.Linear(hidden_feature,out_feature)
def forward(self,x):
x = self.layer1(x)
x = t.sigmoid(x)
x = self.layer2(x)
return x
percep = Perception(3,4,1)
inputs = t.randn(1,3)
outputs = percep(inputs)
print(outputs)
for name,param in percep.named_parameters():
print(name,param)
output:
tensor([[0.6794]], grad_fn=)
layer1.weight Parameter containing:
tensor([[-0.2399, 0.3114, -0.4838],
[-0.2509, 0.5123, 0.2901],