Pytorch学习之nn.Model
import torch
import torch.nn as nn
import torch.nn.functional as F
官网的例子
class Mdeol(nn.Module):
def __init__(self):
super(Mdeol, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.cinv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.cinv2(x)
x = F.relu(x)
return x
自己定义一个
class qingfeng(nn.Module):
def __init__(self):
super().__init__()
def forward(self, inputs):
output = inputs + 1
return output
对神将网络进行实例化
qingfeng = qingfeng()
x = torch.tensor([[1.0, 2.0, 3.0], [4, 5, 6]])
output = qingfeng(x)
print(output)
tensor([[2., 3., 4.],
[5., 6., 7.]])