本文使用mnist数据集为基础,使用alexnet进行数据分类任务,其余代码详见dl学习8,本章节只给出模型结构。
class reshape_data(nn.Module):
def forward(self,X):
return X.view(-1,1,224,224)
class Lenet(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=1,out_channels=96,kernel_size=11,padding=1,stride=4),
nn.ReLU(),
nn.AvgPool2d(3,2),
nn.Conv2d(in_channels=96,out_channels=256,kernel_size=5,padding=2),
nn.ReLU(),
nn.AvgPool2d(3,2),
nn.Conv2d(in_channels=256,out_channels=384,kernel_size=5,padding=2),
nn.ReLU(),
nn.Conv2d(in_channels=384,out_channels=384,kernel_size=5,padding=2),
nn.ReLU(),
nn.Conv2d(in_channels=384,out_channels=256,kernel_size=5,padding=2),
nn.ReLU(),
nn.AvgPool2d(3,2),
nn.Flatten(),
nn.Linear(6400,4096),
nn.ReLU(),
nn.Dropout(.5),
nn.Linear(4096,4096),
nn.ReLU(),
nn.Dropout(.5),
nn.Linear(4096,10),
)
self.reshape_data = reshape_data()
def forward(self,X):
X = self.reshape_data(X)
output = self.net(X)
return output