import torch
import torch.nn as nn
# MNist为例,image shape 1,28,28
class ConvBNReLu(nn.Module):
def __init__(self, in_ch, out_ch):
super(ConvBNReLu, self).__init__()
self.conv = nn.Sequential([
nn.Conv2d(1, 10, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(10),
nn.ReLU()
])
def forward(self, x):
return self.conv(x)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = ConvBNReLu(1, 10)
self.conv2 = ConvBNReLu(10, 20)
self.conv3 = ConvBNReLu(20, 40)
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.cls_head = nn.Linear(40, 10)
def forward(self, x):
x1 = self.conv1(x)
x1 = self.maxpool(x1)
x2 = self.conv2(x1)
x2 = self.maxpool(x2)
x3 = self.conv3(x2)
x3 = self.avgpool(x3).flatten(dim=1)
logtis = self.cls_head(x3)
return logtis
【手撕算法系列】简单神经网络
最新推荐文章于 2024-04-24 11:21:49 发布