GooLeNet-V3
开发环境
- python–3.7
- torch–1.8+cu101
- torchsummary
- torchvision–0.6.1+cu101
- PIL
- numpy
- opencv-python
- pillow
准备工作
GooLeNet-V3预训练模型,下载地址:
https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth
海洋渔业分类数据集,下载地址:
www.kaggle.com/c/the-nature-conservancy-fisheries-monitoring/data
项目代码结构
- data文件夹下存储了海洋渔业分类数据集,GooLeNet预训练模型权重以及用于测试demo的图片和imagenet配置文件。
- src存储了GooLeNet_V3推理demo程序和训练程序
- tools存储通用文件:海洋渔业分类数据集构建文件和模型获取、标签平滑代码文件。
海洋渔业分类数据集构建程序
import os
import random
from PIL import Image
from torch.utils.data import Dataset
random.seed(1)
class_name = ["ALB", "BET", "DOL", "LAG", "NoF", "OTHER", "SHARK", "YFT"]
class NCFMDataSet(Dataset):
def __init__(self, data_dir, mode="train", split_n=0.9, rng_seed=620, transform=None):
"""
鱼类分类任务的Dataset
:param data_dir: 数据集所在路径
:param mode: 训练集/测试集
:param split_n: 数据集划分比例
:param rng_seed: 随机种子
:param transform: torch.transform,数据预处理
"""
self.mode = mode
self.data_dir = data_dir
self.rng_seed = rng_seed
self.split_n = split_n
self.data_info = self._get_img_info() # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
if len(self.data_info) == 0:
raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(self.data_dir))
return len(self.data_info)
def _get_img_info(self):
img_path = []
for root, dirs, files in os.walk(self.data_dir):
for name in files:
if name.endswith(".jpg"):
img_path.append(os.path.join(root, name))
random.seed(self.rng_seed)
random.shuffle(img_path)
img_labels = [class_name.index(os.path.basename(os.path.dirname(p))) for p in img_path]
split_idx = int(len(img_labels) * self.split_n)
# split_idx = int(100 * self.split_n)
if self.mode == "train":
img_set = img_path[:split_idx] # 数据集90%训练
label_set = img_labels[:split_idx]
elif self.mode == "valid":
img_set = img_path[split_idx:]
label_set = img_labels[split_idx:]
else:
raise Exception("self.mode 无法识别,仅支持(train, valid)")
path_img_set = [os.path.join(self.data_dir, n) for n in img_set]
data_info = [(n, l) for n, l in zip(path_img_set, label_set)]
return data_info
GooLeNet_3模型获取程序
def get_googlenet_v3(path_state_dict, device, vis_model=False):
"""
创建模型,加载参数
:param path_state_dict:
:return:
"""
model = models.inception_v3()
if path_state_dict:
pretrained_state_dict = torch.load(path_state_dict)
model.load_state_dict(pretrained_state_dict)
model.eval()
if vis_model:
from torchsummary import summary
summary(model, input_size=(3, 299, 299), device="cpu")
model.to(device)
return model
标签平滑程序
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, eps=0.001):
super(LabelSmoothingCrossEntropy, self).__init__()
self.eps = eps
def forward(self, x, target):
log_probs = f.log_softmax(x, dim=1)
H_pq = -log_probs.gather(dim=-1, index=target.unsqueeze(1))
H_pq = H_pq.squeeze(1)
H_uq = -log_probs.mean()
loss = (1 - self.eps) * H_pq + self.eps * H_uq
return loss.mean()
GooLeNet_3模型推理demo程序
import os
import time
import json
import torch
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
from tools.common_tools import get_googlenet_v3
os.environ['NLS_LANG'] = 'SIMPLIFIED CHINESE_CHINA.UTF8'
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def img_transform(img_rgb, transform=None):
"""
将数据转换为模型读取的形式
:param img_rgb: PIL Image
:param transform: torchvision.transform
:return: tensor
"""
if transform is None:
raise ValueError("找不到transform!必须有transform对img进行处理")
img_t = transform(img_rgb)
return img_t
def process_img(path_img):
# hard code
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
inference_transform = transforms.Compose([
# transforms.Resize(256),
transforms.CenterCrop((299, 299)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
# path --> img
img_rgb = Image.open(path_img).convert('RGB')
# img --> tensor
img_tensor = img_transform(img_rgb, inference_transform)
img_tensor.unsqueeze_(0) # chw --> bchw
img_tensor = img_tensor.to(device)
return img_tensor, img_rgb
def load_class_names(p_clsnames, p_clsnames_cn):
"""
加载标签名
:param p_clsnames:
:param p_clsnames_cn:
:return:
"""
with open(p_clsnames, "r") as f:
class_names = json.load(f)
with open(p_clsnames_cn, encoding='UTF-8') as f: # 设置文件对象
class_names_cn = f.readlines()
return class_names, class_names_cn
if __name__ == "__main__":
# config
path_state_dict = os.path.join(BASE_DIR, "..", "data", "inception_v3_google-1a9a5a14.pth")
# path_img = os.path.join(BASE_DIR, "..", "data","Golden Retriever from baidu.jpg")
path_img = os.path.join(BASE_DIR, "..", "data", "tiger cat.jpg")
path_classnames = os.path.join(BASE_DIR, "..", "data", "imagenet1000.json")
path_classnames_cn = os.path.join(BASE_DIR, "..", "data", "imagenet_classnames.txt")
# load class names
cls_n, cls_n_cn = load_class_names(path_classnames, path_classnames_cn)
# 1/5 load img
img_tensor, img_rgb = process_img(path_img)
# 2/5 load model
googlenet_v3_model = get_googlenet_v3(path_state_dict, device, True)
# 3/5 inference tensor --> vector
with torch.no_grad():
time_tic = time.time()
outputs = googlenet_v3_model(img_tensor)
time_toc = time.time()
# 4/5 index to class names
_, pred_int = torch.max(outputs.data, 1)
_, top5_idx = torch.topk(outputs.data, 5, dim=1)
pred_idx = int(pred_int.cpu().numpy())
pred_str, pred_cn = cls_n[pred_idx], cls_n_cn[pred_idx]
print("img: {} is: {}\n{}".format(os.path.basename(path_img), pred_str, pred_cn))
print("time consuming:{:.2f}s".format(time_toc - time_tic))
# 5/5 visualization
plt.imshow(img_rgb)
plt.title("predict:{}".format(pred_str))
top5_num = top5_idx.cpu().numpy().squeeze()
text_str = [cls_n[t] for t in top5_num]
for idx in range(len(top5_num)):
plt.text(5, 15 + idx * 30, "top {}:{}".format(idx + 1, text_str[idx]), bbox=dict(fc='yellow'))
plt.show()
GooLeNet_3模型训练程序
import os
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from tools.common_tools import get_googlenet_v3, LabelSmoothingCrossEntropy
from tools.my_dataset import NCFMDataSet
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == "__main__":
# config
data_dir = os.path.join(BASE_DIR, "..", "..", "Data", "NCFM", "train")
path_state_dict = os.path.join(BASE_DIR, "..", "data", "inception_v3_google-1a9a5a14.pth")
# path_state_dict = False
num_classes = 8
MAX_EPOCH = 3
BATCH_SIZE = 8
LR = 0.001
log_interval = 1
val_interval = 1
start_epoch = -1
lr_decay_step = 1
# ============================ step 1/5 数据 ============================
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((342)), # 299 / (224/256) = 342
transforms.CenterCrop(299),
transforms.RandomCrop(299),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
normalizes = transforms.Normalize(norm_mean, norm_std)
valid_transform = transforms.Compose([
transforms.Resize((342, 342)),
transforms.TenCrop(299, vertical_flip=False),
transforms.Lambda(lambda crops: torch.stack([normalizes(transforms.ToTensor()(crop)) for crop in crops])),
])
# 构建MyDataset实例
train_data = NCFMDataSet(data_dir=data_dir, mode="train", transform=train_transform)
valid_data = NCFMDataSet(data_dir=data_dir, mode="valid", transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=4)
# ============================ step 2/5 模型 ============================
googlenet_v3_model = get_googlenet_v3(path_state_dict, device, False)
num_ftrs = googlenet_v3_model.fc.in_features
googlenet_v3_model.fc = nn.Linear(num_ftrs, num_classes)
num_ftrs_1 = googlenet_v3_model.AuxLogits.fc.in_features
googlenet_v3_model.AuxLogits.fc = nn.Linear(num_ftrs_1, num_classes)
googlenet_v3_model.to(device)
# ============================ step 3/5 损失函数 ============================
# criterion = nn.CrossEntropyLoss()
criterion = LabelSmoothingCrossEntropy(eps=0.001)
# ============================ step 4/5 优化器 ============================
# 冻结卷积层
flag = 0
# flag = 1
if flag:
fc_params_id = list(map(id, googlenet_v3_model.classifier.parameters())) # 返回的是parameters的 内存地址
base_params = filter(lambda p: id(p) not in fc_params_id, googlenet_v3_model.parameters())
optimizer = optim.SGD([
{'params': base_params, 'lr': LR * 0.1}, # 0
{'params': googlenet_v3_model.classifier.parameters(), 'lr': LR}], momentum=0.9)
else:
optimizer = optim.SGD(googlenet_v3_model.parameters(), lr=LR, momentum=0.9) # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1) # 设置学习率下降策略
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(patience=5)
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()
for epoch in range(start_epoch + 1, MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
googlenet_v3_model.train()
for i, data in enumerate(train_loader):
# 1. forward
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
outputs = googlenet_v3_model(inputs)
# 2. backward
optimizer.zero_grad()
loss_main, aug_loss1 = criterion(outputs[0], labels), criterion(outputs[1], labels)
loss = loss_main + (0.3 * aug_loss1)
loss.backward()
# 3. update weights
optimizer.step()
# 统计分类情况
_, predicted = torch.max(outputs[0].data, 1)
total += labels.size(0)
correct += (predicted == labels).squeeze().cpu().sum().numpy()
# 打印训练信息
loss_mean += loss_main.item()
train_curve.append(loss_main.item())
if (i + 1) % log_interval == 0:
loss_mean = loss_mean / log_interval
print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss_main:{:.4f} Acc:{:.2%} lr:{}".format(
epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total, scheduler.get_last_lr()))
loss_mean = 0.
scheduler.step() # 更新学习率
# validate the model
if (epoch + 1) % val_interval == 0:
correct_val = 0.
total_val = 0.
loss_val = 0.
googlenet_v3_model.eval()
with torch.no_grad():
for j, data in enumerate(valid_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
bs, ncrops, c, h, w = inputs.size()
outputs = googlenet_v3_model(inputs.view(-1, c, h, w))
outputs_avg = outputs.view(bs, ncrops, -1).mean(1)
loss = criterion(outputs_avg, labels)
_, predicted = torch.max(outputs_avg.data, 1)
total_val += labels.size(0)
correct_val += (predicted == labels).squeeze().cpu().sum().numpy()
loss_val += loss.item()
loss_val_mean = loss_val / len(valid_loader)
valid_curve.append(loss_val_mean)
print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val_mean, correct_val / total_val))
googlenet_v3_model.train()
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(train_loader)
valid_x = np.arange(1, len(
valid_curve) + 1) * train_iters * val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()
GooLeNet_3网络结构代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Inception3(nn.Module):
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False,
inception_blocks=None, init_weights=True):
super(Inception3, self).__init__()
if inception_blocks is None:
inception_blocks = [
BasicConv2d, InceptionA, InceptionB, InceptionC,
InceptionD, InceptionE, InceptionAux
]
assert len(inception_blocks) == 7
conv_block = inception_blocks[0]
inception_a = inception_blocks[1]
inception_b = inception_blocks[2]
inception_c = inception_blocks[3]
inception_d = inception_blocks[4]
inception_e = inception_blocks[5]
inception_aux = inception_blocks[6]
self.aux_logits = aux_logits
self.transform_input = transform_input
self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
self.Mixed_5b = inception_a(192, pool_features=32)
self.Mixed_5c = inception_a(256, pool_features=64)
self.Mixed_5d = inception_a(288, pool_features=64)
self.Mixed_6a = inception_b(288)
self.Mixed_6b = inception_c(768, channels_7x7=128)
self.Mixed_6c = inception_c(768, channels_7x7=160)
self.Mixed_6d = inception_c(768, channels_7x7=160)
self.Mixed_6e = inception_c(768, channels_7x7=192)
if aux_logits:
self.AuxLogits = inception_aux(768, num_classes)
self.Mixed_7a = inception_d(768)
self.Mixed_7b = inception_e(1280)
self.Mixed_7c = inception_e(2048)
self.fc = nn.Linear(2048, num_classes)
if init_weights:
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
import scipy.stats as stats
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
X = stats.truncnorm(-2, 2, scale=stddev)
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
values = values.view(m.weight.size())
with torch.no_grad():
m.weight.copy_(values)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _transform_input(self, x):
if self.transform_input:
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
return x
def _forward(self, x):
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 64 x 73 x 73
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = F.max_pool2d(x, kernel_size=3, stride=2)
# N x 192 x 35 x 35
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
aux_defined = self.training and self.aux_logits
if aux_defined:
aux = self.AuxLogits(x)
else:
aux = None
# N x 768 x 17 x 17
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
# N x 2048 x 8 x 8
# Adaptive average pooling
x = F.adaptive_avg_pool2d(x, (1, 1))
# N x 2048 x 1 x 1
x = F.dropout(x, training=self.training)
# N x 2048 x 1 x 1
x = torch.flatten(x, 1)
# N x 2048
x = self.fc(x)
# N x 1000 (num_classes)
return x, aux
@torch.jit.unused
def eager_outputs(self, x, aux):
# type: (Tensor, Optional[Tensor]) -> InceptionOutputs
if self.training and self.aux_logits:
return InceptionOutputs(x, aux)
else:
return x
def forward(self, x):
x = self._transform_input(x)
x, aux = self._forward(x)
aux_defined = self.training and self.aux_logits
if torch.jit.is_scripting():
if not aux_defined:
warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
return InceptionOutputs(x, aux)
else:
return self.eager_outputs(x, aux)
class InceptionA(nn.Module):
def __init__(self, in_channels, pool_features, conv_block=None):
super(InceptionA, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
def _forward(self, x):
branch1x1 = self.branch1x1(x)
branch5x5 = self.branch5x5_1(x)
branch5x5 = self.branch5x5_2(branch5x5)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionB(nn.Module):
def __init__(self, in_channels, conv_block=None):
super(InceptionB, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
def _forward(self, x):
branch3x3 = self.branch3x3(x)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
outputs = [branch3x3, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionC(nn.Module):
def __init__(self, in_channels, channels_7x7, conv_block=None):
super(InceptionC, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
c7 = channels_7x7
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
def _forward(self, x):
branch1x1 = self.branch1x1(x)
branch7x7 = self.branch7x7_1(x)
branch7x7 = self.branch7x7_2(branch7x7)
branch7x7 = self.branch7x7_3(branch7x7)
branch7x7dbl = self.branch7x7dbl_1(x)
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionD(nn.Module):
def __init__(self, in_channels, conv_block=None):
super(InceptionD, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
def _forward(self, x):
branch3x3 = self.branch3x3_1(x)
branch3x3 = self.branch3x3_2(branch3x3)
branch7x7x3 = self.branch7x7x3_1(x)
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
outputs = [branch3x3, branch7x7x3, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionE(nn.Module):
def __init__(self, in_channels, conv_block=None):
super(InceptionE, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
def _forward(self, x):
branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x)
branch3x3 = [
self.branch3x3_2a(branch3x3),
self.branch3x3_2b(branch3x3),
]
branch3x3 = torch.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
branch3x3dbl = [
self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl),
]
branch3x3dbl = torch.cat(branch3x3dbl, 1)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool)
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
return outputs
def forward(self, x):
outputs = self._forward(x)
return torch.cat(outputs, 1)
class InceptionAux(nn.Module):
def __init__(self, in_channels, num_classes, conv_block=None):
super(InceptionAux, self).__init__()
if conv_block is None:
conv_block = BasicConv2d
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
self.conv1 = conv_block(128, 768, kernel_size=5)
self.conv1.stddev = 0.01
self.fc = nn.Linear(768, num_classes)
self.fc.stddev = 0.001
def forward(self, x):
# N x 768 x 17 x 17
x = F.avg_pool2d(x, kernel_size=5, stride=3)
# N x 768 x 5 x 5
x = self.conv0(x)
# N x 128 x 5 x 5
x = self.conv1(x)
# N x 768 x 1 x 1
# Adaptive average pooling
x = F.adaptive_avg_pool2d(x, (1, 1))
# N x 768 x 1 x 1
x = torch.flatten(x, 1)
# N x 768
x = self.fc(x)
# N x 1000
return x
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return F.relu(x, inplace=True)