GoogLeNet网络

目录

1. 创新点

1.1 引入Inception结构

1.2 1×1卷积降维

1.3 两个辅助分类器 

1.4 丢弃全连接层,使用平均池化层

2. 网络结构

3. 知识点

3.1 torch.cat

3.2 关于self.training

3.3 关于load_state_dict中的strict

4. 代码 

4.1 model.py

4.2 train.py

4.3 predict.py

5. 结果


1. 创新点

1.1 引入Inception结构

作用:融合不同尺度的特征信息

注意:每个分支所得特征矩阵的宽、高必须相同

下图来自:Going deeper with convolutions

1.2 1×1卷积降维

channels: 512

a.不使用1×1卷积核降维

使用:64个5×5卷积核进行卷积

参数:5×5×512×64=819,200

b.使用1×1卷积核降维

使用:24个1×1卷积核进行卷积

1.3 两个辅助分类器 

内容:GoogLeNet有三个输出层(两个为辅助分类层)

 Going deeper with convolutions文章里:

  • An average pooling layer with 5×5 filter size and stride 3, resulting in an 4×4×512 output for the (4a), and 4×4×528 for the (4d) stage.
  • A 1×1 convolution with 128 filters for dimension reduction and rectified linear activation.
  • A fully connected layer with 1024 units and rectified linear activation.
  • A dropout layer with 70% ratio of dropped outputs.
  • A linear layer with softmax loss as the classifier (predicting the same 1000 classes as the main classifier, but removed at inference time).

1.4 丢弃全连接层,使用平均池化层

作用:大大减少模型的参数

2. 网络结构

Inception层太多,列出几个:

3. 知识点

3.1 torch.cat

import torch

a = torch.Tensor([1, 2, 3])
b = torch.Tensor([4, 5, 6])
c = [a, b]
print(torch.cat(c))
# tensor([1., 2., 3., 4., 5., 6.])

3.2 关于self.training

使用model.train()和model.eval()控制模型的状态

在model.train()模式下self.training=True

在model.eval()模式下self.training=False

3.3 关于load_state_dict中的strict

为True:有什么要什么,每一个键都有(默认为True)

为False:有什么要什么,没有的就不要

missing_keys和unexpected_keys:缺失的、不期望的键

4. 代码 

4.1 model.py

import torch
import torch.nn as nn
import torch.nn.functional as F

class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000, aux_use=True, init_weight=False):
        super(GoogLeNet, self).__init__()
        self.aux_use = aux_use
        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)  # ceil_mode默认向下取整 True为向上取整
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)
        self.maxpool4 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)

        if self.aux_use:
            self.aux1 = InceptionAux(512, num_classes)
            self.aux2 = InceptionAux(528, num_classes)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # 自适应平均池化 指定输出(H,W)
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, num_classes)
        if init_weight:
            self._initialize_weights_()

    def forward(self, x):
        # N×3×224×224
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)
        x = self.inception4a(x)
        if self.training and self.aux_use:
            aux1 = self.aux1(x)
        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        if self.training and self.aux_use:
            aux2 = self.aux2(x)
        x = self.inception4e(x)
        x = self.maxpool4(x)
        x = self.inception5a(x)
        x = self.inception5b(x)
        x = self.avgpool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.dropout(x)
        x = self.fc(x)
        if self.training and self.aux_use:
            return x, aux1, aux2
        return x;

    def _initialize_weights_(self):
        for v in self.modules():
            if isinstance(v, nn.Conv2d):
                nn.init.xavier_uniform_(v.weight)
                if v.bias is not None:
                    nn.init.constant_(v.bias, 0)
            if isinstance(v, nn.Linear):
                nn.init.xavier_uniform_(v.weight)
                if v.bias is not None:
                    nn.init.constant_(v.bias, 0)


# set BasicConv2d class
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x;


# set Inception class
class Inception(nn.Module):
    # 各分支最后的输出宽高要一样
    def __init__(self, in_channels, ch11, ch33_reduce, ch33, ch55_reduce, ch55, pool_proj):
        super(Inception, self).__init__()
        self.branch1 = BasicConv2d(in_channels, ch11, kernel_size=1)
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch33_reduce, kernel_size=1),
            BasicConv2d(ch33_reduce, ch33, kernel_size=3, padding=1)
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch55_reduce, kernel_size=1),
            BasicConv2d(ch55_reduce, ch55, kernel_size=5, padding=2)
        )
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1)
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, dim=1)


# set InceptionAux class
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.averagePool = nn.AvgPool2d(kernel_size=5, stride=3)
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)  # output:[batch,128,4,4]
        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)

    def forward(self, x):
        # Input: Aux1(batch,512,14,14) Aux2(batch,528,14,14)
        x = self.averagePool(x)
        # output: Aux1(batch,512,4,4) Aux2(batch,528,4,4)
        x = self.conv(x)
        # output:Aux1、Aux2(batch,128,4,4)
        x = torch.flatten(x, start_dim=1)
        x = F.dropout(x, 0.5, training=self.training)
        # batch × 2048
        x = F.relu(self.fc1(x), inplace=True)
        x = F.dropout(x, 0.5, training=self.training)
        # batch × 1024
        x = self.fc2(x)
        # batch × num_classes
        return x

4.2 train.py

import os
import sys

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import json
from model import GoogLeNet
import torch.optim as optim
from tqdm import tqdm

def main():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    data_transform = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]),
        'val': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    }
    data_root = os.path.abspath(os.getcwd())
    image_path = os.path.join(data_root, 'data_set', 'flower_data')
    assert os.path.exists(image_path), 'file:{} is not exist!'.format(image_path)

    # set dataset
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'train'), transform=data_transform['train'])
    val_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'val'), transform=data_transform['val'])
    train_num = len(train_dataset)
    val_num = len(val_dataset)

    # write dict into file
    flower_list = train_dataset.class_to_idx
    class_dict = dict((k, v) for v, k in flower_list.items())
    json_str = json.dumps(class_dict, indent=4)
    with open('./class_indices.json', 'w') as file:
        file.write(json_str)

    # set dataloader
    batch_size = 32
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    print('using {} images for training, {} images for validation.'.format(train_num, val_num))

    net = GoogLeNet(num_classes=5, aux_use=True, init_weight=True)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0003)

    epochs = 30
    best_acc = 0.0
    save_path = './GoogLeNet.pth'
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        epoch_loss = 0.0
        train_bar = tqdm(train_loader)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            output, aux1_output, aux2_output = net(images.to(device))
            loss0 = loss_function(output, labels.to(device))
            loss1 = loss_function(aux1_output, labels.to(device))
            loss2 = loss_function(aux2_output, labels.to(device))
            loss = loss0 + 0.3 * loss1 + 0.3 * loss2
            loss.backward()
            optimizer.step()
            # print statistics
            epoch_loss += loss.item()
            train_bar.desc = 'train epoch[{}/{}] loss:{:.3f}'.format(epoch + 1, epochs, loss)

        # validate
        net.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(val_loader)
            for step, data in enumerate(val_bar):
                val_images, val_labels = data
                outputs = net(val_images.to(device))
                predict_y = torch.argmax(outputs, dim=1)
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
        val_acc = acc / val_num
        print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, epoch_loss / train_steps, val_acc))

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(net.state_dict(), save_path)
    print('Finished Training!')


if __name__ == '__main__':
    main()

4.3 predict.py

import os
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import json
from model import GoogLeNet

def main():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    img_path = './sunflower.jpg'
    assert os.path.exists(img_path), 'file:{} is not exist!'.format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)

    # [N,C,H,W]
    img = transform(img)
    img = torch.unsqueeze(img, dim=0)

    # read class_dict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), 'file:{} is not exist!'.format(json_path)
    with open(json_path, 'r') as file:
        class_dict = json.load(file)

    # create model
    net = GoogLeNet(num_classes=5, aux_use=False).to(device)

    # load model weights
    weight_path = './GoogLeNet.pth'
    assert os.path.exists(weight_path), 'file:{} is not exist!'.format(weight_path)
    # unexpected_keys里面存放的是辅助分类器aux1与aux2的权重
    missing_keys, unexpected_keys = net.load_state_dict(torch.load(weight_path, map_location=device), strict=False)
    net.eval()
    with torch.no_grad():
        outputs = torch.squeeze(net(img.to(device))).cpu()
        predict = torch.softmax(outputs, dim=0)
        predict_class = torch.argmax(predict).numpy()

    print_res = 'class:{} probability:{:.3f}'.format(class_dict[str(predict_class)], predict[predict_class])
    plt.title(print_res)
    for i in range(len(predict)):
        print('class:{:10} probability:{:.3f}'.format(class_dict[str(i)], predict[i]))
    plt.show()


if __name__ == '__main__':
    main()

5. 结果

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值