经典神经网络实现——AlexNet

前言

上一节我们用代码实现了LeNet,本章我们来复现AlexNet,同时借助一些推理工具进行推理耗时对比。

AlexNet介绍

AlexNet《ImageNet Classification with Deep Convolutional Neural Networks》作为早期引起大家关注的论文,对深度学习发展有着启发性作用。当年论文模型参加ImageNet LSVRC-2010比赛,并一举夺冠。Alex整体结构和LeNet类似,都是经过卷积和全连接层得到最终结果,但实现上有很大区别。AlexNet有五层卷积层,三层全连接层,最终输出1000通道的Softmax值。AlexNet将模型分布在两块GPU上运行,大大提高运行效率,在比赛结果上也远远领先第二名,在学术界和工业界引发巨大轰动。
AlexNet网络结构
AlexNet网络结构如上图所示,可以看出网络分上下对称两部分,而这两部分对应着两块显卡,仅在某些特定层上有交互运算。为了实现方便,本文仅实现其中单分支部分。

代码实现

数据集

数据集我帮你们处理好了,如果想测试自己的数据,将自己的数据改成相同的格式就可以替换了。

数据集链接: https://pan.baidu.com/s/1PeH0zPGLQVmmQc-I1vkIlg?pwd=1234

Dataset

数据集为flower的分类,有五种类型的花,因此我们需要自定义Dataset,代码如下

class MyDataset(Dataset):
    def __init__(self, filename, transform) -> None:
        self.filename = filename
        self.transform = transform
        self.image_list, self.label_list = self.operate_file()
    
    def __getitem__(self, idx: Any) -> Any:
        image = Image.open(self.image_list[idx])
        trans = transforms.RandomResizedCrop(227)
        image = trans(image)
        label = self.label_list[idx]
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(label))
        return image,label
    
    def __len__(self):
        return len(self.image_list)

    def operate_file(self):
        dir_list = os.listdir(self.filename)
        img_list = []
        label_list = []
        label_dict = {'tulip':np.int64(0), 'sunflower':np.int64(1), 'rose':np.int64(2), 'dandelion':np.int64(3), 'daisy':np.int64(4)}
        for i,v in enumerate(dir_list):
            dir_path = os.path.join(self.filename, v)
            file_list = os.listdir(dir_path)
            file_list = [os.path.join(dir_path, path) for path in file_list]
            if v not in label_dict.keys():
                continue
            img_list.extend(file_list)
            for j in range(len(file_list)):
                label_list.append(label_dict[v])
        return img_list, label_list

构建网络结构

网络主体结构部分实现比较简单,其中Nomalization部分,原文使用Local Response Normalization(局部响应归一化),鉴于LRN的后续影响力,本文使用BN代替。

class AlexNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.backboke = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
            
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
            
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2), 
        )
        self.classifier = nn.Sequential(
            nn.Linear(9216, 4096),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(4096, 4096),
            #nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(4096, 5)
            #nn.Softmax()
        )
    
    def forward(self, x):
        x = self.backboke(x)
        x = torch.flatten(x, 1)
        result = self.classifier(x)
        return result

下面贴上完整代码,PS:本次相比上一篇复现LeNet增加了模型保存和加载,模型验证,tensorboard可视化,onnx推理,tensorrt推理。

网络部分:

import torch
import torch.nn as nn
from Dataset import MyDataset
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
import onnxruntime
import onnx

from torch.utils.tensorboard import SummaryWriter

import datetime
import time

#torch.manual_seed(0)

class AlexNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.backboke = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
            
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
            
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2), 
        )
        self.classifier = nn.Sequential(
            nn.Linear(9216, 4096),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(4096, 4096),
            #nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(4096, 5)
            #nn.Softmax()
        )
    
    def forward(self, x):
        x = self.backboke(x)
        x = torch.flatten(x, 1)
        result = self.classifier(x)
        return result



def train(model_path, data_path):
    lr = 0.001
    tag = "AlexNet_lr_" + str(lr)
    model = AlexNet()
    device = torch.device("cuda:0")
    model.to(device)
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(params=model.parameters(), lr=lr)
    batch_size = 32
    epochs = 10
    train_set = MyDataset(data_path, transform=transforms.ToTensor())
    train_loader = DataLoader(train_set, batch_size, shuffle=True)
    tensorboard_log_dir = '$ YOUR LOG DIR' + tag
    writter = SummaryWriter(tensorboard_log_dir)
    time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    itr = 0
    for epoch in range(epochs):
        loss_temp = 0
        for j,(batch_data, batch_label) in enumerate(train_loader):
            batch_data = batch_data.cuda()
            batch_label = batch_label.cuda()
            optimizer.zero_grad()
            pred = model(batch_data)
            loss = loss_func(pred, batch_label)
            loss_temp += loss
            loss.backward()
            optimizer.step()
            itr += 1
            writter.add_scalar('iter loss', loss, itr)
        writter.add_scalar('epoch loss', loss_temp / len(train_loader), epoch+1)
        print('[%d] loss : %.3f ' % (epoch+1, loss_temp / len(train_loader)))
    
    writter.close()
    torch.save({
        'epoch': epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss_temp
    }, model_path)

def test(model_path, data_path):
    model = AlexNet()
    #device = torch.device("cpu")
    device = torch.device("cuda:0")
    model.to(device)
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    batch_size = 32
    correct = 0
    test_set = MyDataset(data_path, transform=transforms.ToTensor())
    test_loader = DataLoader(test_set, batch_size, shuffle=True)
    
    for j,(batch_data, batch_label) in enumerate(test_loader):
        batch_data = batch_data.cuda()
        batch_label = batch_label.cuda()
        pred = model(batch_data)
        predicted = torch.max(pred.data, 1)[1]
        correct += (predicted == batch_label).sum()
    print('ckpt correct : ', correct)
    print('acc : %.2f %%' % (100 * int(correct.int()) / len(test_set)))

def export_onnx(model_path):
    model = AlexNet()
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    batch_size = 1
    input_shape = (3, 227, 227)
    x = torch.randn(batch_size, *input_shape)
    torch.onnx.export(model, x, '$Your ONNX MODEL PATH', opset_version=11, do_constant_folding=True, input_names=["input"], output_names=["output"])
    onnx_model = onnx.load('$Your ONNX MODEL PATH')
    onnx.checker.check_model(onnx_model)
    print(onnx.helper.printable_graph(onnx_model.graph))

def test_onnx(onnx_path, data_path):
    ort_session = onnxruntime.InferenceSession(onnx_path, providers=['CUDAExecutionProvider'])

    batch_size = 1
    correct = 0
    test_set = MyDataset(data_path, transform=transforms.ToTensor())
    test_loader = DataLoader(test_set, batch_size, shuffle=True)
    
    with torch.no_grad():
        for j,(batch_data, batch_label) in enumerate(test_loader):
            batch_data = batch_data.cuda()
            batch_label = batch_label.cuda()
            ort_inputs = {ort_session.get_inputs()[0].name:batch_data.numpy()}
            ort_outputs = ort_session.run(None, ort_inputs)[0]
            pred = torch.from_numpy(ort_outputs)
            predicted = torch.max(pred.data, 1)[1]
            correct += (predicted == batch_label).sum()
    print('onnx correct : ', correct)
    print('acc : %.2f %%' % (100 * int(correct.int()) / len(test_set)))

if __name__=="__main__":
    train('$ YOUR SAVE MODEL PATH', '$ YOUR TRAIN IMAGES PATH')
    test('$ YOUR SAVE MODEL PATH', '$ YOUR TEST IMAGES PATH')
    export_onnx('$ YOUR SAVE MODEL PATH')
    test_onnx('$Your ONNX MODEL PATH', '$ YOUR TEST IMAGES PATH')

数据处理部分:

from typing import Any
import torch
from torch.utils.data import Dataset
import os
import numpy as np
from PIL import Image
from torchvision.transforms import transforms

class MyDataset(Dataset):
    def __init__(self, filename, transform) -> None:
        self.filename = filename
        self.transform = transform
        self.image_list, self.label_list = self.operate_file()
    
    def __getitem__(self, idx: Any) -> Any:
        image = Image.open(self.image_list[idx])
        trans = transforms.RandomResizedCrop(227)
        image = trans(image)
        label = self.label_list[idx]
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(label))
        return image,label
    
    def __len__(self):
        return len(self.image_list)

    def operate_file(self):
        dir_list = os.listdir(self.filename)
        img_list = []
        label_list = []
        label_dict = {'tulip':np.int64(0), 'sunflower':np.int64(1), 'rose':np.int64(2), 'dandelion':np.int64(3), 'daisy':np.int64(4)}
        for i,v in enumerate(dir_list):
            dir_path = os.path.join(self.filename, v)
            file_list = os.listdir(dir_path)
            file_list = [os.path.join(dir_path, path) for path in file_list]
            if v not in label_dict.keys():
                continue
            img_list.extend(file_list)
            for j in range(len(file_list)):
                label_list.append(label_dict[v])
        return img_list, label_list

TensorRT部分:

from typing import Union, Optional, Sequence,Dict,Any

import torch
import tensorrt as trt

import time

class TRTWrapper(torch.nn.Module):
    def __init__(self,engine: Union[str, trt.ICudaEngine],
                 output_names: Optional[Sequence[str]] = None) -> None:
        super().__init__()
        self.engine = engine
        if isinstance(self.engine, str):
            with trt.Logger() as logger, trt.Runtime(logger) as runtime:
                with open(self.engine, mode='rb') as f:
                    engine_bytes = f.read()
                self.engine = runtime.deserialize_cuda_engine(engine_bytes)
        self.context = self.engine.create_execution_context()
        names = [_ for _ in self.engine]
        input_names = list(filter(self.engine.binding_is_input, names))
        self._input_names = input_names
        self._output_names = output_names

        if self._output_names is None:
            output_names = list(set(names) - set(input_names))
            self._output_names = output_names

    def forward(self, inputs: Dict[str, torch.Tensor]):
        assert self._input_names is not None
        assert self._output_names is not None
        bindings = [None] * (len(self._input_names) + len(self._output_names))
        profile_id = 0
        for input_name, input_tensor in inputs.items():
            # check if input shape is valid
            profile = self.engine.get_profile_shape(profile_id, input_name)
            assert input_tensor.dim() == len(
                profile[0]), 'Input dim is different from engine profile.'
            for s_min, s_input, s_max in zip(profile[0], input_tensor.shape,
                                             profile[2]):
                assert s_min <= s_input <= s_max, \
                    'Input shape should be between ' \
                    + f'{profile[0]} and {profile[2]}' \
                    + f' but get {tuple(input_tensor.shape)}.'
            idx = self.engine.get_binding_index(input_name)

            # All input tensors must be gpu variables
            assert 'cuda' in input_tensor.device.type
            input_tensor = input_tensor.contiguous()
            if input_tensor.dtype == torch.long:
                input_tensor = input_tensor.int()
            self.context.set_binding_shape(idx, tuple(input_tensor.shape))
            bindings[idx] = input_tensor.contiguous().data_ptr()

        # create output tensors
        outputs = {}
        for output_name in self._output_names:
            idx = self.engine.get_binding_index(output_name)
            dtype = torch.float32
            shape = tuple(self.context.get_binding_shape(idx))

            device = torch.device('cuda')
            output = torch.empty(size=shape, dtype=dtype, device=device)
            outputs[output_name] = output
            bindings[idx] = output.data_ptr()
        self.context.execute_async_v2(bindings,
                                      torch.cuda.current_stream().cuda_stream)
        return outputs

model = TRTWrapper('$ YOUR TRT ENGINE PATH', ['output'])
for i in range(0, 1000):
    output = model(dict(input = torch.randn(1, 3, 227, 227).cuda()))
t1 = time.time()
for i in range(0, 1000):
    output = model(dict(input = torch.randn(1, 3, 227, 227).cuda()))
print('tensorrt cost time : ', time.time() - t1)
print(output)

关于TensorRT部分,需要参考模型测试中的dataloader加载,替换成自己的数据。

运行结果

最后我们看一下自己复现的AlexNet在不同框架下的推理速度
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值