pytorch 源码AlexNet阅读和使用

官方源码:https://pytorch.org/docs/stable/_modules/torchvision/models/alexnet.html#alexnet 

1. torchvision.models.alexnet

import torch
import torch.nn as nn
from .utils import load_state_dict_from_url


__all__ = ['AlexNet', 'alexnet']


model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
}



class AlexNet(nn.Module):  # inherit Module. override __init__ and forward function

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential( # input_shape == [b, 3, 227, 227]
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), # [b, 64, 56, 56], floor( (w-f+2p)/s + 1 )
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2), # [b, 64, 27, 27], floor( (w-f+2p)/s + 1 )

            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))  # [b,256,6,6], 为了使输出特征图大小为6x6,自动求一个合适的核大小
        self.classifier = nn.Sequential(
            nn.Dropout(),  # 到这里之前已经使用flatten函数,变成shape==[b, 256*6*6]
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),

            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),  # 两个dropout+linear+relu模块

            nn.Linear(4096, num_classes),  # 最后输出各个类别分数
        )

    def forward(self, x):
        x = self.features(x)  # # torch.Size([1, 256, 6, 6]) 多个conv2d+relu+maxpool
        x = self.avgpool(x)  # torch.Size([1, 256, 6, 6])
        x = torch.flatten(x, 1)  # [b, 256*6*6], torch.flatten(input, start_dim=0, end_dim=-1), 这里start_dim=1,即从第一个维度到最后一个维度
        x = self.classifier(x)  # torch.Size([1, 1000])
        return x


def alexnet(pretrained=False, progress=True, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model = AlexNet(**kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['alexnet'],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model

  

2. 自己写的测试用例(预测单张图片类别)

import cv2
import torch
import numpy as np
from torchvision import datasets, transforms as tf
import torchvision.models as models

alexnet = models.alexnet(pretrained=True)
alexnet.eval()

img_path = '/home/zxq/PycharmProjects/data/cfar100/car.jpg'
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
resized_img = cv2.resize(img, (227, 227), interpolation=cv2.INTER_NEAREST)

normalize = tf.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
transform = tf.Compose([tf.ToTensor(), normalize])

img_tr = transform(resized_img)
input_data = torch.unsqueeze(img_tr, 0)
output_tensor = alexnet(input_data)

output_numpy = output_tensor.detach().numpy()  # [1, 1000]
category_index = np.argmax(output_numpy, axis=1)
print('预测出类别对应的索引: {}'.format(int(category_index)))

输出

预测出类别对应的索引: 511

对应的标注类别

511 n03100240 轿车 check, convertible
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Mr.Q

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值