from torch import nn
import torch
from torch.nn import functional as F
class AlexNet(nn.Module):
def __init__(self, num_class = 1000):
super().__init__()
#输出为3*224*224
self.features = nn.Sequential(
nn.Conv2d(3, 96, kernel_size = 11, stride = 4, padding = 0),#输出为96*54*54
nn.ReLU(),
nn.MaxPool2d(kernel_size = 3, stride = 2),#输出为96*26*26
nn.BatchNorm2d(96),
nn.Conv2d(96, 256, kernel_size = 5, padding = 2),#输出为256*26*26
nn.ReLU(),
nn.MaxPool2d(kernel_size = 3, stride = 2),#输出为256*12*12
nn.BatchNorm2d(256),
nn.Conv2d(256, 384, kernel_size = 3, padding = 1),#输出为384*12*12
nn.ReLU(),
nn.Conv2d(384, 384, kernel_size = 3, padding = 1),#输出为384*12*12
nn.ReLU(),
nn.Conv2d(384, 256, kernel_size = 3, padding = 1),#输出为256*12*12
nn.ReLU(),
nn.MaxPool2d(kernel_size = 3, stride = 2),#输出为256*5*5
nn.Flatten()
)
self.classifier = nn.Sequential(
nn.Linear(256*5*5, 4096),
nn.ReLU(),
nn.Dropout(p = 0.5),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(p = 0.5),
nn.Linear(4096, 1000)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
网络结构:
AlexNet( (features): Sequential( (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(4, 4)) (1): ReLU() (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (4): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (5): ReLU() (6): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (8): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (9): ReLU() (10): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (11): ReLU() (12): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (13): ReLU() (14): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False) (15): Flatten(start_dim=1, end_dim=-1) ) (classifier): Sequential( (0): Linear(in_features=6400, out_features=4096, bias=True) (1): ReLU() (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=4096, out_features=4096, bias=True) (4): ReLU() (5): Dropout(p=0.5, inplace=False) (6): Linear(in_features=4096, out_features=1000, bias=True) ) )