·
前言
根据之前学习到的pytorch一整个系列的流程,可以自己去写一个深度学习模型,并且进行一系列完整的测试。
本文采用fashion_mnist数据集进行训练,手写一个网络模型,在测试集上的准确率达到 。代码结构如下图所示:
一、数据集
数据集下载:
trans = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.RandomHorizontalFlip(p=0.5), # 水平0.5概率翻转
torchvision.transforms.RandomRotation(degrees=30) # 30度旋转
])
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=False
)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=False
)
使用的是pytorch内部的数据集,对于自定义的数据集需要进行相关的预处理操作,可以看我之前的关于dataset,dataloader的博客。
train_dataload = data.DataLoader(mnist_train, batch_size=128, shuffle=True)
test_dataload = data.DataLoader(mnist_test, batch_size=128, shuffle=True)
二、模型
模型总共6层,4层卷积层,2层全连接层。model.py文件如下
from torch import nn
import torch
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.model1 = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=1, padding=1),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Flatten(),
nn.Linear(128 * 8 * 8, 256),
nn.Dropout(p=0.5),
nn.ReLU(),
nn.Linear(256, 10)
)
def forward(self, x):
return self.model1(x)
三、模型训练
模型训练train.py文件:
import torch
from torch.utils import data
import torchvision
import matplotlib.pyplot as plt
from model import model
def train_model(train_dataloader, test_dataloader, train_size, test_size, epochs):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = model()
net = net.to(device)
optimer = torch.optim.Adam(net.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
train_loss = []
test_loss = []
test_accuracy = []
step = 0
for epoch in range(epochs):
net.train()
total_train_loss = 0
print("-------------------第{}轮训练开始------------------".format(epoch + 1))
for data in train_dataloader:
imgs, targets = data
imgs, targets = imgs.to(device), targets.to(device)
output = net(imgs)
optimer.zero_grad()
loss = loss_fn(output, targets)
total_train_loss += loss.item()
loss.backward()
optimer.step()
step += 128
if step % 1024 == 0:
print("第{}次train,loss:{}".format(step, loss / len(targets)))
train_loss.append(total_train_loss / train_size)
net.eval()
total_test_loss = 0
total_test_accuracy = 0
with torch.no_grad():
for data in test_dataloader:
imgs, targets = data
imgs, targets = imgs.to(device), targets.to(device)
output = net(imgs)
loss = loss_fn(output, targets)
total_test_loss += loss.item()
total_test_accuracy += (output.argmax(1) == targets).sum()
test_loss.append(total_test_loss / test_size)
test_accuracy.append(total_test_accuracy / test_size)
print("test集,loss:{},accuracy:{}".format(total_test_loss / test_size,
total_test_accuracy / test_size))
if (epoch + 1) % epochs == 0:
torch.save(net, "fashion_mnistmodel{}.pth".format(epoch))
plt.xlabel("epoch")
plt.ylabel("val")
plt.plot(range(1, epochs + 1), train_loss,
range(1, epochs + 1), test_loss,
range(1, epochs + 1), test_accuracy)
plt.legend(["train_loss", "test_loss", "test_accuracy"])
plt.show()
if __name__ == '__main__':
trans = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.RandomHorizontalFlip(p=0.5), # 水平0.5概率翻转
torchvision.transforms.RandomRotation(degrees=30) # 30度旋转
])
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=False
)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=False
)
train_dataload = data.DataLoader(mnist_train, batch_size=128, shuffle=True)
test_dataload = data.DataLoader(mnist_test, batch_size=128, shuffle=True)
train_model(train_dataload, test_dataload, mnist_train.__len__(), mnist_test.__len__(), 35)
绘制得到的loss图像如下所示:
四、模型测试
对于预测后需要进行标签对应,代码如下:
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
可以使用plt来绘制一下fashion_mnist的图像,代码如下:
def show_images(imgs, num_rows, num_cols, titles=None):
"""绘制图像列表"""
for i, x in enumerate(imgs):
# 绘制一个n*m个图片围成的画布
plt.subplot(num_rows, num_cols, i + 1)
plt.imshow(x.squeeze(0))
plt.title(titles[i])
plt.xticks([])
plt.yticks([])
plt.show()
x, y = next(iter(data.DataLoader(mnist_train, batch_size=18, shuffle=True)))
show_images(x.reshape(18, 28, 28), 3, 6, titles=get_fashion_mnist_labels(y))
总结
本文实现了fashion_mnist的数据集模型搭建和测试,因为只是记录一下pytorch搭建网络的一个大体过程,所以很多地方没有进行解释和注释,有不懂的欢迎大家在评论区或私信对我提问。