1. 建立LeNet5主干网络
import torch
from torch import nn
# 定义网络模型
class LeNet5(nn.Module):
#初始化网络
def __init__(self):
super(LeNet5, self).__init__()
self.c1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, padding=2)
self.Sigmoid = nn.Sigmoid()
self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)
self.c3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5)
self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)
self.c5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5)
self.flatten = nn.Flatten()
self.f6 = nn.Linear(in_features=256, out_features=512)
self.output = nn.Linear(in_features=512, out_features=10)
def forward(self, x):
x = self.Sigmoid(self.c1(x))
x = self.s2(x)
x = self.Sigmoid(self.c3(x))
x = self.s4(x)
x = self.c5(x)
x = self.flatten(x)
x = self.f6(x)
x = self.output(x)
return x
if __name__=="__main__":
x = torch.rand([1, 1, 28, 28])
model = LeNet5()
y = model(x)