#深度学习网络
#图片大小为1*28*28,batch_size=64
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Sequential, Flatten, Linear
class Shen(nn.Module):
def __init__(self):
super(Shen,self).__init__()
self.model=Sequential(
Conv2d(1,8,kernel_size=3,padding=1,),#8*28*28
MaxPool2d(2),#8*14*14
Conv2d(8,16,kernel_size=5,padding=2),#16*14*14
MaxPool2d(2),#16*7*7
Flatten(),
Linear(784,16),
Linear(16,10)
)
def forward(self,x):
x=self.model(x)
return x
pass
if __name__=='__main__':
shen=Shen()
input=torch.ones((64,1,28,28))
output=shen(input)
print(output.shape)