1. tensorflow模型文件打包成PB文件
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
with tf.Graph().as_default():
with tf.device( "/cpu:0" ):
config = tf.ConfigProto(allow_soft_placement = True )
with tf.Session(config = config).as_default() as sess:
model = Your_Model_Name()
model.build_graph()
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()
ckpt_path = "/your/model/path"
saver.restore(sess, ckpt_path)
graphdef = tf.get_default_graph().as_graph_def()
tf.train.write_graph(sess.graph_def, "/your/save/path/" , "save_name.pb" ,as_text = False )
frozen_graph = tf.graph_util.convert_variables_to_constants(sess,graphdef,[ 'output/node/name' ])
frozen_graph_trim = tf.graph_util.remove_training_nodes(frozen_graph)
freeze_graph.freeze_graph( '/your/save/path/save_name.pb' ,' ',True, ckpt_path,' output / node / name ',' save / restore_all ',' save / Const: 0 ',' frozen_name.pb', True ,"")
|
2. PB文件读取使用
1
2
3
4
5
6
7
8
9
10
|
output_graph_def = tf.GraphDef()
with open ( "your_name.pb" , "rb" ) as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name = "")
node_in = sess.graph.get_tensor_by_name( "input_node_name" )
model_out = sess.graph.get_tensor_by_name( "out_node_name" )
feed_dict = {node_in:in_data}
pred = sess.run(model_out, feed_dict)
|
以上这篇将tensorflow模型打包成PB文件及PB文件读取方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/hustchenze/article/details/83660960