五、肺癌检测-数据集训练 training.py model.py

时间:2022-11-05 14:53:32

上一篇文章中已经通过将dsets.py实现将数据集封装加载,之后就可以通过建立了模型并编写training脚本实现模型的训练了。这一篇文章主要是对《pytorch深度学习实战》第11章内容做的笔记。

一、目标

1、建立简单的卷积神经网络

2、编写训练函数

3、编写训练日志(训练和验证过程的loss,accuracy等)数据结构

4、使用tensorboard可视化训练信息。

二、要点说明

1. 对函数使用通用的系统进程级别的调用

原书代码的【code/p2_run_everything.ipynb】的cell2中,定义了一个通用的系统进程方式的调用方法。通过这种方法可以调用所有脚本中的函数。但个人认为还是挺麻烦的,一点都不人性化。

def run(app, *argv):
    argv = list(argv)
    argv.insert(0, '--num-workers=4')  # <1> 使用4个核
    log.info("Running: {}({!r}).main()".format(app, argv))
    
    app_cls = importstr(*app.rsplit('.', 1))  # <2>    # 动态加载库
    app_cls(argv).main()    # 调用app类的main函数
    
    log.info("Finished: {}.{!r}).main()".format(app, argv))

使用示例:从p2ch11文件夹的training.py文件中importLunaTrainingApp类并调用其main函数,函数的输入参数是epochs=1。

run('p2ch11.training.LunaTrainingApp', '--epochs=1')

其中:

1.1 importstr函数

函数是为了实现动态调用各个库和库函数。类似于from 【pkg_name】 import 【func_name】的作用。通过importstr可以实现动态加载函数,而不用调用前用import声明。

1.2 rsplit函数

 函数用法:list = str.rsplit(sep, maxsplit)。可参考下面的文章。简单而言就是对字符【str】按照【sep】分隔符进行拆分,从字符右侧开始拆分,一共拆分【maxsplit】次。返回的是拆分结果是一个list。

Python实用语法之rsplit_明 总 有的博客-CSDN博客_python rsplit

2. 模型建立

书中在11章用的是简单的卷积堆叠+线性层的神经网络结果,没任何特别之处。其中线性层由于只是简单2分类(结节是否为肿瘤),所以只用了一个线性层。卷积和池化用的是3维的卷积和池化。

2.1 多GPU设置

多GPU训练可通过nn.DataParallel(model)或DistributedParallel函数实现,前者较为简单,一般用在单机多卡场景,后者配置较为复杂,一般用在多台计算机的多卡场景。

2.2 优化器

一般开始训练时可以先尝试使用带动量的SGD,lr=0.001,momentum=0.9,不行再换其他优化器,如Adam。

2.3 模型输入尺寸

在上一篇文章中的ct类介绍中,width_irc参数定义了每个在irc坐标系的尺寸大小。也是数据集输入到模型的input_size。

2.4 模型信息

使用torchinfo库或者torchsummary库的summary函数都可以打印模型的参数信息。具体方法如下:

from p2ch11.model import LunaModel
import torchinfo    # 安装命令conda install torchinfo

model = LunaModel()
torchinfo.summary(model, (1, 32, 48, 48), batch_dim=0,
                  col_names = ('input_size', 'output_size', 'num_params', 'kernel_size', 'mult_adds'), verbose = 1)

运行结果,即模型信息如下:

=====================================================================================================================================================================
Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape              Mult-Adds
=====================================================================================================================================================================
LunaModel                                [1, 1, 32, 48, 48]        [1, 2]                    --                        --                        --
├─BatchNorm3d: 1-1                       [1, 1, 32, 48, 48]        [1, 1, 32, 48, 48]        2                         --                        2
├─LunaBlock: 1-2                         [1, 1, 32, 48, 48]        [1, 8, 16, 24, 24]        --                        --                        --
│    └─Conv3d: 2-1                       [1, 1, 32, 48, 48]        [1, 8, 32, 48, 48]        224                       [3, 3, 3]                 16,515,072
│    └─ReLU: 2-2                         [1, 8, 32, 48, 48]        [1, 8, 32, 48, 48]        --                        --                        --
│    └─Conv3d: 2-3                       [1, 8, 32, 48, 48]        [1, 8, 32, 48, 48]        1,736                     [3, 3, 3]                 127,991,808
│    └─ReLU: 2-4                         [1, 8, 32, 48, 48]        [1, 8, 32, 48, 48]        --                        --                        --
│    └─MaxPool3d: 2-5                    [1, 8, 32, 48, 48]        [1, 8, 16, 24, 24]        --                        2                         --
├─LunaBlock: 1-3                         [1, 8, 16, 24, 24]        [1, 16, 8, 12, 12]        --                        --                        --
│    └─Conv3d: 2-6                       [1, 8, 16, 24, 24]        [1, 16, 16, 24, 24]       3,472                     [3, 3, 3]                 31,997,952
│    └─ReLU: 2-7                         [1, 16, 16, 24, 24]       [1, 16, 16, 24, 24]       --                        --                        --
│    └─Conv3d: 2-8                       [1, 16, 16, 24, 24]       [1, 16, 16, 24, 24]       6,928                     [3, 3, 3]                 63,848,448
│    └─ReLU: 2-9                         [1, 16, 16, 24, 24]       [1, 16, 16, 24, 24]       --                        --                        --
│    └─MaxPool3d: 2-10                   [1, 16, 16, 24, 24]       [1, 16, 8, 12, 12]        --                        2                         --
├─LunaBlock: 1-4                         [1, 16, 8, 12, 12]        [1, 32, 4, 6, 6]          --                        --                        --
│    └─Conv3d: 2-11                      [1, 16, 8, 12, 12]        [1, 32, 8, 12, 12]        13,856                    [3, 3, 3]                 15,962,112
│    └─ReLU: 2-12                        [1, 32, 8, 12, 12]        [1, 32, 8, 12, 12]        --                        --                        --
│    └─Conv3d: 2-13                      [1, 32, 8, 12, 12]        [1, 32, 8, 12, 12]        27,680                    [3, 3, 3]                 31,887,360
│    └─ReLU: 2-14                        [1, 32, 8, 12, 12]        [1, 32, 8, 12, 12]        --                        --                        --
│    └─MaxPool3d: 2-15                   [1, 32, 8, 12, 12]        [1, 32, 4, 6, 6]          --                        2                         --
├─LunaBlock: 1-5                         [1, 32, 4, 6, 6]          [1, 64, 2, 3, 3]          --                        --                        --
│    └─Conv3d: 2-16                      [1, 32, 4, 6, 6]          [1, 64, 4, 6, 6]          55,360                    [3, 3, 3]                 7,971,840
│    └─ReLU: 2-17                        [1, 64, 4, 6, 6]          [1, 64, 4, 6, 6]          --                        --                        --
│    └─Conv3d: 2-18                      [1, 64, 4, 6, 6]          [1, 64, 4, 6, 6]          110,656                   [3, 3, 3]                 15,934,464
│    └─ReLU: 2-19                        [1, 64, 4, 6, 6]          [1, 64, 4, 6, 6]          --                        --                        --
│    └─MaxPool3d: 2-20                   [1, 64, 4, 6, 6]          [1, 64, 2, 3, 3]          --                        2                         --
├─Linear: 1-6                            [1, 1152]                 [1, 2]                    2,306                     --                        2,306
├─Softmax: 1-7                           [1, 2]                    [1, 2]                    --                        --                        --
=====================================================================================================================================================================
Total params: 222,220
Trainable params: 222,220
Non-trainable params: 0
Total mult-adds (M): 312.11
=====================================================================================================================================================================
Input size (MB): 0.29
Forward/backward pass size (MB): 13.12
Params size (MB): 0.89
Estimated Total Size (MB): 14.31
=====================================================================================================================================================================

Process finished with exit code 0

3. 初始化

训练开始前,需要对权重进行初始化,初始化方法是通用的,具体参照书中代码【model.py】的_init_weights函数。

def _init_weights(self):
    for m in self.modules():
        if type(m) in {
            nn.Linear,
            nn.Conv3d,
            nn.Conv2d,
            nn.ConvTranspose2d,
            nn.ConvTranspose3d,
        }:
            nn.init.kaiming_normal_(
                m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
            )
            if m.bias is not None:
                fan_in, fan_out = \
                    nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                bound = 1 / math.sqrt(fan_out)
                nn.init.normal_(m.bias, -bound, bound)

三、关键函数

关键函数不多,基本都是torch的内置函数,剩下的都是一些方便查看训练结果的日志功能的函数。具体还是得看代码。

代码

1. 网络模型 model.py

代码如下:

import math

from torch import nn as nn

from util.logconf import logging

log = logging.getLogger(__name__)
# log.setLevel(logging.WARN)
# log.setLevel(logging.INFO)
log.setLevel(logging.DEBUG)


class LunaModel(nn.Module):
    def __init__(self, in_channels=1, conv_channels=8):
        super().__init__()

        self.tail_batchnorm = nn.BatchNorm3d(1)

        self.block1 = LunaBlock(in_channels, conv_channels)
        self.block2 = LunaBlock(conv_channels, conv_channels * 2)
        self.block3 = LunaBlock(conv_channels * 2, conv_channels * 4)
        self.block4 = LunaBlock(conv_channels * 4, conv_channels * 8)

        self.head_linear = nn.Linear(1152, 2)
        self.head_softmax = nn.Softmax(dim=1)

        self._init_weights()

    # see also https://github.com/pytorch/pytorch/issues/18182
    def _init_weights(self):
        for m in self.modules():
            if type(m) in {
                nn.Linear,
                nn.Conv3d,
                nn.Conv2d,
                nn.ConvTranspose2d,
                nn.ConvTranspose3d,
            }:
                nn.init.kaiming_normal_(
                    m.weight.data, a=0, mode='fan_out', nonlinearity='relu',
                )
                if m.bias is not None:
                    fan_in, fan_out = \
                        nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                    bound = 1 / math.sqrt(fan_out)
                    nn.init.normal_(m.bias, -bound, bound)



    def forward(self, input_batch):
        bn_output = self.tail_batchnorm(input_batch)

        block_out = self.block1(bn_output)
        block_out = self.block2(block_out)
        block_out = self.block3(block_out)
        block_out = self.block4(block_out)

        conv_flat = block_out.view(
            block_out.size(0),
            -1,
        )
        linear_output = self.head_linear(conv_flat)

        return linear_output, self.head_softmax(linear_output)


class LunaBlock(nn.Module):
    def __init__(self, in_channels, conv_channels):
        super().__init__()

        self.conv1 = nn.Conv3d(
            in_channels, conv_channels, kernel_size=3, padding=1, bias=True,
        )
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(
            conv_channels, conv_channels, kernel_size=3, padding=1, bias=True,
        )
        self.relu2 = nn.ReLU(inplace=True)

        self.maxpool = nn.MaxPool3d(2, 2)

    def forward(self, input_batch):
        block_out = self.conv1(input_batch)
        block_out = self.relu1(block_out)
        block_out = self.conv2(block_out)
        block_out = self.relu2(block_out)

        return self.maxpool(block_out)

后续再继续写