import torch
import torch.nn as nn
import torch.nn.functional as F
class LinearFC(nn.Module):
def __init__(self):
super(DropoutFC, self).__init__()
self.fc = nn.Linear(3, 2)
def forward(self, input):
out = self.fc(input)
return out
Net = LinearFC()
x = torch.randint(10, (2, 3)).float() # 随机生成不大于10的整数,转为float, 因为nn.linear需要float类型数据
Net.train()
output = Net(x)
print(output)
# train the Net
创建了一个最简单的LinearFC
模型,里面有一个线性函数nn.Linear(3, 2)
,线性变换公式为:
y
=
x
W
T
+
b
y=x W^T + b
y=xWT+b。
通过Debug,一步一步查看运行情况:
当前这一步可以看到模型给我们随机初始化了权重 W 2 × 3 W_{2 \times 3} W2×3和偏置 b 2 × 3 b_{2 \times 3} b2×3,为什么权重 W W W的shape是 2 × 3 2\times3 2×3,因为公式里需要转置。
x
x
x随机生成不大于10的整数,转为float, 因为nn.linear需要float类型数据。
可以看出使用模型算出来的output,与手动使用公式算出来的结果一致。
Net.train()的作用
当网络中有 dropout,Batch Normalization 的时候。训练的要记得 Net.train(), 测试 要记得 Net.eval()。
在训练模型时会在前面加上:
Net.train()
在测试模型时在前面使用:
model.eval()
同时发现,如果不写这两个程序也可以运行,这是因为这两个方法是针对在网络训练和测试时采用不同方式的情况,比如Batch Normalization 和 Dropout。