第1关:GoogLeNet
import torch
import torch.nn as nn
class BasicConv2d(nn.Module):
# 基本的卷积类 conv + bn
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) # conv
self.bn = nn.BatchNorm2d(out_channels, eps=0.001) # bn
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return F.relu(x, inplace=True) # 再接个relu
class Inception(nn.Module):
__constants__ = ['branch2', 'branch3', 'branch4']
def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj,
conv_block=None):
# in_channels(192), ch1x1(64), ch3x3red(96), ch3x3( 128), ch5x5red(16), ch5x5(32), pool_proj(32)
super(Inception, self).__init__()
if conv_block is None: # 如果为None,则用上面的BasicConv2d类
conv_block = BasicConv2d # conv + bn
#branch1-branch4的定义
##########Begin##########
self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)
self.branch2 = nn.Sequential(
BasicConv2d(in_channels, ch3x3red, kernel_size=1),
BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1)
)
self.branch3 = nn.Sequential(
BasicConv2d(in_channels, ch5x5red, kernel_size=1),
BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1)
)
self.branch4 = nn.Sequential(
nn.MaxPool2d(kernel_size=3, stride=1, padding=1,ceil_mode=True),
BasicConv2d(in_channels, pool_proj, kernel_size=1)
)
##########End##########
def _forward(self, x):
branch1 = self.branch1(x) # 分支1
branch2 = self.branch2(x) # 分支2
branch3 = self.branch3(x) # 分支3
branch4 = self.branch4(x) # 分支4
outputs = [branch1, branch2, branch3, branch4] # 4部分合并为list
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1) # 通道合并 ch1x1 + ch3x3 + ch5x5 + pool_proj
Inception = Inception(192, 64, 96, 128, 16, 32, 32)
print(Inception)