五、肺癌检测-数据集训练 training.py model.py

上一篇文章中已经通过将dsets.py实现将数据集封装加载,之后就可以通过建立了模型并编写training脚本实现模型的训练了。这一篇文章主要是对《pytorch深度学习实战》第11章内容做的笔记。

一、目标

1、建立简单的卷积神经网络

2、编写训练函数

3、编写训练日志(训练和验证过程的loss,accuracy等)数据结构

4、使用tensorboard可视化训练信息。

二、要点说明

1. 对函数使用通用的系统进程级别的调用

原书代码的【code/p2_run_everything.ipynb】的cell2中,定义了一个通用的系统进程方式的调用方法。通过这种方法可以调用所有脚本中的函数。但个人认为还是挺麻烦的,一点都不人性化。建议不要把精力花在这部分代码上,知道代码是在干嘛就行。

def run(app, *argv):
    argv = list(argv)
    argv.insert(0, '--num-workers=4')  # <1> 使用4个核
    log.info("Running: {}({!r}).main()".format(app, argv))
    
    app_cls = importstr(*app.rsplit('.', 1))  # <2>    # 动态加载库
    app_cls(argv).main()    # 调用app类的main函数
    
    log.info("Finished: {}.{!r}).main()".format(app, argv))

使用示例:从p2ch11文件夹的training.py文件中importLunaTrainingApp类并调用其main函数,函数的输入参数是epochs=1。

run('p2ch11.training.LunaTrainingApp', '--epochs=1')

其中:

1.1 importstr函数

函数是为了实现动态调用各个库和库函数。类似于from 【pkg_name】 import 【func_name】的作用。通过importstr可以实现动态加载函数,而不用调用前用import声明。

1.2 rsplit函数

 函数用法:list = str.rsplit(sep, maxsplit)。可参考下面的文章。简单而言就是对字符【str】按照【sep】分隔符进行拆分,从字符右侧开始拆分,一共拆分【maxsplit】次。返回的是拆分结果是一个list。

Python实用语法之rsplit_明 总 有的博客-CSDN博客_python rsplit

1.3 argparse库

在原书代码的【prepcache.py】文件中,使用了argparse库。argparse库是用来解决使用命令行执行函数时,让命令行能够解析我们输入的参数名称和参数值的问题。定义了参数解释器后,我们在命令行执行函数时,就可以像使用conda命令一样,用类似【conda --user xxx】一样的方式来执行函数了。

argparse库的具体用法可以参考以下文章:

argparse.ArgumentParser()的用法_无尽的沉默的博客-CSDN博客_argparse.argumentparser

简单用法如下:

import argparse

parser = argparse.ArgumentParser()      # 创建一个参数解释器
parser.add_argument("--arg1", type=int, help="一个整数", default=1)  # 通过 --argName方式声明参数,为int类型
parser.add_argument("--arg2", type=int, help="一个整数", default=2)  # 通过 --argName方式声明参数,为int类型

args = parser.parse_args()      # 解析参数

print("arg1 = {0}".format(args.arg1))
print("arg2 = {0}".format(args.arg2))

 使用命令行运行结果如下:

(pytorch) E:\CT\code>python test2.py --arg1 1 --arg2 2
arg1 = 1
arg2 = 2

1.4  @classmethod修饰器

在原书代码的【prepcache.py】文件中,使用了@classmethod修饰器,这样就可以不实例化对象直接调用类内的函数。

2. 模型建立

书中在11章用的是简单的卷积堆叠+线性层的神经网络结果,没任何特别之处。其中线性层由于只是简单2分类(结节是否为肿瘤),所以只用了一个线性层。卷积和池化用的是3维的卷积和池化。

2.1 多GPU设置

多GPU训练可通过nn.DataParallel(model)或DistributedParallel函数实现,前者较为简单,一般用在单机多卡场景,后者配置较为复杂,一般用在多台计算机的多卡场景。

2.2 优化器

一般开始训练时可以先尝试使用带动量的SGD,lr=0.001,momentum=0.9,不行再换其他优化器,如Adam。

2.3 模型输入尺寸

在上一篇文章中的ct类介绍中,width_irc参数定义了每个在irc坐标系的尺寸大小。也是数据集输入到模型的input_size。

2.4 模型信息

使用torchinfo库或者torchsummary库的summary函数都可以打印模型的参数信息。具体方法如下:

from p2ch11.model import LunaModel
import torchinfo    # 安装命令conda install torchinfo

model = LunaModel()
torchinfo.summary(model, (1, 32, 48, 48), batch_dim=0,
                  col_names = ('input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'), verbose = 1)

运行结果,即模型信息如下:

=====================================================================================================================================================================
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
=====================================================================================================================================================================
LunaModel                                [1, 1, 32, 48, 48]        [1, 2]                    --                        --                        --
├─BatchNorm3d: 1-1                       [1, 1, 32, 48, 48]        [1, 1, 32, 48, 48]        2                         --                        2
├─LunaBlock: 1-2                         [1, 1, 32, 48, 48]        [1, 8, 16, 24, 24]        --                        --                        --
│    └─Conv3d: 2-1                       [1, 1, 32, 48, 48]        [1, 8, 32, 48, 48]        224                       [3, 3, 3]                 16,515,072
│    └─ReLU: 2-2                         [1, 8, 32, 48, 48]        [1, 8, 32, 48, 48]        --                        --                        --
│    └─Conv3d: 2-3                       [1, 8, 32, 48, 48]        [1, 8, 32, 48, 48]        1,736                     [3, 3, 3]                 127,991,808
│    └─ReLU: 2-4                         [1, 8, 32, 48, 48]        [1, 8, 32, 48, 48]        --                        --                        --
│    └─MaxPool3d: 2-5                    [1, 8, 32, 48, 48]        [1, 8, 16, 24, 24]        --                        2                         --
├─LunaBlock: 1-3                         [1, 8, 16, 24, 24]        [1, 16, 8, 12, 12]        --                        --                        --
│    └─Conv3d: 2-6                       [1, 8, 16, 24, 24]        [1, 16, 16, 24, 24]       3,472                     [3, 3, 3]                 31,997,952
│    └─ReLU: 2-7                         [1, 16, 16, 24, 24]       [1, 16, 16, 24, 24]       --                        --                        --
│    └─Conv3d: 2-8                       [1, 16, 16, 24, 24]       [1, 16, 16, 24, 24]       6,928                     [3, 3, 3]                 63,848,448
│    └─ReLU: 2-9                         [1, 16, 16, 24, 24]       [1, 16, 16, 24, 24]       --                        --                        --
│    └─MaxPool3d: 2-10                   [1, 16, 16, 24, 24]       [1, 16, 8, 12, 12]        --                        2                         --
├─LunaBlock: 1-4                         [1, 16, 8, 12, 12]        [1, 32, 4, 6, 6]          --                        --                        --
│    └─Conv3d: 2-11                      [1, 16, 8, 12, 12]        [1, 32, 8, 12, 12]        13,856                    [3, 3, 3]                 15,962,112
│    └─ReLU: 2-12                        [1, 32, 8, 12, 12]        [1, 32, 8, 12, 12]        --                        --                        --
│    └─Conv3d: 2-13                      [1, 32, 8, 12, 12]        [1, 32, 8, 12, 12]        27,680                    [3, 3, 3]                 31,887,360
│    └─ReLU: 2-14                        [1, 32, 8, 12, 12]        [1, 32, 8, 12, 12]        --                        --                        --
│    └─MaxPool3d: 2-15                   [1, 32, 8, 12, 12]        [1, 32, 4, 6, 6]          --                        2                         --
├─LunaBlock: 1-5                         [1, 32, 4, 6, 6]          [1, 64, 2, 3, 3]          --                        --                        --
│    └─Conv3d: 2-16                      [1, 32, 4, 6, 6]          [1, 64, 4, 6, 6]          55,360                    [3, 3, 3]                 7,971,840
│    └─ReLU: 2-17                        [1, 64, 4, 6, 6]          [1, 64, 4, 6, 6]          --                        --                        --
│    └─Conv3d: 2-18                      [1, 64, 4, 6, 6]          [1, 64, 4, 6, 6]          110,656                   [3, 3, 3]                 15,934,464
│    └─ReLU: 2-19                        [1, 64, 4, 6, 6]          [1, 64, 4, 6, 6]          --                        --                        --
│    └─MaxPool3d: 2-20                   [1, 64, 4, 6, 6]          [1, 64, 2, 3, 3]          --                        2                         --
├─Linear: 1-6                            [1, 1152]                 [1, 2]                    2,306                     --                        2,306
├─Softmax: 1-7                           [1, 2]                    [1, 2]                    --                        --                        --
=====================================================================================================================================================================
Total params: 222,220
Trainable params: 222,220
Non-trainable params: 0
Total mult-adds (M): 312.11
=====================================================================================================================================================================
Input size (MB): 0.29
Forward/backward pass size (MB): 13.12
Params size (MB): 0.89
Estimated Total Size (MB): 14.31
=====================================================================================================================================================================

Process finished with exit code 0

3. 初始化

训练开始前,需要对权重进行初始化,初始化方法是通用的,具体参照书中代码【model.py】的_init_weights函数。

def _init_weights(self):
    for m in self.modules():
        if type(m) in {
            nn.Linear,
            nn.Conv3d,
            nn.Conv2d,
            nn.ConvTranspose2d,
            nn.ConvTranspose3d,
        }:
            nn.init.kaiming_normal_(
                m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
            )
            if m.bias is not None:
                fan_in, fan_out = \
                    nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                bound = 1 / math.sqrt(fan_out)
                nn.init.normal_(m.bias, -bound, bound)

4. 代码运行时间预计

原书代码中,定义了enumerateWithEstimate函数来预计运行完某段代码所需的运行时间。其中关键是利用了yield关键字,使enumerateWithEstimate一次次的迭代加载数据集。关于yield的用法可参考下面的文章。

python中yield的用法详解——最简单,最清晰的解释_冯爽朗的博客-CSDN博客_python yield

总的来说,声明为yield关键子的函数func,调用时类似断点执行:

1.首次执行时,代码执行到yield关键字右侧部分代码,并返回右侧部分代码的结果,类似return。yield之后的代码不在执行。

2. 用next函数再次调用函数func时,函数func继续从yield之后的代码开始执行,直到碰到下一个yield;如果函数后续没有别的yield关键字,则函数运行到末尾后返回函数开头重新运行,直至碰到yield。

3. 每次用next函数调用func时,不断重复第2点的执行方式。

5. 提高数据加载速度

原书中,作者通过diskacache库,将第一次加载的数据集缓存到磁盘中,下次训练或者验证再加载数据的时候,可直接在磁盘缓存中加载,可节省极大部分数据加载和预处理的时间。具体diskache库用法可参考下面的文章:

https://blog.csdn.net/wxyczhyza/article/details/127773721

6. tensorboard

pytorch1.2之后已集成tensorboard,直接在util库调用即可。

from torch.utils.tensorboard import SummaryWriter       # 调用tensorboard的SummaryWriter,用于记录训练性能

writer = SummaryWriter(file_path)           # 实例化时指明记录文件的路径
writer.add_scalar(tag, y_value, x_value)    # 添加标量
# writer.add_histogram()                    # 添加直方图
# writer.add_image()                        # 添加图像
writer.close()                              # 关闭文件引用

三、代码

原书代码可根据下面文章的代码链接下载,这里贴下我自己注释过的代码吧:

1. 网络模型 model.py

代码如下:

import math

from torch import nn as nn

from util.logconf import logging

log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)


class LunaModel(nn.Module):
    def __init__(self, in_channels=1, conv_channels=8):
        super().__init__()

        self.tail_batchnorm = nn.BatchNorm3d(1)

        self.block1 = LunaBlock(in_channels, conv_channels)
        self.block2 = LunaBlock(conv_channels, conv_channels * 2)
        self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
        self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)

        self.head_linear = nn.Linear(1152, 2)
        self.head_softmax = nn.Softmax(dim=1)

        self._init_weights()

    # see also https://github.com/pytorch/pytorch/issues/18182
    def _init_weights(self):
        for m in self.modules():
            if type(m) in {
                nn.Linear,
                nn.Conv3d,
                nn.Conv2d,
                nn.ConvTranspose2d,
                nn.ConvTranspose3d,
            }:
                nn.init.kaiming_normal_(
                    m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
                )
                if m.bias is not None:
                    fan_in, fan_out = \
                        nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                    bound = 1 / math.sqrt(fan_out)
                    nn.init.normal_(m.bias, -bound, bound)



    def forward(self, input_batch):
        bn_output = self.tail_batchnorm(input_batch)

        block_out = self.block1(bn_output)
        block_out = self.block2(block_out)
        block_out = self.block3(block_out)
        block_out = self.block4(block_out)

        conv_flat = block_out.view(
            block_out.size(0),
            -1,
        )
        linear_output = self.head_linear(conv_flat)

        return linear_output, self.head_softmax(linear_output)


class LunaBlock(nn.Module):
    def __init__(self, in_channels, conv_channels):
        super().__init__()

        self.conv1 = nn.Conv3d(
            in_channels, conv_channels, kernel_size=3, padding=1, bias=True,
        )
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(
            conv_channels, conv_channels, kernel_size=3, padding=1, bias=True,
        )
        self.relu2 = nn.ReLU(inplace=True)

        self.maxpool = nn.MaxPool3d(2, 2)

    def forward(self, input_batch):
        block_out = self.conv1(input_batch)
        block_out = self.relu1(block_out)
        block_out = self.conv2(block_out)
        block_out = self.relu2(block_out)

        return self.maxpool(block_out)

2.  enumerateWithEstimate函数

函数位置:util\util.py

函数主要用了yield关键字,使enumerateWithEstimate函数变为一个迭代器生成器,不断的迭代加载数据集,并根据每次迭代的时间来预估加载完整个数据集所需要的总时间。

# 函数实现预估加载完整个迭代器所需要的时间。具体原理:
# step1:使用yield关键字,每次加载一部分数据集,统计这部分数据集的平均单个数据集的使用时间delta_t = 花费的时间/该部分数据集样本数
# step2:根据迭代器长度,预估加载整个数据集所花时间 t_dataset = delta_t * 数据集长度
def enumerateWithEstimate(
        iter,           # 数据集的一个迭代器。函数目的就是统计加载完整个数据集所需要的时间。
        desc_str,       # 打印log的时候的说明文本。自己随便定义就行。
        start_ndx=0,    # 开始统计前跳过的统计此时。比如start_ndx=3,则意思是第1,2次统计不打印,第三次开始打印。
        print_ndx=4,    # 相邻两次打印日志的统计次数间隔print_ndx = print_ndx * backoff,缺省的初始值为4
        backoff=None,   # 相邻两次打印日志的统计次数间隔的倍数。print_ndx = print_ndx * backoff
        iter_len=None,  # 迭代器的长度,不指定时,iter_len = len(iter)
):
    """
    In terms of behavior, `enumerateWithEstimate` is almost identical
    to the standard `enumerate` (the differences are things like how
    our function returns a generator, while `enumerate` returns a
    specialized `<enumerate object at 0x...>`).

    However, the side effects (logging, specifically) are what make the
    function interesting.

    :param iter: `iter` is the iterable that will be passed into
        `enumerate`. Required.

    :param desc_str: This is a human-readable string that describes
        what the loop is doing. The value is arbitrary, but should be
        kept reasonably short. Things like `"epoch 4 training"` or
        `"deleting temp files"` or similar would all make sense.

    :param start_ndx: This parameter defines how many iterations of the
        loop should be skipped before timing actually starts. Skipping
        a few iterations can be useful if there are startup costs like
        caching that are only paid early on, resulting in a skewed
        average when those early iterations dominate the average time
        per iteration.

        NOTE: Using `start_ndx` to skip some iterations makes the time
        spent performing those iterations not be included in the
        displayed duration. Please account for this if you use the
        displayed duration for anything formal.

        This parameter defaults to `0`.

    :param print_ndx: determines which loop interation that the timing
        logging will start on. The intent is that we don't start
        logging until we've given the loop a few iterations to let the
        average time-per-iteration a chance to stablize a bit. We
        require that `print_ndx` not be less than `start_ndx` times
        `backoff`, since `start_ndx` greater than `0` implies that the
        early N iterations are unstable from a timing perspective.

        `print_ndx` defaults to `4`.

    :param backoff: This is used to how many iterations to skip before
        logging again. Frequent logging is less interesting later on,
        so by default we double the gap between logging messages each
        time after the first.

        `backoff` defaults to `2` unless iter_len is > 1000, in which
        case it defaults to `4`.

    :param iter_len: Since we need to know the number of items to
        estimate when the loop will finish, that can be provided by
        passing in a value for `iter_len`. If a value isn't provided,
        then it will be set by using the value of `len(iter)`.

    :return:
    """
    if iter_len is None:
        iter_len = len(iter)

    if backoff is None:
        backoff = 2
        while backoff ** 7 < iter_len:
            backoff *= 2

    assert backoff >= 2
    while print_ndx < start_ndx * backoff:
        print_ndx *= backoff

    log.warning("{} ----/{}, starting".format(
        desc_str,
        iter_len,
    ))
    start_ts = time.time()
    for (current_ndx, item) in enumerate(iter):
        yield (current_ndx, item)
        if current_ndx == print_ndx:
            # ... <1> step1:计算若干隔数据集加载时间;step2:平均得到每个数据集加载时间;step3:乘以数据集长度得到预计加载所有数据的时间
            duration_sec = ((time.time() - start_ts)
                            / (current_ndx - start_ndx + 1)
                            * (iter_len-start_ndx)
                            )

            done_dt = datetime.datetime.fromtimestamp(start_ts + duration_sec)
            done_td = datetime.timedelta(seconds=duration_sec)

            log.info("{} {:-4}/{}, done at {}, {}".format(
                desc_str,
                current_ndx,
                iter_len,
                str(done_dt).rsplit('.', 1)[0],     # 运行了current_ndx次后,预估的加载完整个数据集后的系统时间
                str(done_td).rsplit('.', 1)[0],     # 运行了current_ndx次后,预估的加载完整个数据集所需要的秒数
            ))

            print_ndx *= backoff

        if current_ndx + 1 == start_ndx:
            start_ts = time.time()

    log.warning("{} ----/{}, done at {}".format(
        desc_str,
        iter_len,
        str(datetime.datetime.now()).rsplit('.', 1)[0],
    ))

3. prepcahe.py

这个脚本用来尝试加载整个数据集,测试加载数据集所需要的时间。核心时调用enumerateWithEstimate函数。

import argparse     # 参数解释器
import sys

import numpy as np

import torch.nn as nn
from torch.autograd import Variable
from torch.optim import SGD
from torch.utils.data import DataLoader

from util.util import enumerateWithEstimate
from .dsets import LunaDataset
from util.logconf import logging
from .model import LunaModel

log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
log.setLevel(logging.INFO)
# log.setLevel(logging.DEBUG)


class LunaPrepCacheApp:
    @classmethod
    def __init__(self, sys_argv=None):
        if sys_argv is None:
            sys_argv = sys.argv[1:]

        parser = argparse.ArgumentParser()      # 命令行参数修饰器
        parser.add_argument('--batch-size',     # 添加参数
            help='Batch size to use for training',
            default=1024,
            type=int,
        )
        parser.add_argument('--num-workers',
            help='Number of worker processes for background data loading',
            default=8,
            type=int,
        )

        self.cli_args = parser.parse_args(sys_argv)     # 解释参数

    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        self.prep_dl = DataLoader(
            LunaDataset(
                sortby_str='series_uid',
            ),
            batch_size=self.cli_args.batch_size,
            num_workers=self.cli_args.num_workers,
        )

        batch_iter = enumerateWithEstimate(     # 尝试加载数据集,预估加载整个数据集所需时间
            self.prep_dl,
            "Stuffing cache",
            start_ndx=self.prep_dl.num_workers,
        )
        for _ in batch_iter:
            pass


if __name__ == '__main__':
    LunaPrepCacheApp().main()   # 对类的__init__函数使用了@classmethod修饰器,所以可以不需要实例化,直接调用类内函数

在jupyter运行方法可参考原书代码的【p2_run_everything.ipynb】的【chapter11-cell2】。具体运行方法:

step1:加载相关库和函数

step2:使用命令行形式调用LunaPrepCacheApp函数。

 运行结果:

从下图可见,数据集中一个551065个样本,每个batch有1024个样本,一共539个batch,加载16个batch后,推算出加载完所有batch的时间要6个小时05分。

4. training.py

注释了部分代码,其中部分关于tensorboard的代码注释放到第六篇文章的笔记。训练结果及代码如下:

import argparse
import datetime
import os
import sys

import numpy as np

from torch.utils.tensorboard import SummaryWriter

import torch
import torch.nn as nn
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader

from util.util import enumerateWithEstimate
from .dsets import LunaDataset
from util.logconf import logging
from .model import LunaModel

log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)

# Used for computeBatchLoss and logMetrics to index into metrics_t/metrics_a
# 将每个样本在训练时候的label、预测值、loss存在了一个矩阵,用于打印结果和tensorboard上显示
# 矩阵第一行为label,第二行为预测值,第三行为loss值,每一列为一个样本
METRICS_LABEL_NDX=0     # label的行索引
METRICS_PRED_NDX=1      # 预测值行索引
METRICS_LOSS_NDX=2      # loss值行索引
METRICS_SIZE = 3    # 矩阵行数

class LunaTrainingApp:
    def __init__(self, sys_argv=None):
        if sys_argv is None:
            sys_argv = sys.argv[1:]

        parser = argparse.ArgumentParser()
        parser.add_argument('--num-workers',
            help='Number of worker processes for background data loading',
            default=6,      # 使用的CPU核心数,我用的i5-12490f为6核
            type=int,
        )
        parser.add_argument('--batch-size',
            help='Batch size to use for training',
            default=24,     # 每个batch样本数
            type=int,
        )
        parser.add_argument('--epochs',
            help='Number of epochs to train for',
            default=1,      # 训练的代数
            type=int,
        )

        parser.add_argument('--tb-prefix',
            default='p2ch11',
            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
        )

        parser.add_argument('comment',
            help="Comment suffix for Tensorboard run.",
            nargs='?',
            default='dwlpt',
        )
        self.cli_args = parser.parse_args(sys_argv)
        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')

        self.trn_writer = None
        self.val_writer = None
        self.totalTrainingSamples_count = 0

        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")

        self.model = self.initModel()       # 将模型搬到cuda
        self.optimizer = self.initOptimizer()       # 定义优化器

    def initModel(self):
        model = LunaModel()
        if self.use_cuda:
            log.info("Using CUDA; {} devices.".format(torch.cuda.device_count()))
            if torch.cuda.device_count() > 1:
                model = nn.DataParallel(model)      # 如果有多个gpu,分配多给GPU训练
            model = model.to(self.device)
        return model

    def initOptimizer(self):
        # 一般第一次训练用SGD看看效果,再选择其他优化器。比较常用参数为lr=0.001,momentum=0.99
        return SGD(self.model.parameters(), lr=0.001, momentum=0.99)        
        # return Adam(self.model.parameters())

    def initTrainDl(self):
        # 由于LunaDataset的getCtRawCandidate被diskcache修饰,所以第一次加载数据集时,需要从文件读取数据,
        # 同时数据处理后会缓存到磁盘,速度较慢;第二次开始,会直接从缓存加载,速度会较快。
        train_ds = LunaDataset(
            val_stride=10,
            isValSet_bool=False,
        )

        batch_size = self.cli_args.batch_size
        if self.use_cuda:
            batch_size *= torch.cuda.device_count()

        train_dl = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,
        )

        return train_dl

    def initValDl(self):
        val_ds = LunaDataset(
            val_stride=10,
            isValSet_bool=True,
        )

        batch_size = self.cli_args.batch_size
        if self.use_cuda:
            batch_size *= torch.cuda.device_count()

        val_dl = DataLoader(
            val_ds,
            batch_size=batch_size,
            num_workers=self.cli_args.num_workers,
            pin_memory=self.use_cuda,
        )

        return val_dl

    def initTensorboardWriters(self):
        if self.trn_writer is None:
            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)

            self.trn_writer = SummaryWriter(
                log_dir=log_dir + '-trn_cls-' + self.cli_args.comment)
            self.val_writer = SummaryWriter(
                log_dir=log_dir + '-val_cls-' + self.cli_args.comment)


    def main(self):
        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))

        train_dl = self.initTrainDl()
        val_dl = self.initValDl()

        for epoch_ndx in range(1, self.cli_args.epochs + 1):

            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
                epoch_ndx,
                self.cli_args.epochs,
                len(train_dl),
                len(val_dl),
                self.cli_args.batch_size,
                (torch.cuda.device_count() if self.use_cuda else 1),
            ))

            trnMetrics_t = self.doTraining(epoch_ndx, train_dl)
            self.logMetrics(epoch_ndx, 'trn', trnMetrics_t)

            valMetrics_t = self.doValidation(epoch_ndx, val_dl)
            self.logMetrics(epoch_ndx, 'val', valMetrics_t)

        if hasattr(self, 'trn_writer'):
            self.trn_writer.close()
            self.val_writer.close()


    def doTraining(self, epoch_ndx, train_dl):
        self.model.train()
        trnMetrics_g = torch.zeros(
            METRICS_SIZE,
            len(train_dl.dataset),
            device=self.device,
        )

        # batch_iter = enumerateWithEstimate(
        #     train_dl,
        #     "E{} Training".format(epoch_ndx),
        #     start_ndx=train_dl.num_workers,
        # )
        for batch_ndx, batch_tup in enumerate(train_dl):
            self.optimizer.zero_grad()

            loss_var = self.computeBatchLoss(
                batch_ndx,
                batch_tup,
                train_dl.batch_size,
                trnMetrics_g
            )

            loss_var.backward()
            self.optimizer.step()

            # # This is for adding the model graph to TensorBoard.
            # if epoch_ndx == 1 and batch_ndx == 0:
            #     with torch.no_grad():
            #         model = LunaModel()
            #         self.trn_writer.add_graph(model, batch_tup[0], verbose=True)
            #         self.trn_writer.close()

        self.totalTrainingSamples_count += len(train_dl.dataset)

        return trnMetrics_g.to('cpu')


    def doValidation(self, epoch_ndx, val_dl):
        with torch.no_grad():
            self.model.eval()
            valMetrics_g = torch.zeros(
                METRICS_SIZE,
                len(val_dl.dataset),
                device=self.device,
            )

            batch_iter = enumerateWithEstimate(
                val_dl,
                "E{} Validation ".format(epoch_ndx),
                start_ndx=val_dl.num_workers,
            )
            for batch_ndx, batch_tup in batch_iter:
                self.computeBatchLoss(
                    batch_ndx, batch_tup, val_dl.batch_size, valMetrics_g)

        return valMetrics_g.to('cpu')



    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_g):
        input_t, label_t, _series_list, _center_list = batch_tup

        input_g = input_t.to(self.device, non_blocking=True)
        label_g = label_t.to(self.device, non_blocking=True)

        logits_g, probability_g = self.model(input_g)

        loss_func = nn.CrossEntropyLoss(reduction='none')   # reduction=none时,将每个样本的loss返回
        loss_g = loss_func(
            logits_g,
            label_g[:,1],
        )
        start_ndx = batch_ndx * batch_size
        end_ndx = start_ndx + label_t.size(0)

        # 将训练结果存到矩阵
        metrics_g[METRICS_LABEL_NDX, start_ndx:end_ndx] = \
            label_g[:,1].detach()
        metrics_g[METRICS_PRED_NDX, start_ndx:end_ndx] = \
            probability_g[:,1].detach()
        metrics_g[METRICS_LOSS_NDX, start_ndx:end_ndx] = \
            loss_g.detach()

        return loss_g.mean()


    def logMetrics(
            self,
            epoch_ndx,
            mode_str,
            metrics_t,
            classificationThreshold=0.5,
    ):
        self.initTensorboardWriters()
        log.info("E{} {}".format(
            epoch_ndx,
            type(self).__name__,
        ))

        negLabel_mask = metrics_t[METRICS_LABEL_NDX] <= classificationThreshold
        negPred_mask = metrics_t[METRICS_PRED_NDX] <= classificationThreshold

        posLabel_mask = ~negLabel_mask
        posPred_mask = ~negPred_mask

        neg_count = int(negLabel_mask.sum())
        pos_count = int(posLabel_mask.sum())

        neg_correct = int((negLabel_mask & negPred_mask).sum())
        pos_correct = int((posLabel_mask & posPred_mask).sum())

        metrics_dict = {}
        metrics_dict['loss/all'] = \
            metrics_t[METRICS_LOSS_NDX].mean()
        metrics_dict['loss/neg'] = \
            metrics_t[METRICS_LOSS_NDX, negLabel_mask].mean()
        metrics_dict['loss/pos'] = \
            metrics_t[METRICS_LOSS_NDX, posLabel_mask].mean()

        metrics_dict['correct/all'] = (pos_correct + neg_correct) \
            / np.float32(metrics_t.shape[1]) * 100
        metrics_dict['correct/neg'] = neg_correct / np.float32(neg_count) * 100
        metrics_dict['correct/pos'] = pos_correct / np.float32(pos_count) * 100

        log.info(
            ("E{} {:8} {loss/all:.4f} loss, "
                 + "{correct/all:-5.1f}% correct, "
            ).format(
                epoch_ndx,
                mode_str,
                **metrics_dict,
            )
        )
        log.info(
            ("E{} {:8} {loss/neg:.4f} loss, "
                 + "{correct/neg:-5.1f}% correct ({neg_correct:} of {neg_count:})"
            ).format(
                epoch_ndx,
                mode_str + '_neg',
                neg_correct=neg_correct,
                neg_count=neg_count,
                **metrics_dict,
            )
        )
        log.info(
            ("E{} {:8} {loss/pos:.4f} loss, "
                 + "{correct/pos:-5.1f}% correct ({pos_correct:} of {pos_count:})"
            ).format(
                epoch_ndx,
                mode_str + '_pos',
                pos_correct=pos_correct,
                pos_count=pos_count,
                **metrics_dict,
            )
        )

        writer = getattr(self, mode_str + '_writer')

        for key, value in metrics_dict.items():
            writer.add_scalar(key, value, self.totalTrainingSamples_count)

        writer.add_pr_curve(
            'pr',
            metrics_t[METRICS_LABEL_NDX],
            metrics_t[METRICS_PRED_NDX],
            self.totalTrainingSamples_count,
        )

        bins = [x/50.0 for x in range(51)]

        negHist_mask = negLabel_mask & (metrics_t[METRICS_PRED_NDX] > 0.01)
        posHist_mask = posLabel_mask & (metrics_t[METRICS_PRED_NDX] < 0.99)

        if negHist_mask.any():
            writer.add_histogram(
                'is_neg',
                metrics_t[METRICS_PRED_NDX, negHist_mask],
                self.totalTrainingSamples_count,
                bins=bins,
            )
        if posHist_mask.any():
            writer.add_histogram(
                'is_pos',
                metrics_t[METRICS_PRED_NDX, posHist_mask],
                self.totalTrainingSamples_count,
                bins=bins,
            )

        # score = 1 \
        #     + metrics_dict['pr/f1_score'] \
        #     - metrics_dict['loss/mal'] * 0.01 \
        #     - metrics_dict['loss/all'] * 0.0001
        #
        # return score

    # def logModelMetrics(self, model):
    #     writer = getattr(self, 'trn_writer')
    #
    #     model = getattr(model, 'module', model)
    #
    #     for name, param in model.named_parameters():
    #         if param.requires_grad:
    #             min_data = float(param.data.min())
    #             max_data = float(param.data.max())
    #             max_extent = max(abs(min_data), abs(max_data))
    #
    #             # bins = [x/50*max_extent for x in range(-50, 51)]
    #
    #             try:
    #                 writer.add_histogram(
    #                     name.rsplit('.', 1)[-1] + '/' + name,
    #                     param.data.cpu().numpy(),
    #                     # metrics_a[METRICS_PRED_NDX, negHist_mask],
    #                     self.totalTrainingSamples_count,
    #                     # bins=bins,
    #                 )
    #             except Exception as e:
    #                 log.error([min_data, max_data])
    #                 raise


if __name__ == '__main__':
    LunaTrainingApp().main()

  • 4
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值