-
基础
TensorFlow 基础
TensorFlow 模型建立与训练
基础示例:多层感知机(MLP)
卷积神经网络(CNN)
循环神经网络(RNN)
深度强化学习(DRL)
Keras Pipeline
自定义层、损失函数和评估指标
常用模块 :变量的保存与恢复
常用模块 TensorBoard:训练过程可视化
常用模块 :数据集的构建与预处理
常用模块 TFRecord :TensorFlow 数据集存储格式
常用模块 :图执行模式
常用模块 :TensorFlow 动态数组
常用模块 :GPU 的使用与分配 -
部署
TensorFlow 模型导出
TensorFlow Serving
TensorFlow Lite -
大规模训练与加速
TensorFlow 分布式训练
使用 TPU 训练 TensorFlow 模型 -
扩展
TensorFlow Hub 模型复用
TensorFlow Datasets 数据集载入 -
附录
强化学习基础简介
目录
- 保存参数
- 载入之前保存的参数
- 保存变量+恢复变量
- `` VS ``
- 实例
- 使用 `` 删除旧的 Checkpoint 以及自定义文件编号
Checkpoint 只保存模型的参数,不保存模型的计算过程,因此一般用于在具有模型源代码的时候恢复之前训练好的模型参数。如果需要导出模型(无需源代码也能运行模型),请参考 “部署” 章节中的 SavedModel 。
很多时候,我们希望在模型训练完成后能将训练好的参数(变量)保存起来。在需要使用模型的其他地方载入模型和参数,就能直接得到训练好的模型。可能你第一个想到的是用 Python 的序列化模块 pickle
存储 。但不幸的是,TensorFlow 的变量类型
ResourceVariable
并不能被序列化。
好在 TensorFlow 提供了 这一强大的变量保存与恢复类,可以使用其
save()
和 restore()
方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言, 、
、
或者
实例都可以被保存。其使用方法非常简单,我们首先声明一个 Checkpoint:
checkpoint = tf.train.Checkpoint(model=model)
这里 ()
接受的初始化参数比较特殊,是一个 **kwargs
。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。例如,如果我们希望保存一个继承 的模型实例
model
和一个继承 的优化器
optimizer
,我们可以这样写:
checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)
这里 myAwesomeModel
是我们为待保存的模型 model
所取的任意键名。注意,在恢复变量的时候,我们还将使用这一键名。
保存参数
接下来,当模型训练完成需要保存的时候,使用:
checkpoint.save(save_path_with_prefix)
就可以。 save_path_with_prefix
是保存文件的目录 + 前缀。
- 例如,在源代码目录建立一个名为 save 的文件夹并调用一次
('./save/')
,我们就可以在 save 目录下发现名为checkpoint
、、
-00000-of-00001
的三个文件,这些文件就记录了变量信息。()
方法可以运行多次,每运行一次都会得到一个 .index 文件和 .data 文件,序号依次累加。
载入之前保存的参数
当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。就像下面这样:
model_to_be_restored = MyModel() # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored) # 键名保持为“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)
即可恢复模型变量。 save_path_with_prefix_and_index
是之前保存的文件的目录 + 前缀 + 编号。
- 例如,调用
('./save/-1')
就可以载入前缀为 ,序号为 1 的文件来恢复模型。
当保存了多个文件时,我们往往想载入最近的一个。可以使用 .latest_checkpoint(save_path)
这个辅助函数返回目录下最近一次 checkpoint 的文件名。
- 例如如果 save 目录下有
到
的 10 个保存文件,
.latest_checkpoint('./save')
即返回./save/-10
。
保存变量+恢复变量
总体而言,恢复与保存变量的典型代码框架如下:
# 模型训练阶段
model = MyMod