前言
SqueezeNet代码解读pytorch
在这里我们分析SqeezeNet的PyTorch实现,以加深对网络架构的理解。源代码可见:
https://github.com/pytorch/vision/blob/master/torchvision/models/squeezenet.py
一、Model
1.1 Fire module的实现
class Fire(nn.Module):
def __init__(self,input,squeeze_planes,expand1x1_planes,expand3x3_planes):
super(Fire, self).__init__()
# squeeze层 kernel=1x1
self.squeeze=nn.Conv2d(input,squeeze_planes,kernel_size=1,stride=1)
self.squeeze_activation=nn.ReLU(inplace=True)
# expand 层 ,1x1,3x3两部分
# 1x1,kernel=1x1
self.expand1x1=nn.Conv2d(squeeze_planes,expand1x1_planes,kernel_size=1,stride=1)
# 3x3,kernel=3x3,padding=1
self.expand3x3=nn.Conv2d(squeeze_planes,expand3x3_planes,kernel_size=3,stride=1,padding=1)
self.expand_activation=nn.ReLU(inplace=True)
def forward(self,x):
x=self.squeeze_activation(self.squeeze(x))
# 拼接expand的两部分
x=torch.cat([
self.expand_activation(self.expand1x1(x)),
self.expand_activation(self.expand3x3(x))],1)
return x
1.1 Squeeze 的实现
class SqueezeNet(nn.Module):
def __init__(self,version=1.0,num_classes=10):
super(SqueezeNet, self).__init__()
self.num_classes = num_classes
self.feature=nn.Sequential(
nn.Conv2d(3, 96, kernel_size=7, stride=2),
nn.MaxPool2d(kernel_size=3,stride=2),
Fire(96,16,64,64),
Fire(128,16,64,64),
Fire(128,32,128,128),
nn.MaxPool2d(kernel_size=3,stride=2),
Fire(256,32,128,128),
Fire(256,48,192,192),
Fire(384,48,192,192),
Fire(384,64,256,256),
nn.MaxPool2d(kernel_size=3,stride=2),
Fire(512,64,256,256),
)
conv10=nn.Conv2d(512,num_classes, kernel_size=1, stride=1)
self.classifier=nn.Sequential(
nn.Dropout(p=0.5),
conv10,
nn.AvgPool2d(kernel_size=13)
)
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# if m is conv10:
# init.normal(m.weight.data, mean=0.0, std=0.01)
# else:
# init.kaiming_uniform(m.weight.data)
# if m.bias is not None:
# m.bias.data.zero_()
def forward(self,x):
x=self.feature(x)
x=self.classifier(x)
return x.view(x.size(0), self.num_classes)
二、Train
import json
import torch
from model import *
import torchvision
# 配置设备
from torch.utils.data import DataLoader
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据处理
data_transform={
"train":torchvision.transforms.Compose([
# 随机裁剪,再缩放成 227×227
torchvision.transforms.RandomResizedCrop(227),
# 水平方向随机翻转,概率为 0.5, 即一半的概率翻转, 一半的概率不翻转
torchvision.transforms.RandomHorizontalFlip(p=0.5),
# 将数据转换为Tensor类型
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]),
"val":torchvision.transforms.Compose([
torchvision.transforms.Resize((227,227)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
}
# 准备数据集
train_dataset=torchvision.datasets.CIFAR10('../dataset',train=True,transform=data_transform["train"],download=True)
test_dataset=torchvision.datasets.CIFAR10('../dataset',train=False,transform=data_transform["val"],download=True)
# 加载数据集
train_dataLoader=DataLoader(train_dataset,batch_size=16)
test_dataLoader=DataLoader(test_dataset,batch_size=16)
# 训练集的长度
train_length=len(train_dataLoader)
# 测试集的长度
test_length=len(test_dataLoader)
cifar10_classes=train_dataset.class_to_idx
cla_dic=dict((key,val) for key,val in cifar10_classes.items())
# 将 cla_dict 写入 json 文件中
json_str=json.dumps(cla_dic,indent=4)
with open('class_indices.json','w') as json_file:
json_file.write(json_str)
# 创建网络
net=SqueezeNet()
net=net.to(device)
# 损失函数,交叉验证集
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.to(device)
# 优化器
learning_rate=0.01
optimizer=torch.optim.SGD(net.parameters(),lr=learning_rate)
# 训练参数保存路径
save_path='./SqueezeNet.pth'
# 训练过程中最高准确率
best_acc=0.0
# 训练总损失
total_train_loss=0.0
# 训练次数
epoch=10
for i in range(epoch):
train_step=0
for data in train_dataLoader:
img,target=data
img=img.to(device)
target=target.to(device)
output=net(img)
# output.view(-1, 1, 32)
# print(target.shape)
# print(output.shape)
loss=loss_fn(output,target)
# 模型优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_loss=total_train_loss+loss
train_step=train_step+1
print("\r[第{}轮训练] processing [{}/{}]".format(i + 1, train_step, train_length), end="")
print()
# 测试步骤
total_test_loss=0.0
# 计算准确率
total_accuracy=0.0
with torch.no_grad():
for data in test_dataLoader:
img,target=data
img=img.to(device)
target=target.to(device)
output=net(img)
accuracy=(output.argmax(1)==target).sum()
total_test_loss=total_test_loss+loss
total_accuracy=total_accuracy+accuracy
if total_accuracy>best_acc:
best_acc=total_accuracy
torch.save(net.state_dict(),save_path)
print(total_accuracy)
print('整体测试集上的Loss:{}'.format(total_test_loss))
print('整体测试集上的正确率:{}'.format(total_accuracy / test_length))
三、Test
import json
import torch
import torchvision
from PIL import Image
from torch import nn
from model import *
img_path='../test_image/cat1.png'
img=Image.open(img_path)
transform=torchvision.transforms.Compose([
torchvision.transforms.Resize((227,227)),
torchvision.transforms.ToTensor()
])
img=transform(img)
json_path='./class_indices.json'
with open(json_path,'r') as json_filr:
class_dict=json.load(json_filr)
weight_path='./SqueezeNet.pth'
net=SqueezeNet()
net.load_state_dict(torch.load(weight_path))
img=torch.reshape(img,[1,3,227,227])
net.eval()
with torch.no_grad():
output=net(img)
result=output.argmax(1)
# 显示预测结果
for key,value in class_dict.items():
# print(type(key),type(value))
if value == int(result):
print("This picture is of an {}".format(key))
总结
注:该文章为非盈利文章,以上代码如有侵权请联系删除,小编的qq:2370154327