深度学习 Day27——J6ResNeXt-50实战解析

时间:2024-01-24 09:23:25
  • ???? 本文为????365天深度学习训练营 中的学习记录博客
  • ???? 原作者:K同学啊 | 接辅导、项目定制
  • ???? 文章来源:K同学的学习圈子

文章目录

  • 前言
  • 1 我的环境
  • 2 pytorch实现DenseNet算法
    • 2.1 前期准备
      • 2.1.1 引入库
      • 2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU)
      • 2.1.3 导入数据
      • 2.1.4 可视化数据
      • 2.1.4 图像数据变换
      • 2.1.4 划分数据集
      • 2.1.4 加载数据
      • 2.1.4 查看数据
    • 2.2 搭建ResNeXt50模型
    • 2.3 训练模型
      • 2.3.1 设置超参数
      • 2.3.2 编写训练函数
      • 2.3.3 编写测试函数
      • 2.3.4 正式训练
    • 2.4 结果可视化
    • 2.4 指定图片进行预测
    • 2.6 模型评估
  • 3 tensorflow实现DenseNet算法
    • 3.1.引入库
    • 3.2.设置GPU(如果使用的是CPU可以忽略这步)
    • 3.3.导入数据
    • 3.4.查看数据
    • 3.5.加载数据
    • 3.6.再次检查数据
    • 3.7.配置数据集
    • 3.8.可视化数据
    • 3.9.构建ResNeXt50网络
    • 3.10.编译模型
    • 3.11.训练模型
    • 3.12.模型评估
    • 3.13.图像预测
  • 4 知识点详解
    • 4.1ResNeXt50详解
    • 4.2 ResNeXt50对比ResNet50V2、DenseNet
      • 4.2.1 网络结构
      • 4.2.2 精度和计算量
      • 4.2.3 适用范围
  • 4 总结


前言

关键字: pytorch实现ResNeXt50详解算法,tensorflow实现ResNeXt50详解算法,ResNeXt50详解

1 我的环境

  • 电脑系统:Windows 11
  • 语言环境:python 3.8.6
  • 编译器:pycharm2020.2.3
  • 深度学习环境:
    torch == 1.9.1+cu111
    torchvision == 0.10.1+cu111
    TensorFlow 2.10.1
  • 显卡:NVIDIA GeForce RTX 4070

2 pytorch实现DenseNet算法

2.1 前期准备

2.1.1 引入库


import torch
import torch.nn as nn
import time
import copy
from torchvision import transforms, datasets
from pathlib import Path
from PIL import Image
import torchsummary as summary
import torch.nn.functional as F
from collections import OrderedDict
import re
import torch.utils.model_zoo as model_zoo
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率
import warnings

warnings.filterwarnings('ignore')  # 忽略一些warning内容,无需打印

2.1.2 设置GPU(如果设备上支持GPU就使用GPU,否则使用CPU)

"""前期准备-设置GPU"""
# 如果设备上支持GPU就使用GPU,否则使用CPU
 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 print("Using {} device".format(device))

输出

Using cuda device

2.1.3 导入数据

'''前期工作-导入数据'''
data_dir = r"D:\DeepLearning\data\monkeypox_recognition"
data_dir = Path(data_dir)

data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[-1] for path in data_paths]
print(classeNames)

输出

['Monkeypox', 'Others']

2.1.4 可视化数据

'''前期工作-可视化数据'''
subfolder = Path(data_dir) / "Monkeypox"
image_files = list(p.resolve() for p in subfolder.glob('*') if p.suffix in [".jpg", ".png", ".jpeg"])
plt.figure(figsize=(10, 6))
for i in range(len(image_files[:12])):
    image_file = image_files[i]
    ax = plt.subplot(3, 4, i + 1)
    img = Image.open(str(image_file))
    plt.imshow(img)
    plt.axis("off")
# 显示图片
plt.tight_layout()
plt.show()

在这里插入图片描述

2.1.4 图像数据变换

'''前期工作-图像数据变换'''
total_datadir = data_dir

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])
total_data = datasets.ImageFolder(total_datadir, transform=train_transforms)
print(total_data)
print(total_data.class_to_idx)

输出

Dataset ImageFolder
    Number of datapoints: 2142
    Root location: D:\DeepLearning\data\monkeypox_recognition
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )
{'Monkeypox': 0, 'Others': 1}

2.1.4 划分数据集

'''前期工作-划分数据集'''
train_size = int(0.8 * len(total_data))  # train_size表示训练集大小,通过将总体数据长度的80%转换为整数得到;
test_size = len(total_data) - train_size  # test_size表示测试集大小,是总体数据长度减去训练集大小。
# 使用torch.utils.data.random_split()方法进行数据集划分。该方法将总体数据total_data按照指定的大小比例([train_size, test_size])随机划分为训练集和测试集,
# 并将划分结果分别赋值给train_dataset和test_dataset两个变量。
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print("train_dataset={}\ntest_dataset={}".format(train_dataset, test_dataset))
print("train_size={}\ntest_size={}".format(train_size, test_size))

输出

train_dataset=<torch.utils.data.dataset.Subset object at 0x000002A96E08E0D0>
test_dataset=<torch.utils.data.dataset.Subset object at 0x000002A96E04E640>
train_size=1713
test_size=429

2.1.4 加载数据

'''前期工作-加载数据'''
batch_size = 32

train_dl = torch.utils.data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=1)

2.1.4 查看数据

'''前期工作-查看数据'''
for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

输出

Shape of X [N, C, H, W]:  torch.Size([32, 3, 224, 224])
Shape of y:  torch.Size([32]) torch.int64

2.2 搭建ResNeXt50模型

"""构建ResNeXt50网络"""


class BN_Conv2d(nn.Module):
    """
    BN_CONV_RELU
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bias=False):
        super(BN_Conv2d, self).__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                      padding=padding, dilation=dilation, groups=groups, bias=bias),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return F.relu(self.seq(x))


class ResNeXt_Block(nn.Module):
    """
    ResNeXt block with group convolutions
    """

    def __init__(self, in_chnls, cardinality, group_depth, stride):
        super(ResNeXt_Block, self).__init__()
        self.group_chnls = cardinality * group_depth
        self.conv1 = BN_Conv2d(in_chnls, self.group_chnls, 1, stride=1, padding=0)
        self.conv2 = BN_Conv2d(self.group_chnls, self.group_chnls, 3, stride=stride, padding=1, groups=cardinality)
        self.conv3 = nn.Conv2d(self.group_chnls, self.group_chnls * 2, 1, stride=1, padding=0)
        self.bn = nn.BatchNorm2d(self.group_chnls * 2)
        self.short_cut = nn.Sequential(
            nn.Conv2d(in_chnls, self.group_chnls * 2, 1, stride, 0, bias=False),
            nn.BatchNorm2d(self.group_chnls * 2)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.bn(self.conv3(out))
        out += self.short_cut(x)
        return F.relu(out)


class ResNeXt(nn.Module):
    """
    ResNeXt builder
    """

    def __init__(self, layers: object, cardinality, group_depth, num_classes) -> object:
        super(ResNeXt, self).__init__()
        self.cardinality = cardinality
        self.channels = 64
        self.conv1 = BN_Conv2d(3, self.channels, 7, stride=2, padding=3)
        d1 = group_depth
        self.conv2 = self.___make_layers(d1, layers[0], stride=1)
        d2 = d1 * 2
        self.conv3 = self.___make_layers(d2, layers[1], stride=2)
        d3 = d2 * 2
        self.conv4 = self.___make_layers(d3, layers[2], stride=2)
        d4 = d3 * 2
        self.conv5 = self.___make_layers(d4, layers[3], stride=2)
        self.fc = nn.Linear(self.channels, num_classes)  # 224x224 input size

    def ___make_layers(self, d, blocks, stride):
        strides = [stride] + [1] * (blocks - 1)
        layers = []
        for stride in strides:
            layers.append(ResNeXt_Block(self.channels, self.cardinality, d, stride))
            self.channels = self.cardinality * d * 2
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = F.max_pool2d(out, 3, 2, 1)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = F.avg_pool2d(out, 7)
        out = out.view(out.size(0), -1)
        out = F.softmax(self.fc(out), dim=1)
        return out

该模型相比DenseNet的区别是,在最后一个denseblock后增加SE_layer。

# SE_layer
self.features.add_module('SE-module', Squeeze_excitation_layer(num_features))

输出

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
         BN_Conv2d-3         [-1, 64, 112, 112]               0
            Conv2d-4          [-1, 128, 56, 56]           8,192
       BatchNorm2d-5          [-1, 128, 56, 56]             256
         BN_Conv2d-6          [-1, 128, 56, 56]               0
            Conv2d-7          [-1, 128, 56, 56]           4,608
       BatchNorm2d-8          [-1, 128, 56, 56]             256
         BN_Conv2d-9          [-1, 128, 56, 56]               0
           Conv2d-10          [-1, 256, 56, 56]          33,024
      BatchNorm2d-11          [-1, 256, 56, 56]             512
           Conv2d-12          [-1, 256, 56, 56]          16,384
      BatchNorm2d-13          [-1, 256, 56, 56]             512
    ResNeXt_Block-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 128, 56, 56]          32,768
      BatchNorm2d-16          [-1, 128, 56, 56]             256
        BN_Conv2d-17          [-1, 128, 56, 56]               0
           Conv2d-18          [-1, 128, 56, 56]           4,608
      BatchNorm2d-19          [-1, 128, 56, 56]             256
        BN_Conv2d-20          [-1, 128, 56, 56]               0
           Conv2d-21          [-1, 256, 56, 56]          33,024
      BatchNorm2d-22          [-1, 256, 56, 56]             512
           Conv2d-23          [-1, 256, 56, 56]          65,536
      BatchNorm2d-24          [-1, 256, 56, 56]             512
    ResNeXt_Block-25          [-1, 256, 56, 56]               0
           Conv2d-26          [-1, 128, 56, 56]          32,768
      BatchNorm2d-27          [-1, 128, 56, 56]             256
        BN_Conv2d-28          [-1, 128, 56, 56]               0
           Conv2d-29          [-1, 128, 56, 56]           4,608
      BatchNorm2d-30          [-1, 128, 56, 56]             256
        BN_Conv2d-31          [-1, 128, 56, 56]               0
           Conv2d-32          [-1, 256, 56, 56]          33,024
      BatchNorm2d-33          [-1, 256, 56, 56]             512
           Conv2d-34          [-1, 256, 56, 56]          65,536
      BatchNorm2d-35          [-1, 256, 56, 56]             512
    ResNeXt_Block-36          [-1, 256, 56, 56]               0
           Conv2d-37          [-1, 256, 56, 56]          65,536
      BatchNorm2d-38          [-1, 256, 56, 56]             512
        BN_Conv2d-39          [-1, 256, 56, 56]               0
           Conv2d-40          [-1, 256, 28, 28]          18,432
      BatchNorm2d-41          [-1, 256, 28, 28]             512
        BN_Conv2d-42          [-1, 256, 28, 28]               0
           Conv2d-43          [-1, 512, 28, 28]         131,584
      BatchNorm2d-44          [-1, 512, 28, 28]           1,024
           Conv2d-45          [-1, 512, 28, 28]         131,072
      BatchNorm2d-46          [-1, 512, 28, 28]           1,024
    ResNeXt_Block-47          [-1, 512, 28, 28]               0
           Conv2d-48          [-1, 256, 28, 28]         131,072
      BatchNorm2d-49          [-1, 256, 28, 28]             512
        BN_Conv2d-50          [-1, 256, 28, 28]               0
           Conv2d-51          [-1, 256, 28, 28]          18,432
      BatchNorm2d-52          [-1, 256, 28, 28]             512
        BN_Conv2d-53          [-1, 256, 28, 28]               0
           Conv2d-54          [-1, 512, 28, 28]         131,584
      BatchNorm2d-55          [-1, 512, 28, 28]           1,024
           Conv2d-56          [-1, 512, 28, 28]         262,144
      BatchNorm2d-57          [-1, 512, 28, 28]           1,024
    ResNeXt_Block-58          [-1, 512, 28, 28]               0
           Conv2d-59          [-1, 256, 28, 28]         131,072
      BatchNorm2d-60          [-1, 256, 28, 28]             512
        BN_Conv2d-61          [-1, 256, 28, 28]               0
           Conv2d-62          [-1, 256, 28, 28]          18,432
      BatchNorm2d-63          [-1, 256, 28, 28]             512
        BN_Conv2d-64          [-1, 256, 28, 28]               0
           Conv2d-65          [-1, 512, 28, 28]         131,584
      BatchNorm2d-66          [-1, 512, 28, 28]           1,024
           Conv2d-67          [-1, 512, 28, 28]         262,144
      BatchNorm2d-68          [-1, 512, 28, 28]           1,024
    ResNeXt_Block-69          [-1, 512, 28, 28]               0
           Conv2d-70          [-1, 256, 28, 28]         131,072
      BatchNorm2d-71          [-1, 256, 28, 28]             512
        BN_Conv2d-72          [-1, 256, 28, 28]               0
           Conv2d-73          [-1, 256, 28, 28]          18,432
      BatchNorm2d-74          [-1, 256, 28, 28]             512
        BN_Conv2d-75          [-1, 256, 28, 28]               0
           Conv2d-76          [-1, 512, 28, 28]         131,584
      BatchNorm2d-77          [-1, 512, 28, 28]           1,024
           Conv2d-78          [-1, 512, 28, 28]         262,144
      BatchNorm2d-79          [-1, 512, 28, 28]           1,024
    ResNeXt_Block-80          [-1, 512, 28, 28]               0
           Conv2d-81          [-1, 512, 28, 28]         262,144
      BatchNorm2d-82          [-1, 512, 28, 28]           1,024
        BN_Conv2d-83          [-1, 512, 28, 28]               0
           Conv2d-84          [-1, 512, 14, 14]          73,728
      BatchNorm2d-85          [-1, 512, 14, 14]           1,024
        BN_Conv2d-86          [-1, 512, 14, 14]               0
           Conv2d-87         [-1, 1024, 14, 14]         525,312
      BatchNorm2d-88         [-1, 1024, 14, 14]           2,048
           Conv2d-89         [-1, 1024, 14, 14]         524,288
      BatchNorm2d-90         [-1, 1024, 14, 14]           2,048
    ResNeXt_Block-91         [-1, 1024, 14, 14]               0
           Conv2d-92          [-1, 512, 14, 14]         524,288
      BatchNorm2d-93          [-1, 512, 14, 14]           1,024
        BN_Conv2d-94          [-1, 512, 14, 14]               0
           Conv2d-95          [-1, 512, 14, 14]          73,728
      BatchNorm2d-96          [-1, 512, 14, 14]           1,024
        BN_Conv2d-97          [-1, 512, 14, 14]               0
           Conv2d-98         [-1, 1024, 14, 14]         525,312
      BatchNorm2d-99         [-1, 1024, 14, 14]           2,048
          Conv2d-100         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-101         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-102         [-1, 1024, 14, 14]               0
          Conv2d-103          [-1, 512, 14, 14]         524,288
     BatchNorm2d-104          [-1, 512, 14, 14]           1,024
       BN_Conv2d-105          [-1, 512, 14, 14]               0
          Conv2d-106          [-1, 512, 14, 14]          73,728
     BatchNorm2d-107          [-1, 512, 14, 14]           1,024
       BN_Conv2d-108          [-1, 512, 14, 14]               0
          Conv2d-109         [-1, 1024, 14, 14]         525,312
     BatchNorm2d-110         [-1, 1024, 14, 14]           2,048
          Conv2d-111         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-112         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-113         [-1, 1024, 14, 14]               0
          Conv2d-114          [-1, 512, 14, 14]         524,288
     BatchNorm2d-115          [-1, 512, 14, 14]           1,024
       BN_Conv2d-116          [-1, 512, 14, 14]               0
          Conv2d-117          [-1, 512, 14, 14]          73,728
     BatchNorm2d-118          [-1, 512, 14, 14]           1,024
       BN_Conv2d-119          [-1, 512, 14, 14]               0
          Conv2d-120         [-1, 1024, 14, 14]         525,312
     BatchNorm2d-121         [-1, 1024, 14, 14]           2,048
          Conv2d-122         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-123         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-124         [-1, 1024, 14, 14]               0
          Conv2d-125          [-1, 512, 14, 14]         524,288
     BatchNorm2d-126          [-1, 512, 14, 14]           1,024
       BN_Conv2d-127          [-1, 512, 14, 14]               0
          Conv2d-128          [-1, 512, 14, 14]          73,728
     BatchNorm2d-129          [-1, 512, 14, 14]           1,024
       BN_Conv2d-130          [-1, 512, 14, 14]               0
          Conv2d-131         [-1, 1024, 14, 14]         525,312
     BatchNorm2d-132         [-1, 1024, 14, 14]           2,048
          Conv2d-133         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-134         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-135         [-1, 1024, 14, 14]               0
          Conv2d-136          [-1, 512, 14, 14]         524,288
     BatchNorm2d-137          [-1, 512, 14, 14]           1,024
       BN_Conv2d-138          [-1, 512, 14, 14]               0
          Conv2d-139          [-1, 512, 14, 14]          73,728
     BatchNorm2d-140          [-1, 512, 14, 14]           1,024
       BN_Conv2d-141          [-1, 512, 14, 14]               0
          Conv2d-142         [-1, 1024, 14, 14]         525,312
     BatchNorm2d-143         [-1, 1024, 14, 14]           2,048
          Conv2d-144         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-145         [-1, 1024, 14, 14]           2,048
   ResNeXt_Block-146         [-1, 1024, 14, 14]               0
          Conv2d-147         [-1, 1024, 14, 14]       1,048,576
     BatchNorm2d-148         [-1, 1024, 14, 14]           2,048
       BN_Conv2d-149         [-1, 1024, 14, 14]               0
          Conv2d-150           [-1, 1024, 7, 7]         294,912
     BatchNorm2d-151           [-1, 1024, 7, 7]           2,048
       BN_Conv2d-152           [-1, 1024, 7, 7]               0
          Conv2d-153           [-1, 2048, 7, 7]       2,099,200
     BatchNorm2d-154           [-1, 2048, 7, 7]           4,096
          Conv2d-155           [-1, 2048, 7, 7]       2,097,152
     BatchNorm2d-156           [-1, 2048, 7, 7]           4,096
   ResNeXt_Block-157           [-1, 2048, 7, 7]               0
          Conv2d-158           [-1, 1024, 7, 7]       2,097,152
     BatchNorm2d-159           [-1, 1024, 7, 7]           2,048
       BN_Conv2d-160           [-1, 1024, 7, 7]               0
          Conv2d-161           [-1, 1024, 7, 7]         294,912
     BatchNorm2d-162           [-1, 1024, 7, 7]           2,048
       BN_Conv2d-163           [-1, 1024, 7, 7]               0
          Conv2d-164           [-1, 2048, 7, 7]       2,099,200
     BatchNorm2d-165           [-1, 2048, 7, 7]           4,096
          Conv2d-166           [-1, 2048, 7, 7]       4,194,304
     BatchNorm2d-167           [-1, 2048, 7, 7]           4,096
   ResNeXt_Block-168           [-1, 2048, 7, 7]               0
          Conv2d-169           [-1, 1024, 7, 7]       2,097,152
     BatchNorm2d-170           [-1, 1024, 7, 7]           2,048
       BN_Conv2d-171           [-1, 1024, 7, 7]               0
          Conv2d-172           [-1, 1024, 7, 7]         294,912
     BatchNorm2d-173           [-1, 1024, 7, 7]           2,048
       BN_Conv2d-174           [-1, 1024, 7, 7]               0
          Conv2d-175           [-1, 2048, 7, 7]       2,099,200
     BatchNorm2d-176           [-1, 2048, 7, 7]           4,096
          Conv2d-177           [-1, 2048, 7, 7]       4,194,304
     BatchNorm2d-178           [-1, 2048, 7, 7]           4,096
   ResNeXt_Block-179           [-1, 2048, 7, 7]               0
          Linear-180                    [-1, 4]           8,196
================================================================
Total params: 37,574,724
Trainable params: 37,574,724
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 379.37
Params size (MB): 143.34
Estimated Total Size (MB): 523.28
----------------------------------------------------------------
None

2.3 训练模型

2.3.1 设置超参数

"""训练模型--设置超参数"""
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数,计算实际输出和真实相差多少,交叉熵损失函数,事实上,它就是做图片分类任务时常用的损失函数
learn_rate = 1e-4  # 学习率
optimizer1 = torch.optim.SGD(model.parameters(), lr=learn_rate)# 作用是定义优化器,用来训练时候优化模型参数;其中,SGD表示随机梯度下降,用于控制实际输出y与真实y之间的相差有多大
optimizer2 = torch.optim.Adam(model.parameters(), lr=learn_rate)  
lr_opt = optimizer2
model_opt = optimizer2
# 调用官方动态学习率接口时使用2
lambda1 = lambda epoch : 0.92 ** (epoch // 4)
# optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(lr_opt, lr_lambda=lambda1) #选定调整方法

2.3.2 编写训练函数

"""训练模型--编写训练函数"""
# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片
    num_batches = len(dataloader)  # 批次数目,1875(60000/32)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率

    for X, y in dataloader:  # 加载数据加载器,得到里面的 X(图片数据)和 y(真实标签)
        X, y = X.to(device), y.to(device) # 用于将数据存到显卡

        # 计算预测误差
        pred = model(X)  # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失

        # 反向传播
        optimizer.zero_grad()  # 清空过往梯度
        loss.backward()  # 反向传播,计算当前梯度
        optimizer.step()  # 根据梯度更新网络参数

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches

    return train_acc, train_loss

2.3.3 编写测试函数

"""训练模型--编写测试函数"""
# 测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)  # 批次数目,313(10000/32=312.5,向上取整)
    test_loss, test_acc = 0, 0

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad(): # 测试时模型参数不用更新,所以 no_grad,整个模型参数正向推就ok,不反向更新参数
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to