Tensorflow学习笔记(三)模型的保存与加载以及再训练(二)

时间:2024-03-31 12:35:30

Tensorflow学习笔记(三)模型的保存与加载以及再训练(二)

saver.save()模型的保存与加载以及再训练

声明: 参考链接官方文档
大家如果去网上搜索Tensorflow模型保存估计出现最多的就是这个saver.save()方法了!

保存

首先先创建一个Saver对象,如

saver=tf.train.Saver()

Saver()的构造函数有两个可选参数:

max_to_keep 表示要保留的模型文件的最大数量。创建新文件时,将删除旧文件。如果max_to_keep 为None或0,如

saver=tf.train.Saver(max_to_keep=0)

则不会从文件系统中删除旧文件,但只有最后一个检查点保留在checkpoint文件中。max_to_keep 参数默认为5(即保留最近的5个模型文件。)

keep_checkpoint_every_n_hours:除了保留最新的 max_to_keep模型文件之外,您可能还希望每N小时的训练保留一个模型文件。如果您想稍后分析模型在长时间训练期间的进展情况,这将非常有用。例如,传递keep_checkpoint_every_n_hours=2确保每2小时训练保留一个检查点文件。默认值10,000小时可有效禁用该功能。

创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,'./models/my_model',global_step= step)

第一个参数sess就不用说了,第二个参数是设定保存的路径和名字,第三个参数是用来对保存模型文件进行编号。官方有一个例程:

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

完整样例:

import tensorflow as tf # 以下所有代码默认导入
import os
# # 保存模型路径
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
MODEL_NAME = "my_model" # 模型命名
# 创建一个变量
one = tf.Variable(2.0)
# 创建一个占位符,在 Tensorflow 中需要定义 placeholder 的 type ,一般为 float32 形式
num = tf.placeholder(tf.float32,name='input')
# 创建一个加法步骤,注意这里并没有直接计算
sum = tf.add(num,one,name='output')
# 初始化变量,如果定义Variable就必须初始化
init = tf.global_variables_initializer()
# 创建saver对象
saver=tf.train.Saver()
# 创建会话sess
with tf.Session() as sess:
    sess.run(init)
    saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME),global_step=0)

运行结果如下
Tensorflow学习笔记(三)模型的保存与加载以及再训练(二)
这里os.path.join(MODEL_SAVE_PATH, MODEL_NAME)只是把保存模型的路径和模型的名字合并起来相当于'./models/my_model'

保存后在models这个文件夹中实际会出现4个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。
checkpoint文件保存了一个目录下所有的模型文件列表
my_model-0.data-00000-of-00001文件保存了TensorFlow程序中每一个变量的取值
my_model-0.index文件保存了当前参数名
my_model.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构

加载

模型的加载调用用的是restore()函数,它需要两个参数restore(sess, save_path)save_path是保存的模型路径。

模型的加载有两种方法,第一种需要我们把模型的结构重新定义一次,然后载入对应名字的变量的值。这样显然不是我们所期望的,我们更希望能够直接使用而不是还要把模型重新定义一遍。所以就只介绍另一种方法。

直接上例子:

import tensorflow as tf # 以下所有代码默认导入
MODEL_SAVE_PATH = "./models/" # 保存模型的路径
###模型调用###
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')  # 载入图结构,保存在.meta文件中
with tf.Session() as sess:
    saver.restore(sess, ckpt.model_checkpoint_path)
    in_x = sess.graph.get_tensor_by_name('input:0')     #加载输入变量
    y = sess.graph.get_tensor_by_name('output:0')       #加载输出变量
    scores = sess.run(y,feed_dict={in_x: 2.})
    print(scores)

结果:

4.0

tf.train.get_checkpoint_state函数通过checkpoint文件找到模型文件名,因为checkpoint文件保存了一个目录下所有的模型文件列表。(该函数返回的是checkpoint文件CheckpointState proto类型的内容,其中有model_checkpoint_path和all_model_checkpoint_paths两个属性。其中model_checkpoint_path保存了最新的tensorflow模型文件的文件名,all_model_checkpoint_paths则有未被删除的所有tensorflow模型文件的文件名。)详见
这里大家可以以记事本方式打开checkpoint文件查看一下Tensorflow学习笔记(三)模型的保存与加载以及再训练(二)
之前提到.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构所以这里直接读入
tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
剩下的就跟上篇一样加载输入输出变量以及最重要的sess.run()这样就完成了saver.save()模型的保存与加载

再训练

ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)  # 通过 checkpoint 文件定位到最新保存的模型
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)  # 加载最新的模型

再训练只是训练和加载结合起来,在每次训练前先查询是否有之前训练的模型存在,如果存在就加载它的数据就OK了!

希望这篇文章对您有帮助,感谢阅读!