#coding=utf-8
import tensorflow as tf
from tensorflow.python.framework import graph_util
x = tf.placeholder(shape=[1], dtype=tf.float32, name='x')
varibale_1 = tf.get_variable('v1', [1], tf.float32, initializer=tf.random_normal_initializer(mean=1))
output = tf.multiply(x, varibale_1, name='mul')
initial_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(initial_op)
graph_def = tf.get_default_graph().as_graph_def()#将图定义取出
# print(graph_def)
out_graph = graph_util.convert_variables_to_constants(sess, graph_def, ['mul'])#将图中的变量转化为constant
print(sess.run(output,{x:[5]}))
print(sess.run(varibale_1))
with tf.gfile.GFile('./model.pb','wb') as f:
f.write(out_graph.SerializeToString())#将图定义转化为字符串形式并且写入.pb文件中
结果:
读取.pb文件:
#coding=utf-8
import tensorflow as tf
from tensorflow.python.platform import gfile
k = tf.constant([1, 2, 3], dtype=tf.float32)
with tf.Session() as sess:
model_filename = 'model.pb'
with gfile.FastGFile(model_filename, 'rb') as f:#打开.pb文件
graph_def = tf.GraphDef()#建立一个图定义类
print(graph_def)
graph_def.ParseFromString(f.read())#将.pb文件中的信息写入该图定义类
v1= tf.import_graph_def(graph_def, return_elements=[ 'v1:0'])#载入图定义,并返回感兴趣的值
print(tf.get_default_graph().as_graph_def())
print(tf.get_default_graph().get_tensor_by_name('import/x:0'))
print(v1.name)