本文为大家分享了TensorFLow用Saver保存和恢复变量的具体代码,供大家参考,具体内容如下
建立文件tensor_save.py, 保存变量v1,v2的tensor到checkpoint files中,名称分别设置为v3,v4。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
import tensorflow as tf
# Create some variables.
v1 = tf.Variable( 3 , name = "v1" )
v2 = tf.Variable( 4 , name = "v2" )
# Create model
y = tf.add(v1,v2)
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()
# Add ops to save and restore all the variables.
saver = tf.train.Saver({ 'v3' :v1, 'v4' :v2})
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
print ( "v1 = " , v1. eval ())
print ( "v2 = " , v2. eval ())
# Save the variables to disk.
save_path = saver.save(sess, "f:/tmp/model.ckpt" )
print ( "Model saved in file: " , save_path)
|
建立文件tensor_restror.py, 将checkpoint files中名称分别为v3,v4的tensor分别恢复到变量v3,v4中。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
import tensorflow as tf
# Create some variables.
v3 = tf.Variable( 0 , name = "v3" )
v4 = tf.Variable( 0 , name = "v4" )
# Create model
y = tf.mul(v3,v4)
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "f:/tmp/model.ckpt" )
print ( "Model restored." )
print ( "v3 = " , v3. eval ())
print ( "v4 = " , v4. eval ())
print ( "y = " ,sess.run(y))
|
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。
原文链接:http://blog.csdn.net/muyiyushan/article/details/68486497