Tensorflow学习笔记(三)模型的保存与加载以及再训练(二)
saver.save()模型的保存与加载以及再训练
声明: 参考链接官方文档
大家如果去网上搜索Tensorflow模型保存估计出现最多的就是这个saver.save()方法了!
保存
首先先创建一个Saver对象,如
saver=tf.train.Saver()
Saver()的构造函数有两个可选参数:
max_to_keep
表示要保留的模型文件的最大数量。创建新文件时,将删除旧文件。如果max_to_keep
为None或0,如
saver=tf.train.Saver(max_to_keep=0)
则不会从文件系统中删除旧文件,但只有最后一个检查点保留在checkpoint文件中。max_to_keep
参数默认为5(即保留最近的5个模型文件。)
keep_checkpoint_every_n_hours
:除了保留最新的 max_to_keep模型文件之外,您可能还希望每N小时的训练保留一个模型文件。如果您想稍后分析模型在长时间训练期间的进展情况,这将非常有用。例如,传递keep_checkpoint_every_n_hours=2
确保每2小时训练保留一个检查点文件。默认值10,000小时可有效禁用该功能。
创建完saver对象后,就可以保存训练好的模型了,如:
saver.save(sess,'./models/my_model',global_step= step)
第一个参数sess就不用说了,第二个参数是设定保存的路径和名字,第三个参数是用来对保存模型文件进行编号。官方有一个例程:
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
完整样例:
import tensorflow as tf # 以下所有代码默认导入
import os
# # 保存模型路径
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
MODEL_NAME = "my_model" # 模型命名
# 创建一个变量
one = tf.Variable(2.0)
# 创建一个占位符,在 Tensorflow 中需要定义 placeholder 的 type ,一般为 float32 形式
num = tf.placeholder(tf.float32,name='input')
# 创建一个加法步骤,注意这里并没有直接计算
sum = tf.add(num,one,name='output')
# 初始化变量,如果定义Variable就必须初始化
init = tf.global_variables_initializer()
# 创建saver对象
saver=tf.train.Saver()
# 创建会话sess
with tf.Session() as sess:
sess.run(init)
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME),global_step=0)
运行结果如下
这里os.path.join(MODEL_SAVE_PATH, MODEL_NAME)
只是把保存模型的路径和模型的名字合并起来相当于'./models/my_model'
保存后在models这个文件夹中实际会出现4个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。checkpoint
文件保存了一个目录下所有的模型文件列表my_model-0.data-00000-of-00001
文件保存了TensorFlow程序中每一个变量的取值my_model-0.index
文件保存了当前参数名my_model.meta
文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
加载
模型的加载调用用的是restore()
函数,它需要两个参数restore(sess, save_path)
,save_path
是保存的模型路径。
模型的加载有两种方法,第一种需要我们把模型的结构重新定义一次,然后载入对应名字的变量的值。这样显然不是我们所期望的,我们更希望能够直接使用而不是还要把模型重新定义一遍。所以就只介绍另一种方法。
直接上例子:
import tensorflow as tf # 以下所有代码默认导入
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
###模型调用###
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta') # 载入图结构,保存在.meta文件中
with tf.Session() as sess:
saver.restore(sess, ckpt.model_checkpoint_path)
in_x = sess.graph.get_tensor_by_name('input:0') #加载输入变量
y = sess.graph.get_tensor_by_name('output:0') #加载输出变量
scores = sess.run(y,feed_dict={in_x: 2.})
print(scores)
结果:
4.0
tf.train.get_checkpoint_state
函数通过checkpoint
文件找到模型文件名,因为checkpoint文件保存了一个目录下所有的模型文件列表。(该函数返回的是checkpoint文件CheckpointState proto类型的内容,其中有model_checkpoint_path和all_model_checkpoint_paths两个属性。其中model_checkpoint_path保存了最新的tensorflow模型文件的文件名,all_model_checkpoint_paths则有未被删除的所有tensorflow模型文件的文件名。)详见
这里大家可以以记事本方式打开checkpoint
文件查看一下
之前提到.meta
文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构所以这里直接读入tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
剩下的就跟上篇一样加载输入输出变量以及最重要的sess.run()
这样就完成了saver.save()模型的保存与加载
再训练
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH) # 通过 checkpoint 文件定位到最新保存的模型
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path) # 加载最新的模型
再训练只是训练和加载结合起来,在每次训练前先查询是否有之前训练的模型存在,如果存在就加载它的数据就OK了!
希望这篇文章对您有帮助,感谢阅读!