在nn.Sequential中嵌套OrderedDict组织网络,以对层进行命名
import torch
import torch.nn as nn
from collections import OrderedDict
class OrderedDictCNN(nn.Module):
def __init__(self):
super(OrderedDictCNN, self).__init__()
self.model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)),
('bn1', nn.BatchNorm2d(64)),
('relu1', nn.ReLU(inplace=True)),
('maxpool1', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
('conv2', nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)),
('bn2', nn.BatchNorm2d(128)),
('relu2', nn.ReLU(inplace=True)),
('maxpool2', nn.MaxPool2d(kernel_size=2, stride=2, padding=0)),
('flatten', nn.Flatten()),
('fc1', nn.Linear(128 * 112 * 112, 1000)),
('relu3', nn.ReLU(inplace=True)),
('fc2', nn.Linear(1000, 10))
]))
def forward(self, x):
return self.model(x)
使用多个nn.Sequential组织网络
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.stem = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.feature_extraction = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 112 * 112, 1000),
nn.ReLU(inplace=True),
nn.Linear(1000, 10)
)
def forward(self, x):
x = self.stem(x)
x = self.feature_extraction(x)
x = self.fc(x)
return x
使用单个nn.Sequential组织网络
import torch
import torch.nn as nn
class SequentialCNN(nn.Module):
def __init__(self):
super(SequentialCNN, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
nn.Flatten(),
nn.Linear(128 * 112 * 112, 1000),
nn.ReLU(inplace=True),
nn.Linear(1000, 10)
)
def forward(self, x):
return self.model(x)