以CIFAR10为例:
import torch from torch import nn from torch.nn import Conv2d, MaxPool2d, Flatten, Linear class Model(nn.Module): def __init__(self): super(Model,self).__init__() self.conv1=Conv2d(3,32,5,padding=2) self.maxpool1=MaxPool2d(2) self.conv2=Conv2d(32,32,5,padding=2) self.maxpool2=MaxPool2d(2) self.conv3=Conv2d(32,64,5,padding=2) self.maxpool3=MaxPool2d(2) self.flatten=Flatten() # 展平 self.Linear1=Linear(1024,64) self.Linear2=Linear(64,10) def forward(self,x): x=self.conv1(x) x=self.maxpool1(x) x=self.conv2(x) x=self.maxpool2(x) x=self.conv3(x) x=self.maxpool3(x) x=self.flatten(x) x=self.Linear1(x) x=self.Linear2(x) return x model=Model() print(model) input=torch.ones((64,3,32,32))//此行开始是检验,最终输出(64,10) output=model(input) print(output.shape)
如果不想这么繁琐可以用Sequential()打包一下