import torch.nn as nn class alexnet(nn.Module): def __init__(self): super(alexnet,self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(1,96,kernel_size=3,stride=1,padding=1), nn.ReLU(inplace=True) ,nn.MaxPool2d(kernel_size=3,stride=2) ) self.conv2 = nn.Sequential( nn.Conv2d(96,256,kernel_size=3,stride=1,padding=1),nn.ReLU(inplace=True) ,nn.MaxPool2d(kernel_size=3,stride=2) ) self.conv3 = nn.Sequential( nn.Conv2d(256,384,kernel_size=3,stride=1,padding=1),nn.ReLU(inplace=True) ) self.conv4 = nn.Sequential( nn.Conv2d(384,384,kernel_size=3,stride=1,padding=1) ,nn.ReLU(inplace=True) ) self.conv5 = nn.Sequential( nn.Conv2d(384,256,kernel_size=3,stride=1,padding=1),nn.ReLU(inplace=True) ,nn.MaxPool2d(kernel_size=3,stride=2) ) self.fc1 = nn.Sequential( nn.Linear(256*2*2,50),nn.ReLU(inplace=True) ) self.fc2 = nn.Sequential( nn.Linear(50,50),nn.ReLU(inplace=True) ) self.fc3 = nn.Linear(50, 10) def forward(self,pic): pic = self.conv1(pic) pic = self.conv2(pic) pic = self.conv3(pic) pic = self.conv4(pic) pic = self.conv5(pic) pic = pic.view(pic.size()[0], -1) pic = self.fc1(pic) pic = self.fc2(pic) pic = self.fc3(pic) return pic
AlexNet
最新推荐文章于 2022-06-07 09:51:08 发布