pyTorch 图像分类模型训练教程

pyTorch 图像识别教程

代码:
https://github.com/dwSun/classification-tutorial.git

这里以 TinyMind 《汉字书法识别》比赛数据为例,展示使用 Pytorch 进行图像数据分类模型训练的整个流程。

数据地址请参考:
https://www.tinymind.cn/competitions/41#property_23

或到这里下载:
自由练习赛数据下载地址:
训练集:链接: https://pan.baidu.com/s/1UxvN7nVpa0cuY1A-0B8gjg 密码: aujd

测试集: https://pan.baidu.com/s/1tzMYlrNY4XeMadipLCPzTw 密码: 4y9k

数据探索

请参考官方的数据说明

数据处理

竞赛中只有训练集 train 数据有准确的标签,因此这里只使用 train 数据即可,实际应用中,阶段 1、2 的榜单都需要使用。

数据下载

下载数据之后进行解压,得到 train 文件夹,里面有 100 个文件夹,每个文件夹名字即是各个汉字的标签。类似的数据集结构经常在分类任务中见到。可以使用下述命令验证一下每个文件夹下面文件的数量,看数据集是否符合竞赛数据描述:

for l in $(ls); do echo $l $(ls $l|wc -l); done

划分数据集

因为这里只使用了 train 集,因此我们需要对已有数据集进行划分,供模型训练的时候做验证使用,也就是 validation 集的构建。

一般认为,train 用来训练模型,validation 用来对模型进行验证以及超参数( hyper parameter)调整,test 用来做模型的最终验证,我们所谓模型的性能,一般也是指 test 集上模型的性能指标。但是实际项目中,一般只有 train 集,同时没有可靠的 test 集来验证模型,因此一般将 train 集划分出一部分作为 validation,同时将 validation 上的模型性能作为最终模型性能指标。

一般情况下,我们不严格区分 validation 和 test。

这里将每个文件夹下面随机50个文件拿出来做 validation。

export train=train
export val=validation

for d in $(ls $train); do
    mkdir -p $val/$d/
    for f in $(ls train/$d | shuf | head -n 50 ); do
        mv $train/$d/$f $val/$d/;
    done;
done

需要注意,这里的 validation 只间接通过超参数的调整参与了模型训练。因此有一定的数据浪费。

模型训练代码-数据部分

首先导入 pyTorch 看一下版本

import torch
import torchvision as tv

torch.__version__
'1.4.0'

训练模型的时候,模型内部全部都是数字,没有任何可读性,而且这些数字也需要人为给予一些实际的意义,这里将 100 个汉字作为模型输出数字的文字表述。

需要注意的是,因为模型训练往往是一个循环往复的过程,因此一个稳定的文字标签是很有必要的,这里利用相关 python 代码在首次运行的时候生成了一个标签文件,后续检测到这个标签文件,则直接调用即可。

import os

if os.path.exists("labels.txt"):
    with open("labels.txt") as inf:
        classes = [l.strip() for l in inf]
else:
    classes = os.listdir("worddata/train/")
    with open("labels.txt", "w") as of:
        of.write("\r\n".join(classes))

class_idx = {v: k for k, v in enumerate(classes)}
idx_class = dict(enumerate(classes))

pyTorch里面,classes有自己的组织方式,这里我们想要自定义,要做一下转换。

from PIL import Image

pth_classes = classes[:]
pth_classes.sort()
pth_classes_to_idx = {v: k for k, v in enumerate(pth_classes)}


def target_transform(pth_idx):
    return class_idx[pth_classes[pth_idx]]

pyTorch 中提供了直接从目录中读取数据并进行训练的 API 这里使用的API如下。

这里使用了两个数据集,分别代表 train、validation。

需要注意的是,由于 数据中,使用的图像数据集,其数值在(0, 255)之间。同时,pyTorch 用 pillow 来处理图像的加载,其图像的数据layout是(H,W,C),而 pyTorch用来训练的数据需要是(C,H,W)的,因此需要对数据做一些转换。另外,train 数据集做了一定的数据预处理(旋转、明暗度),用于进行数据增广,也做了数据打乱(shuffle),而 validation则不需要做类似的变换。

这里有一些地方需要注意一下,RandomRotation 我们使用了 expand 所以每次输出图像大小都不同,resize 操作要放在后面。pyTorch 中我没找到如何直接用灰度方式读取图像,对于汉字来说,色彩没有任何意义。因此这里用 Grayscale 来转换图像为灰度。ToTensor这个操作会转换数据的 layout,因此要放在最后面。

from multiprocessing import cpu_count

transform_train = tv.transforms.Compose(
    [
        # tv.transforms.RandomRotation((-15, 15), expand=True),
        tv.transforms.RandomRotation((-15, 15)),
        tv.transforms.Resize((128, 128)),
        tv.transforms.ColorJitter(brightness=0.5),
        tv.transforms.Grayscale(),
        tv.transforms.ToTensor(),
    ]
)
transform_val = tv.transforms.Compose(
    [
        tv.transforms.Resize((128, 128)),
        tv.transforms.Grayscale(),
        tv.transforms.ToTensor(),
    ]
)

img_gen_train = tv.datasets.ImageFolder(
    "worddata/train/", transform=transform_train, target_transform=target_transform
)


img_gen_val = tv.datasets.ImageFolder(
    "worddata/validation/", transform=transform_val, target_transform=target_transform
)

batch_size = 32

img_train = torch.utils.data.DataLoader(
    img_gen_train, batch_size=batch_size, shuffle=True, num_workers=cpu_count()
)
img_val = torch.utils.data.DataLoader(
    img_gen_val, batch_size=batch_size, num_workers=cpu_count()
)

到这里,这两个数据集就可以使用了,正式模型训练之前,我们可以先来看看这个数据集是怎么读取数据的,读取出来的数据又是设么样子的。

for imgs, labels in img_train:
    # img_train 只部分满足 generator 的语法,不能用 next 来获取数据
    break
imgs.shape, labels.shape
(torch.Size([32, 1, 128, 128]), torch.Size([32]))

可以看到数据是(batch, channel, height, width, height), 因为这里是灰度图像,因此 channel 是 1。

需要注意,pyTorch、mxnet使用的数据 layout 与Tensorflow 不同,因此数据也有一些不同的处理方式。

把图片打印出来看看,看看数据和标签之间是否匹配

import numpy as np
from matplotlib import pyplot as plt

plt.imshow(imgs[0, 0, :, :], cmap="gray")
classes[labels[0]]
'寒'

在这里插入图片描述

模型训练代码-模型构建

pyTorch 中使用静态图来构建模型,模型构建比较简单。这里演示的是使用 class 的方式构建模型,对于简单模型,还可以直接使用 Sequential 进行构建。

这里的复杂模型也是用 Sequential 的简单模型进行的叠加。

这里构建的是VGG模型,关于VGG模型的更多细节请参考 1409.1556。

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 模型有两个主要部分,特征提取层和分类器

        # 这里是特征提取层
        self.feature = torch.nn.Sequential()
        self.feature.add_module("conv1", self.conv(1, 64))
        self.feature.add_module("conv2", self.conv(64, 64, add_pooling=True))

        self.feature.add_module("conv3", self.conv(64, 128))
        self.feature.add_module("conv4", self.conv(128, 128, add_pooling=True))

        self.feature.add_module("conv5", self.conv(128, 256))
        self.feature.add_module("conv6", self.conv(256, 256))
        self.feature.add_module("conv7", self.conv(256, 256, add_pooling=True))

        self.feature.add_module("conv8", self.conv(256, 512))
        self.feature.add_module("conv9", self.conv(512, 512))
        self.feature.add_module("conv10", self.conv(512, 512, add_pooling=True))

        self.feature.add_module("conv11", self.conv(512, 512))
        self.feature.add_module("conv12", self.conv(512, 512))
        self.feature.add_module("conv13", self.conv(512, 512, add_pooling=True))

        self.feature.add_module("avg", torch.nn.AdaptiveAvgPool2d((1, 1)))
        self.feature.add_module("flatten", torch.nn.Flatten())

        self.feature.add_module("linear1", torch.nn.Linear(512, 4096))
        self.feature.add_module("act_linear_1", torch.nn.ReLU())
        self.feature.add_module("bn_linear_1", torch.nn.BatchNorm1d(4096))

        self.feature.add_module("linear2", torch.nn.Linear(4096, 4096))
        self.feature.add_module("act_linear_2", torch.nn.ReLU())
        self.feature.add_module("bn_linear_2", torch.nn.BatchNorm1d(4096))

        self.feature.add_module("dropout", torch.nn.Dropout())

        # 这个简单的机构是分类器
        self.pred = torch.nn.Linear(4096, 100)

    def conv(self, in_channels, out_channels, add_pooling=False):
        # 模型大量使用重复模块构建,
        # 这里将重复模块提取出来,简化模型构建过程
        model = torch.nn.Sequential()
        model.add_module(
            "conv", torch.nn.Conv2d(in_channels, out_channels, 3, padding=1)
        )
        model.add_module("act_conv", torch.nn.ReLU())
        model.add_module("bn_conv", torch.nn.BatchNorm2d(out_channels))

        if add_pooling:
            model.add_module("pool", torch.nn.MaxPool2d((2, 2)))
        return model

    def forward(self, x):
        # call 用来定义模型各个结构之间的运算关系

        x = self.feature(x)
        return self.pred(x)

可以看到,这里必须指定网络输入输出,对比 TF 和 mxnet 不是很方便。

实例化一个模型看看:

model = MyModel()
model
MyModel(
  (feature): Sequential(
    (conv1): Sequential(
      (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv2): Sequential(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    )
    (conv3): Sequential(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv4): Sequential(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    )
    (conv5): Sequential(
      (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv6): Sequential(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv7): Sequential(
      (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    )
    (conv8): Sequential(
      (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv9): Sequential(
      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv10): Sequential(
      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    )
    (conv11): Sequential(
      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv12): Sequential(
      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv13): Sequential(
      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (act_conv): ReLU()
      (bn_conv): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    )
    (avg): AdaptiveAvgPool2d(output_size=(1, 1))
    (flatten): Flatten()
    (linear1): Linear(in_features=512, out_features=4096, bias=True)
    (act_linear_1): ReLU()
    (bn_linear_1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (linear2): Linear(in_features=4096, out_features=4096, bias=True)
    (act_linear_2): ReLU()
    (bn_linear_2): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (pred): Linear(in_features=4096, out_features=100, bias=True)
)

模型训练代码-训练相关部分

要训练模型,我们还需要定义损失,优化器等。

loss_object = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())  # 优化器有些参数可以设置
import time  # 模型训练的过程中手动追踪一下模型的训练速度

因为模型整个训练过程一般是一个循环往复的过程,所以经常性的保存重启模型训练中间过程是有必要的。
这里我们一个ckpt保存了两份,便于中断模型的重新训练。

import os

gpu = 1

model.cuda(gpu)
if os.path.exists("model_ckpt.pth"):
    # 检查 checkpoint 是否存在
    # 如果存在,则加载 checkpoint

    net_state, optm_state = torch.load("model_ckpt.pth")

    model.load_state_dict(net_state)
    optimizer.load_state_dict(optm_state)

    # 这里是一个比较生硬的方式,其实还可以观察之前训练的过程,
    # 手动选择准确率最高的某次 checkpoint 进行加载。
    print("model lodaded")
EPOCHS = 40
for epoch in range(EPOCHS):

    train_loss = 0
    train_accuracy = 0
    train_samples = 0

    val_loss = 0
    val_accuracy = 0
    val_samples = 0

    start = time.time()
    for imgs, labels in img_train:
        imgs = imgs.cuda(gpu)
        labels = labels.cuda(gpu)

        preds = model(imgs)

        loss = loss_object(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_samples += imgs.shape[0]

        train_loss += loss.item()
        train_accuracy += (preds.argmax(dim=1) == labels).sum().item()

    train_samples_per_second = train_samples / (time.time() - start)

    start = time.time()
    for imgs, labels in img_val:
        imgs = imgs.cuda(gpu)
        labels = labels.cuda(gpu)
        model.eval()
        preds = model(imgs)
        model.train()
        val_loss += loss.item()
        val_accuracy += (preds.argmax(dim=1) == labels).sum().item()

        val_samples += imgs.shape[0]

    val_samples_per_second = val_samples / (time.time() - start)

    print(
        "Epoch {} Loss {}, Acc {}, Val Loss {}, Val Acc {}".format(
            epoch,
            train_loss * batch_size / train_samples,
            train_accuracy * 100 / train_samples,
            val_loss * batch_size / val_samples,
            val_accuracy * 100 / val_samples,
        )
    )
    print(
        "Speed train {}imgs/s val {}imgs/s".format(
            train_samples_per_second, val_samples_per_second
        )
    )

    torch.save((model.state_dict(), optimizer.state_dict()), "model_ckpt.pth")
    torch.save(
        (model.state_dict(), optimizer.state_dict()),
        "model_ckpt-{:04d}.pth".format(epoch),
    )

    # 每个 epoch 保存一下模型,需要注意每次
    # 保存要用一个不同的名字,不然会导致覆盖,
    # 同时还要关注一下磁盘空间占用,防止太多
    # chekcpoint 占满磁盘空间导致错误。
Epoch 0 Loss 6.379009783063616, Acc 0.96, Val Loss 7.499249600219726, Val Acc 1.06
Speed train 234.8145027420376imgs/s val 706.7103467600252imgs/s
Epoch 1 Loss 6.250997774396624, Acc 1.1485714285714286, Val Loss 8.049263702392578, Val Acc 1.12
Speed train 230.86756209036096imgs/s val 684.4524146021467imgs/s
Epoch 2 Loss 6.0144778276715956, Acc 1.16, Val Loss 5.387196655273438, Val Acc 1.22
Speed train 226.12469399959375imgs/s val 681.3837266883015imgs/s
Epoch 3 Loss 5.589597338867187, Acc 1.0057142857142858, Val Loss 5.907029174804688, Val Acc 1.1
Speed train 225.0556244090409imgs/s val 680.1839950846079imgs/s
Epoch 4 Loss 5.402270581054688, Acc 1.1714285714285715, Val Loss 5.1295126434326175, Val Acc 1.4
Speed train 224.72032368819234imgs/s val 682.6134263746877imgs/s
Epoch 5 Loss 5.175169513811384, Acc 1.062857142857143, Val Loss 5.386006506347656, Val Acc 0.78
Speed train 224.69030177626445imgs/s val 680.2019310561357imgs/s
Epoch 6 Loss 4.945824640328544, Acc 1.2485714285714287, Val Loss 5.106301385498047, Val Acc 1.7
Speed train 224.41665937455835imgs/s val 680.7435881896529imgs/s
Epoch 7 Loss 4.78519496547154, Acc 1.2714285714285714, Val Loss 5.001887857055664, Val Acc 1.46
Speed train 224.3077478615822imgs/s val 679.1943518703032imgs/s
Epoch 8 Loss 4.681244001116071, Acc 1.5542857142857143, Val Loss 4.678347979736328, Val Acc 1.76
Speed train 224.43681583771664imgs/s val 680.9421894383488imgs/s
Epoch 9 Loss 4.594511456734794, Acc 1.977142857142857, Val Loss 4.734268209838867, Val Acc 3.48
Speed train 224.43505593143354imgs/s val 679.8263336556521imgs/s
Epoch 10 Loss 4.564881538609096, Acc 2.2142857142857144, Val Loss 4.5007019622802735, Val Acc 3.5
Speed train 224.35177417457732imgs/s val 681.4680194348261imgs/s
Epoch 11 Loss 4.359355766732352, Acc 3.797142857142857, Val Loss 4.303946963500977, Val Acc 5.1
Speed train 224.22713261480806imgs/s val 678.4494364538398imgs/s
Epoch 12 Loss 4.05738628692627, Acc 6.651428571428571, Val Loss 3.5746582946777345, Val Acc 10.42
Speed train 224.18908188445624imgs/s val 679.7794406874162imgs/s
Epoch 13 Loss 3.7937849918910436, Acc 10.214285714285714, Val Loss 3.7133444229125976, Val Acc 7.04
Speed train 224.10400818300923imgs/s val 679.113765396056imgs/s
Epoch 14 Loss 3.2694146046229773, Acc 19.425714285714285, Val Loss 3.6981288192749022, Val Acc 30.78
Speed train 224.13154184979035imgs/s val 680.3901503258012imgs/s
Epoch 15 Loss 2.7287981418064664, Acc 31.591428571428573, Val Loss 2.7384859634399414, Val Acc 43.0
Speed train 224.08093870796063imgs/s val 680.9818793692156imgs/s
Epoch 16 Loss 2.4017765145438057, Acc 40.222857142857144, Val Loss 2.373513427734375, Val Acc 55.06
Speed train 224.0240886741711imgs/s val 680.9301175413847imgs/s
Epoch 17 Loss 1.9575243755885532, Acc 50.81428571428572, Val Loss 1.8042015686035155, Val Acc 60.3
Speed train 224.023773128244imgs/s val 679.3334877190623imgs/s
Epoch 18 Loss 1.8670056664603096, Acc 52.754285714285714, Val Loss 1.7974752388000488, Val Acc 59.12
Speed train 224.04512456698183imgs/s val 677.1390343683371imgs/s
Epoch 19 Loss 1.6107693487439836, Acc 58.48571428571429, Val Loss 2.0469212783813475, Val Acc 66.72
Speed train 223.87349316639512imgs/s val 678.0275851549168imgs/s
Epoch 20 Loss 1.7171708895547049, Acc 56.642857142857146, Val Loss 1.8279149505615235, Val Acc 65.96
Speed train 223.90746153613367imgs/s val 677.346826146328imgs/s
Epoch 21 Loss 1.2915482904706683, Acc 65.76285714285714, Val Loss 1.3189221771240234, Val Acc 68.62
Speed train 223.9696856948392imgs/s val 680.7548137514192imgs/s
Epoch 22 Loss 1.1914144684110368, Acc 68.43714285714286, Val Loss 1.0220409889221191, Val Acc 67.86
Speed train 223.9880397987904imgs/s val 679.2464221336535imgs/s
Epoch 23 Loss 1.0181893185751778, Acc 72.81428571428572, Val Loss 0.6417443874359131, Val Acc 75.94
Speed train 223.8553653733427imgs/s val 678.307898091706imgs/s
Epoch 24 Loss 0.9370736787523543, Acc 75.12, Val Loss 1.0853789276123047, Val Acc 76.04
Speed train 223.88541791918547imgs/s val 681.0767780628761imgs/s
Epoch 25 Loss 0.858675898034232, Acc 76.78, Val Loss 0.6966076656341553, Val Acc 76.14
Speed train 223.89225800022191imgs/s val 677.8778524201055imgs/s
Epoch 26 Loss 0.911681534739903, Acc 75.55142857142857, Val Loss 1.2748067726135255, Val Acc 75.28
Speed train 223.8182146384268imgs/s val 678.3320102990313imgs/s
Epoch 27 Loss 0.7263422344616481, Acc 80.27428571428571, Val Loss 0.8662283229827881, Val Acc 76.24
Speed train 223.84745302957168imgs/s val 677.5476529925469imgs/s
Epoch 28 Loss 0.7096801671164377, Acc 80.64, Val Loss 0.4879303056716919, Val Acc 78.4
Speed train 223.82907588526868imgs/s val 679.641487715212imgs/s
Epoch 29 Loss 0.8400143226759774, Acc 77.24857142857142, Val Loss 0.48099885005950926, Val Acc 76.28
Speed train 223.79857240610696imgs/s val 677.2004991093537imgs/s
Epoch 30 Loss 0.6340663018226623, Acc 82.75428571428571, Val Loss 0.3814028434753418, Val Acc 77.88
Speed train 223.69922632756038imgs/s val 677.8182801677248imgs/s
Epoch 31 Loss 0.6143715186391558, Acc 83.12571428571428, Val Loss 1.5937435668945312, Val Acc 55.9
Speed train 223.73651624919145imgs/s val 677.9854113416313imgs/s
Epoch 32 Loss 0.6921936396871294, Acc 80.92285714285714, Val Loss 0.6802982173919677, Val Acc 74.06
Speed train 223.6073489617921imgs/s val 675.2316887830606imgs/s
Epoch 33 Loss 0.6144891169275556, Acc 83.18, Val Loss 0.46930033054351805, Val Acc 76.34
Speed train 223.57683437615302imgs/s val 677.3064650659536imgs/s
Epoch 34 Loss 0.568616727393014, Acc 84.28, Val Loss 0.4940680891036987, Val Acc 79.08
Speed train 223.53035806536994imgs/s val 675.1831667819793imgs/s
Epoch 35 Loss 0.5646722382409232, Acc 84.30571428571429, Val Loss 0.4494327730178833, Val Acc 79.24
Speed train 223.53708898876783imgs/s val 676.0269683297067imgs/s
Epoch 36 Loss 0.977967550604684, Acc 74.1, Val Loss 0.8460039363861084, Val Acc 74.68
Speed train 223.70501802195884imgs/s val 675.9956110209276imgs/s
Epoch 37 Loss 0.7239568670545306, Acc 80.12857142857143, Val Loss 1.048443465423584, Val Acc 80.18
Speed train 223.7010776600049imgs/s val 677.2457684330305imgs/s
Epoch 38 Loss 0.5576571273531232, Acc 84.37714285714286, Val Loss 0.7641737712860107, Val Acc 78.96
Speed train 223.92617365625426imgs/s val 679.0357044899268imgs/s
Epoch 39 Loss 0.4953382140840803, Acc 86.30285714285715, Val Loss 0.9396348545074463, Val Acc 80.98
Speed train 223.80456065939208imgs/s val 677.9383775081458imgs/s

一些技巧

因为这里定义的模型比较大,同时训练的数据也比较多,每个 epoch 用时较长,因此,如果代码有 bug 的话,经过一次 epoch 再去 debug 效率比较低。

这种情况下,我们使用的数据生成过程又是自己手动指定数据数量的,因此可以尝试缩减模型规模,定义小一些的数据集来快速验证代码。在这个例子里,我们可以通过注释模型中的卷积和全连接层的代码来缩减模型尺寸,通过修改训练循环里面的数据数量来缩减数据数量。

训练的速度很慢

类似的网络结构和参数,TF里面 20epochs能达到90%的准确率,这里要40epochs才能到86%,应该是哪里有什么问题,我再看看怎么解决。


下面的代码属于另外一个文件,因此部分代码跟上面是重复的。

模型的使用代码

模型训练好了之后要实际应用。对于模型部署有很多成熟的方案,如 Nvidia 的 TensorRT, Intel 的 OpenVINO 等,都可以做模型的高效部署,这里限于篇幅不涉及相关内容。

在模型训练过程中,也可以使用使用框架提供的 API 做模型的简单部署以方便开发。

import torch
import torchvision as tv
import os
torch.__version__
'1.4.0'

首先要加载模型的标签用于展示,因为我们训练的时候就已经生成了标签文件,这里直接用写好的代码就可以。

if os.path.exists("labels.txt"):
    with open("labels.txt") as inf:
        classes = [l.strip() for l in inf]
else:
    classes = os.listdir("worddata/train/")
    with open("labels.txt", "w") as of:
        of.write("\r\n".join(classes))

接着是模型的定义,这里直接将训练中使用的模型代码拿来即可。

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 模型有两个主要部分,特征提取层和分类器

        # 这里是特征提取层
        self.feature = torch.nn.Sequential()
        self.feature.add_module("conv1", self.conv(1, 64))
        self.feature.add_module("conv2", self.conv(64, 64, add_pooling=True))

        self.feature.add_module("conv3", self.conv(64, 128))
        self.feature.add_module("conv4", self.conv(128, 128, add_pooling=True))

        self.feature.add_module("conv5", self.conv(128, 256))
        self.feature.add_module("conv6", self.conv(256, 256))
        self.feature.add_module("conv7", self.conv(256, 256, add_pooling=True))

        self.feature.add_module("conv8", self.conv(256, 512))
        self.feature.add_module("conv9", self.conv(512, 512))
        self.feature.add_module("conv10", self.conv(512, 512, add_pooling=True))

        self.feature.add_module("conv11", self.conv(512, 512))
        self.feature.add_module("conv12", self.conv(512, 512))
        self.feature.add_module("conv13", self.conv(512, 512, add_pooling=True))

        self.feature.add_module("avg", torch.nn.AdaptiveAvgPool2d((1, 1)))
        self.feature.add_module("flatten", torch.nn.Flatten())

        self.feature.add_module("linear1", torch.nn.Linear(512, 4096))
        self.feature.add_module("act_linear_1", torch.nn.ReLU())
        self.feature.add_module("bn_linear_1", torch.nn.BatchNorm1d(4096))

        self.feature.add_module("linear2", torch.nn.Linear(4096, 4096))
        self.feature.add_module("act_linear_2", torch.nn.ReLU())
        self.feature.add_module("bn_linear_2", torch.nn.BatchNorm1d(4096))

        self.feature.add_module("dropout", torch.nn.Dropout())

        # 这个简单的机构是分类器
        self.pred = torch.nn.Linear(4096, 100)

    def conv(self, in_channels, out_channels, add_pooling=False):
        # 模型大量使用重复模块构建,
        # 这里将重复模块提取出来,简化模型构建过程
        model = torch.nn.Sequential()
        model.add_module(
            "conv", torch.nn.Conv2d(in_channels, out_channels, 3, padding=1)
        )
        model.add_module("act_conv", torch.nn.ReLU())
        model.add_module("bn_conv", torch.nn.BatchNorm2d(out_channels))

        if add_pooling:
            model.add_module("pool", torch.nn.MaxPool2d((2, 2)))
        return model

    def forward(self, x):
        # call 用来定义模型各个结构之间的运算关系

        x = self.feature(x)
        return self.pred(x)

有了模型的定义之后,我们可以加载训练好的模型,跟模型训练的时候类似,我们可以直接加载模型训练中的 checkpoint。

import os
model = MyModel().cuda()

if os.path.exists('ckpts_pth/model_ckpt.pth'):
    # 检查 checkpoint 是否存在
    # 如果存在,则加载 checkpoint

    net_state, optm_state = torch.load('ckpts_pth/model_ckpt.pth')

    model.load_state_dict(net_state)

    # 这里是一个比较生硬的方式,其实还可以观察之前训练的过程,
    # 手动选择准确率最高的某次 checkpoint 进行加载。
    print("model lodaded")
model lodaded

对于数据,我们需要直接处理图片,因此这里导入一些图片处理的库和数据处理的库

from matplotlib import pyplot as plt
import numpy as np
from PIL import Image

直接打开某个图片

img = Image.open(
    "worddata/validation/从/116e891836204e4e67659d2b73a7e4780a37c301.jpg")

plt.imshow(img, cmap="gray")

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hNIynK8t-1587993165325)(output_11_1.png)]

需要注意,模型在训练的时候,我们对数据进行了一些处理,在模型使用的时候,我们要对数据做一样的处理,如果不做的话,模型最终的结果会出现不可预料的问题。

img = img.resize((128, 128))
img = np.array(img) / 255
img.shape
(128, 128)

模型对图片数据的运算其实很简单,一行代码就可以。

这里需要注意模型处理的数据是 4 维的,而上面的图片数据实际是 2 维的,因此要对数据进行维度的扩充。同时模型的输出是 2 维的,带 batch ,所以需要压缩一下维度。

model.eval()
pred = np.squeeze(
    model(torch.Tensor(img[np.newaxis, np.newaxis, :, :]).cuda()))
pred = torch.nn.functional.softmax(pred)
pred.argsort()[-5:]

print([pred[idx].item() for idx in pred.argsort()[-5:]])
print([classes[idx] for idx in pred.argsort()[-5:]])
[7.042408323165716e-11, 1.551086897810805e-10, 2.2588204917628474e-10, 4.854148372146483e-08, 1.0]
['遂', '夜', '御', '作', '从']


/home/dl/miniconda3/lib/python3.7/site-packages/ipykernel_launcher.py:4: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
  after removing the cwd from sys.path.

这里只给出了 top5 的结果,可以看到,准确率还是不错的。

相关推荐
©️2020 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客 返回首页