SqueezeNet代码解读pytorch


前言

SqueezeNet代码解读pytorch
在这里我们分析SqeezeNet的PyTorch实现,以加深对网络架构的理解。源代码可见:
https://github.com/pytorch/vision/blob/master/torchvision/models/squeezenet.py


一、Model

1.1 Fire module的实现

class Fire(nn.Module):
    def __init__(self,input,squeeze_planes,expand1x1_planes,expand3x3_planes):
        super(Fire, self).__init__()
        # squeeze层 kernel=1x1
        self.squeeze=nn.Conv2d(input,squeeze_planes,kernel_size=1,stride=1)
        self.squeeze_activation=nn.ReLU(inplace=True)
        # expand 层 ,1x1,3x3两部分
        # 1x1,kernel=1x1
        self.expand1x1=nn.Conv2d(squeeze_planes,expand1x1_planes,kernel_size=1,stride=1)
        # 3x3,kernel=3x3,padding=1
        self.expand3x3=nn.Conv2d(squeeze_planes,expand3x3_planes,kernel_size=3,stride=1,padding=1)
        self.expand_activation=nn.ReLU(inplace=True)

    def forward(self,x):
        x=self.squeeze_activation(self.squeeze(x))
        # 拼接expand的两部分
        x=torch.cat([
            self.expand_activation(self.expand1x1(x)),
            self.expand_activation(self.expand3x3(x))],1)
        return x

1.1 Squeeze 的实现

class SqueezeNet(nn.Module):
    def __init__(self,version=1.0,num_classes=10):
        super(SqueezeNet, self).__init__()
        self.num_classes = num_classes
        self.feature=nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=7, stride=2),
            nn.MaxPool2d(kernel_size=3,stride=2),
            Fire(96,16,64,64),
            Fire(128,16,64,64),
            Fire(128,32,128,128),
            nn.MaxPool2d(kernel_size=3,stride=2),
            Fire(256,32,128,128),
            Fire(256,48,192,192),
            Fire(384,48,192,192),
            Fire(384,64,256,256),
            nn.MaxPool2d(kernel_size=3,stride=2),
            Fire(512,64,256,256),
        )

        conv10=nn.Conv2d(512,num_classes, kernel_size=1, stride=1)
        self.classifier=nn.Sequential(
            nn.Dropout(p=0.5),
            conv10,
            nn.AvgPool2d(kernel_size=13)
        )

        # for m in self.modules():
        #     if isinstance(m, nn.Conv2d):
        #         if m is conv10:
        #             init.normal(m.weight.data, mean=0.0, std=0.01)
        #         else:
        #             init.kaiming_uniform(m.weight.data)
        #         if m.bias is not None:
        #             m.bias.data.zero_()
    def forward(self,x):
        x=self.feature(x)
        x=self.classifier(x)
        return x.view(x.size(0), self.num_classes)

二、Train

import json

import torch
from model import *
import torchvision
# 配置设备
from torch.utils.data import DataLoader

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据处理
data_transform={
    "train":torchvision.transforms.Compose([
        # 随机裁剪,再缩放成 227×227
        torchvision.transforms.RandomResizedCrop(227),
        # 水平方向随机翻转,概率为 0.5, 即一半的概率翻转, 一半的概率不翻转
        torchvision.transforms.RandomHorizontalFlip(p=0.5),
        # 将数据转换为Tensor类型
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ]),
    "val":torchvision.transforms.Compose([
        torchvision.transforms.Resize((227,227)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])
}
# 准备数据集
train_dataset=torchvision.datasets.CIFAR10('../dataset',train=True,transform=data_transform["train"],download=True)
test_dataset=torchvision.datasets.CIFAR10('../dataset',train=False,transform=data_transform["val"],download=True)
# 加载数据集
train_dataLoader=DataLoader(train_dataset,batch_size=16)
test_dataLoader=DataLoader(test_dataset,batch_size=16)
# 训练集的长度
train_length=len(train_dataLoader)
# 测试集的长度
test_length=len(test_dataLoader)

cifar10_classes=train_dataset.class_to_idx
cla_dic=dict((key,val) for key,val in cifar10_classes.items())
# 将 cla_dict 写入 json 文件中
json_str=json.dumps(cla_dic,indent=4)
with open('class_indices.json','w') as json_file:
    json_file.write(json_str)
# 创建网络
net=SqueezeNet()
net=net.to(device)
# 损失函数,交叉验证集
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.to(device)
# 优化器
learning_rate=0.01
optimizer=torch.optim.SGD(net.parameters(),lr=learning_rate)
# 训练参数保存路径
save_path='./SqueezeNet.pth'
# 训练过程中最高准确率
best_acc=0.0
# 训练总损失
total_train_loss=0.0
# 训练次数
epoch=10
for i in range(epoch):
    train_step=0
    for data in train_dataLoader:
        img,target=data
        img=img.to(device)
        target=target.to(device)
        output=net(img)
        # output.view(-1, 1, 32)
        # print(target.shape)
        # print(output.shape)

        loss=loss_fn(output,target)
        # 模型优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_loss=total_train_loss+loss
        train_step=train_step+1
        print("\r[第{}轮训练] processing [{}/{}]".format(i + 1, train_step, train_length), end="")
    print()
    # 测试步骤
    total_test_loss=0.0
    # 计算准确率
    total_accuracy=0.0
    with torch.no_grad():
        for data in test_dataLoader:
            img,target=data
            img=img.to(device)
            target=target.to(device)
            output=net(img)
            accuracy=(output.argmax(1)==target).sum()

            total_test_loss=total_test_loss+loss
            total_accuracy=total_accuracy+accuracy

        if total_accuracy>best_acc:
            best_acc=total_accuracy
            torch.save(net.state_dict(),save_path)
        print(total_accuracy)
        print('整体测试集上的Loss:{}'.format(total_test_loss))
        print('整体测试集上的正确率:{}'.format(total_accuracy / test_length))


三、Test

import json

import torch
import torchvision
from PIL import Image
from torch import nn
from model import *

img_path='../test_image/cat1.png'

img=Image.open(img_path)

transform=torchvision.transforms.Compose([
    torchvision.transforms.Resize((227,227)),
    torchvision.transforms.ToTensor()
])

img=transform(img)

json_path='./class_indices.json'
with open(json_path,'r') as json_filr:
    class_dict=json.load(json_filr)

weight_path='./SqueezeNet.pth'

net=SqueezeNet()

net.load_state_dict(torch.load(weight_path))

img=torch.reshape(img,[1,3,227,227])

net.eval()
with torch.no_grad():
    output=net(img)
result=output.argmax(1)
# 显示预测结果
for key,value in class_dict.items():
    # print(type(key),type(value))
    if value == int(result):
        print("This picture is of an {}".format(key))


总结

注:该文章为非盈利文章,以上代码如有侵权请联系删除,小编的qq:2370154327

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值