SqueezeNet代码解读pytorch


前言

SqueezeNet代码解读pytorch
在这里我们分析SqeezeNet的PyTorch实现,以加深对网络架构的理解。源代码可见:
https://github.com/pytorch/vision/blob/master/torchvision/models/squeezenet.py


一、Model

1.1 Fire module的实现

class Fire(nn.Module):
    def __init__(self,input,squeeze_planes,expand1x1_planes,expand3x3_planes):
        super(Fire, self).__init__()
        # squeeze层 kernel=1x1
        self.squeeze=nn.Conv2d(input,squeeze_planes,kernel_size=1,stride=1)
        self.squeeze_activation=nn.ReLU(inplace=True)
        # expand 层 ,1x1,3x3两部分
        # 1x1,kernel=1x1
        self.expand1x1=nn.Conv2d(squeeze_planes,expand1x1_planes,kernel_size=1,stride=1)
        # 3x3,kernel=3x3,padding=1
        self.expand3x3=nn.Conv2d(squeeze_planes,expand3x3_planes,kernel_size=3,stride=1,padding=1)
        self.expand_activation=nn.ReLU(inplace=True)

    def forward(self,x):
        x=self.squeeze_activation(self.squeeze(x))
        # 拼接expand的两部分
        x=torch.cat([
            self.expand_activation(self.expand1x1(x)),
            self.expand_activation(self.expand3x3(x))],1)
        return x

1.1 Squeeze 的实现

class SqueezeNet(nn.Module):
    def __init__(self,version=1.0,num_classes=10):
        super(SqueezeNet, self).__init__()
        self.num_classes = num_classes
        self.feature=nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=7, stride=2),
            nn.MaxPool2d(kernel_size=3,stride=2),
            Fire(96,16,64,64),
            Fire(128,16,64,64),
            Fire(128,32,128,128),
            nn.MaxPool2d(kernel_size=3,stride=2),
            Fire(256,32,128,128),
            Fire(256,48,192,192),
            Fire(384,48,192,192),
            Fire(384,64,256,256),
            nn.MaxPool2d(kernel_size=3,stride=2),
            Fire(512,64,256,256),
        )

        conv10=nn.Conv2d(512,num_classes, kernel_size=1, stride=1)
        self.classifier=nn.Sequential(
            nn.Dropout(p=0.5),
            conv10,
            nn.AvgPool2d(kernel_size=13)
        )

        # for m in self.modules():
        #     if isinstance(m, nn.Conv2d):
        #         if m is conv10:
        #             init.normal(m.weight.data, mean=0.0, std=0.01)
        #         else:
        #             init.kaiming_uniform(m.weight.data)
        #         if m.bias is not None:
        #             m.bias.data.zero_()
    def forward(self,x):
        x=self.feature(x)
        x=self.classifier(x)
        return x.view(x.size(0), self.num_classes)

二、Train

import json

import torch
from model import *
import torchvision
# 配置设备
from torch.utils.data import DataLoader

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据处理
data_transform={
    "train":torchvision.transforms.Compose([
        # 随机裁剪,再缩放成 227×227
        torchvision.transforms.RandomResizedCrop(227),
        # 水平方向随机翻转,概率为 0.5, 即一半的概率翻转, 一半的概率不翻转
        torchvision.transforms.RandomHorizontalFlip(p=0.5),
        # 将数据转换为Tensor类型
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ]),
    "val":torchvision.transforms.Compose([
        torchvision.transforms.Resize((227,227)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])
}
# 准备数据集
train_dataset=torchvision.datasets.CIFAR10('../dataset',train=True,transform=data_transform["train"],download=True)
test_dataset=torchvision.datasets.CIFAR10('../dataset',train=False,transform=data_transform["val"],download=True)
# 加载数据集
train_dataLoader=DataLoader(train_dataset,batch_size=16)
test_dataLoader=DataLoader(test_dataset,batch_size=16)
# 训练集的长度
train_length=len(train_dataLoader)
# 测试集的长度
test_length=len(test_dataLoader)

cifar10_classes=train_dataset.class_to_idx
cla_dic=dict((key,val) for key,val in cifar10_classes.items())
# 将 cla_dict 写入 json 文件中
json_str=json.dumps(cla_dic,indent=4)
with open('class_indices.json','w') as json_file:
    json_file.write(json_str)
# 创建网络
net=SqueezeNet()
net=net.to(device)
# 损失函数,交叉验证集
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.to(device)
# 优化器
learning_rate=0.01
optimizer=torch.optim.SGD(net.parameters(),lr=learning_rate)
# 训练参数保存路径
save_path='./SqueezeNet.pth'
# 训练过程中最高准确率
best_acc=0.0
# 训练总损失
total_train_loss=0.0
# 训练次数
epoch=10
for i in range(epoch):
    train_step=0
    for data in train_dataLoader:
        img,target=data
        img=img.to(device)
        target=target.to(device)
        output=net(img)
        # output.view(-1, 1, 32)
        # print(target.shape)
        # print(output.shape)

        loss=loss_fn(output,target)
        # 模型优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss=total_train_loss+loss
        train_step=train_step+1
        print("\r[第{}轮训练] processing [{}/{}]".format(i + 1, train_step, train_length), end="")
    print()
    # 测试步骤
    total_test_loss=0.0
    # 计算准确率
    total_accuracy=0.0
    with torch.no_grad():
        for data in test_dataLoader:
            img,target=data
            img=img.to(device)
            target=target.to(device)
            output=net(img)
            accuracy=(output.argmax(1)==target).sum()

            total_test_loss=total_test_loss+loss
            total_accuracy=total_accuracy+accuracy

        if total_accuracy>best_acc:
            best_acc=total_accuracy
            torch.save(net.state_dict(),save_path)
        print(total_accuracy)
        print('整体测试集上的Loss:{}'.format(total_test_loss))
        print('整体测试集上的正确率:{}'.format(total_accuracy / test_length))


三、Test

import json

import torch
import torchvision
from PIL import Image
from torch import nn
from model import *

img_path='../test_image/cat1.png'

img=Image.open(img_path)

transform=torchvision.transforms.Compose([
    torchvision.transforms.Resize((227,227)),
    torchvision.transforms.ToTensor()
])

img=transform(img)

json_path='./class_indices.json'
with open(json_path,'r') as json_filr:
    class_dict=json.load(json_filr)

weight_path='./SqueezeNet.pth'

net=SqueezeNet()

net.load_state_dict(torch.load(weight_path))

img=torch.reshape(img,[1,3,227,227])

net.eval()
with torch.no_grad():
    output=net(img)
result=output.argmax(1)
# 显示预测结果
for key,value in class_dict.items():
    # print(type(key),type(value))
    if value == int(result):
        print("This picture is of an {}".format(key))


总结

注:该文章为非盈利文章,以上代码如有侵权请联系删除,小编的qq:2370154327

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Squeezenet是一种轻量级的卷积神经网络,适用于移动设备和嵌入式设备。PyTorch是一种深度学习框架,可以用于构建和训练神经网络。如果您想了解如何在PyTorch中实现Squeezenet,可以参考以下代码: ```python import torch import torch.nn as nn class Fire(nn.Module): def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes): super(Fire, self).__init__() self.inplanes = inplanes self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) self.squeeze_activation = nn.ReLU(inplace=True) self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1) self.expand1x1_activation = nn.ReLU(inplace=True) self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1) self.expand3x3_activation = nn.ReLU(inplace=True) def forward(self, x): x = self.squeeze_activation(self.squeeze(x)) return torch.cat([ self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x)) ], 1) class SqueezeNet(nn.Module): def __init__(self, num_classes=100): super(SqueezeNet, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 96, kernel_size=7, stride=2), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(96, 16, 64, 64), Fire(128, 16, 64, 64), Fire(128, 32, 128, 128), nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(256, 32, 128, 128), Fire(256, 48, 192, 192), Fire(384, 48, 192, 192), Fire(384, 64, 256, 256), nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(512, 64, 256, 256), ) self.classifier = nn.Sequential( nn.Dropout(p=.5), nn.Conv2d(512, num_classes, kernel_size=1), nn.ReLU(inplace=True), nn.AvgPool2d(kernel_size=13, stride=1) ) def forward(self, x): x = self.features(x) x = self.classifier(x) return x.view(x.size(), -1) ``` 这个代码定义了SqueezeNet和Fire两个类,其中SqueezeNet包含了多个Fire模块。Fire模块由一个squeeze卷积层和两个expand卷积层组成,用于提取特征。SqueezeNet的前半部分是特征提取器,后半部分是分类器。在前半部分中,使用了多个Fire模块来提取特征,其中每个模块都包含了一个squeeze卷积层和两个expand卷积层。在后半部分中,使用了一个dropout层、一个卷积层和一个平均池化层来进行分类。最终输出的是一个大小为num_classes的向量,表示每个类别的概率。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值