【Tensorflow教程笔记】常用模块 :变量的保存与恢复

时间:2025-02-26 11:40:26
  1. 基础
    TensorFlow 基础
    TensorFlow 模型建立与训练
    基础示例:多层感知机(MLP)
    卷积神经网络(CNN)
    循环神经网络(RNN)
    深度强化学习(DRL)
    Keras Pipeline
    自定义层、损失函数和评估指标
    常用模块 :变量的保存与恢复
    常用模块 TensorBoard:训练过程可视化
    常用模块 :数据集的构建与预处理
    常用模块 TFRecord :TensorFlow 数据集存储格式
    常用模块 :图执行模式
    常用模块 :TensorFlow 动态数组
    常用模块 :GPU 的使用与分配

  2. 部署
    TensorFlow 模型导出
    TensorFlow Serving
    TensorFlow Lite

  3. 大规模训练与加速
    TensorFlow 分布式训练
    使用 TPU 训练 TensorFlow 模型

  4. 扩展
    TensorFlow Hub 模型复用
    TensorFlow Datasets 数据集载入

  5. 附录
    强化学习基础简介


目录

  • 保存参数
  • 载入之前保存的参数
  • 保存变量+恢复变量
  • `` VS ``
  • 实例
  • 使用 `` 删除旧的 Checkpoint 以及自定义文件编号

Checkpoint 只保存模型的参数,不保存模型的计算过程,因此一般用于在具有模型源代码的时候恢复之前训练好的模型参数。如果需要导出模型(无需源代码也能运行模型),请参考 “部署” 章节中的 SavedModel

很多时候,我们希望在模型训练完成后能将训练好的参数(变量)保存起来。在需要使用模型的其他地方载入模型和参数,就能直接得到训练好的模型。可能你第一个想到的是用 Python 的序列化模块 pickle 存储 。但不幸的是,TensorFlow 的变量类型 ResourceVariable 并不能被序列化。

好在 TensorFlow 提供了 这一强大的变量保存与恢复类,可以使用其 save()restore() 方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言, 或者 实例都可以被保存。其使用方法非常简单,我们首先声明一个 Checkpoint:

checkpoint = tf.train.Checkpoint(model=model)

这里 () 接受的初始化参数比较特殊,是一个 **kwargs 。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。例如,如果我们希望保存一个继承 的模型实例 model 和一个继承 的优化器 optimizer ,我们可以这样写:

checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)

这里 myAwesomeModel 是我们为待保存的模型 model 所取的任意键名。注意,在恢复变量的时候,我们还将使用这一键名。

保存参数

接下来,当模型训练完成需要保存的时候,使用:

checkpoint.save(save_path_with_prefix)

就可以。 save_path_with_prefix 是保存文件的目录 + 前缀

  • 例如,在源代码目录建立一个名为 save 的文件夹并调用一次 ('./save/') ,我们就可以在 save 目录下发现名为 checkpoint-00000-of-00001 的三个文件,这些文件就记录了变量信息。() 方法可以运行多次,每运行一次都会得到一个 .index 文件和 .data 文件,序号依次累加。

载入之前保存的参数

当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。就像下面这样:

model_to_be_restored = MyModel()                                        # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)   # 键名保持为“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)

即可恢复模型变量。 save_path_with_prefix_and_index 是之前保存的文件的目录 + 前缀 + 编号

  • 例如,调用 ('./save/-1') 就可以载入前缀为 ,序号为 1 的文件来恢复模型。

当保存了多个文件时,我们往往想载入最近的一个。可以使用 .latest_checkpoint(save_path) 这个辅助函数返回目录下最近一次 checkpoint 的文件名。

  • 例如如果 save 目录下有 的 10 个保存文件, .latest_checkpoint('./save') 即返回 ./save/-10

保存变量+恢复变量

总体而言,恢复与保存变量的典型代码框架如下:

#  模型训练阶段

model = MyMod