Tensorflow 模型保存、节点修改以及Serving 图优化
文章目录
- Tensorflow 模型保存、节点修改以及Serving 图优化
- 前言 (与正文无关, 可忽略)
- 总览
- 代码地址
- 广而告之
- checkpoint 格式
- 训练代码 & 保存 ckpt
- 加载 ckpt & 检查 graph 结构
- 节点修改
- frozen_graph 格式
- Serving 图优化
- 总结
前言 (与正文无关, 可忽略)
近期打算总结一些 Tensorflow 的基础知识, 方便查阅. 本文的写作动机是考虑到一个小问题: 我们常用 tf.data
系列 API 来生成训练数据, 因此 Train Graph 的输入节点通常是 Iterator 节点 (比如会调用 tf.data.make_one_shot_iterator
以及该对象的 get_next()
方法), 但是在 Serving 的时候, 我在想应该如何处理输入节点, 如何把新增的 tf.placeholder
加入到 Serving 图中.
一种方法是将 Serving Graph 重新写一遍, 输入节点更新成 tf.placeholder
, 然后输入到模型中, 从而生成一个新的 Graph; 但我希望有更简洁的方法, 比如能不能直接将 Iterator 输入节点替换成 tf.placeholder
, 这样即便我不知道模型代码是如何写的, 也能构建好 Serving 图. 在该问题的指引下, 对 TF 模型的保存与加载, Graph/MetaGraph 等概念有了稍微深入的了解.
总览
本文介绍 Tensorflow 模型部分保存方式, 主要包含 checkpoint
格式、frozen_graph
格式(SavedModel
格式暂略), 通过代码实例了解模型的保存方式, Serving 图的优化以及对 Serving 图中的节点进行修改更新.
代码地址
本文代码在 Python 3.5.2
| Tensorflow 1.15.0
环境下测试成功.
本文所有代码均可以从 https://github.com/axzml/BlogShare/tree/master/Tensorflow/GraphDef 下载.
广而告之
可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号, 可以及时获取最新原创技术文章更新:
另外可以看看知乎专栏 PoorMemory-机器学习, 以后文章也会发在知乎专栏中.
checkpoint 格式
训练代码 & 保存 ckpt
写了一个简单的训练代码(train.py
)如下, 五脏俱全, 其中定义了三个主要函数:
-
data_generator()
: 生成 Fake 数据参与模型训练
-
model()
: 定义了简单的神经网络
-
train()
: 定义训练代码, 调用 tf.train.Saver()
以 checkpoint 的形式保存模型
# _*_ coding:utf-8 _*_
## train.py
import tensorflow as tf
import os
import numpy as np
from os.path import join, exists
batch_size = 2
steps = 10
epochs = 1
emb_dim = 4
sample_num = epochs * steps * batch_size
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
def data_generator():
"""产生 Fake 训练数据"""
dataset = tf.data.Dataset.from_tensor_slices((np.random.randn(sample_num, emb_dim),\
np.random.randn(sample_num)))
dataset = dataset.repeat(epochs).batch(batch_size)
iterator = tf.data.make_one_shot_iterator(dataset)
feature, label = iterator.get_next()
return feature, label
def model(feature, params=[10, 5, 1]):
"""定义模型, 3层DNN"""
fc1 = tf.layers.dense(feature, params[0], activation=tf.nn.relu, name='fc1')
fc2 = tf.layers.dense(fc1, params[1], activation=tf.nn.relu, name='fc2')
fc3 = tf.layers.dense(fc2, params[2], activation=tf.nn.sigmoid, name='fc3')
out = tf.identity(fc3, name='output')
return out
def train():
feature, label = data_generator()
output = model(feature)
loss = tf.reduce_mean(tf.square(output - label))
train_op = tf.train.AdamOptimizer(learning_rate=0.1, name='Adam').minimize(loss)
saver = tf.train.Saver()
if exists(checkpoint_dir):
os.system('rm -rf {}'.format(checkpoint_dir))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
local_step = 0
save_freq = 2
while True:
local_step += 1
_, loss_val = sess.run([train_op, loss])
if local_step % save_freq == 0:
saver.save(sess, saver_dir)
print('loss: {:.4f}'.format(loss_val))
except tf.errors.OutOfRangeError:
print("train end!")
if __name__ == '__main__':
train()
运行 python train.py
会在当前目录下生成 checkpoint_dir
目录, 其组成如下:
checkpoint_dir/
|-- 0.data-00000-of-00001 ## 记录了网络参数值
|-- 0.index ## 记录了网络参数名
|-- 0.meta ## 保存 MetaGraphDef, 该文件以 pb 格式记录了网络结构
`-- checkpoint ## 该文件记录了最新的 ckpt
加载 ckpt & 检查 graph 结构
checkpoint
格式的模型需要在 Tensorflow 框架下进行加载. 比如编写 eval.py
进行 inference, 代码如下:
#_*_ coding:utf-8 _*_
## eval.py
import tensorflow as tf
import os
from os.path import join, exists
import numpy as np
emb_dim = 4
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
meta_file = saver_dir + '.meta'
model_file = tf.train.latest_checkpoint(checkpoint_dir)
np.random.seed(123)
test_data = np.random.randn(4, emb_dim) ## 生成测试数据
def eval_graph():
with tf.Session() as sess:
saver = tf.train.import_meta_graph(meta_file)
saver.restore(sess, model_file)
output = sess.run(['output:0'], feed_dict={
'IteratorGetNext:0': test_data
})
print('eval_graph:\n{}'.format(output))
if __name__ == '__main__':
eval_graph()
在上面代码中, 注意到输入和输出节点名分别为 output
以及 IteratorGetNext
. 对于输出节点, 由于在 train.py
的 model()
函数中使用
out = tf.identity(fc3, name='output')
对输出节点重新命名为 output
, 因此输出节点的名字非常好确定. 但是输入节点的名字却不太好确定, 原因是训练时采用 tf.data
API 来传入数据, 没有显式地对输入节点进行命名. 不过由于保存模型时网络结构都已经存放在 0.meta
文件中了, 因此可以通过解析该文件来查看网络的输入节点, 具体方法如下:
#_*_ coding:utf-8 _*_
## check_graph.py
import tensorflow as tf
from tensorflow.python.framework import meta_graph
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from google.protobuf import text_format
import os
from os.path import join, exists
import numpy as np
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
meta_file = saver_dir + '.meta'
model_file = tf.train.latest_checkpoint(checkpoint_dir)
def read_pb_meta(meta_file):
"""读取 pb 格式的 meta 文件"""
meta_graph_def = meta_graph.read_meta_graph_file(meta_file)
return meta_graph_def
def read_txt_meta(txt_meta_file):
"""读取文本格式的 meta 文件"""
meta_graph = MetaGraphDef()
with open(txt_meta_file, 'rb') as f:
text_format.Merge(f.read(), meta_graph)
return meta_graph
def read_pb_graph(graph_file):
"""读取 pb 格式的 graph_def 文件"""
try:
with tf.gfile.GFile(graph_file, 'rb') as pb:
graph_def = tf.GraphDef()
graph_def.ParseFromString(pb.read())
except IOError as e:
raise Exception("Parse '{}' Failed!".format(graph_file))
return graph_def
def check_graph_def(graph_def):
"""检查 graph_def 中的各节点"""
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
name=""
)
print('===> {}'.format(type(graph)))
for op in graph.get_operations():
print(op.name, op.values()) ## 打印网络结构
def check_graph(graph_file):
"""检查 pb 格式的 graph_def 文件中的各节点"""
graph_def = read_pb_graph(graph_file)
check_graph_def(graph_def)
if __name__ == '__main__':
check_graph_def(read_pb_meta(meta_file).graph_def)
输出结果如下图所示, 可以发现距离网络参数 fc1/kernel
最近的节点是 IteratorGetNext
, 因此输入节点的名字基本可以确认是它了.
节点修改
现在回到 “前言” 中提到的问题, 如果我希望使用自行创建的 tf.placeholder
节点作为 Graph 的输入节点, 而不是采用 IteratorGetNext
, 应该如何实现. 一方面可以重新将 Tensorflow Graph 写一遍, 使用 tf.placeholder
作为输入; 另一方面其实可以考虑将 IteratorGetNet
节点用自定义的节点给替换掉, 这一步参考了博文 如何在建好TF图后修改图. 具体做法如下, 代码在 infer.py
中:
#_*_ coding:utf-8 _*_
## infer.py
import tensorflow as tf
from tensorflow.python.framework import meta_graph
import os
from os.path import join, exists
import numpy as np
emb_dim = 4
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
meta_file = saver_dir + '.meta'
model_file = tf.train.latest_checkpoint(checkpoint_dir)
np.random.seed(123)
test_data = np.random.randn(4, emb_dim)
def read_pb_meta(meta_file):
meta_graph_def = meta_graph.read_meta_graph_file(meta_file)
return meta_graph_def
def update_node(graph, src_node_name, tar_node):
"""
@params:
graph : tensorflow Graph object
src_node_name : source node name to be modified
tar_node : target node
"""
input = graph.get_tensor_by_name('{}:0'.format(src_node_name))
for op in input.consumers():
idx_list = []
for idx, item in enumerate(op.inputs):
if src_node_name in item.name:
idx_list.append(idx)
for idx in idx_list:
op._update_input(idx, tar_node)
def modify_graph():
meta_graph_def = read_pb_meta(meta_file)
with tf.Graph().as_default() as graph:
tf.import_graph_def(meta_graph_def.graph_def, name="")
input_ph = tf.placeholder(tf.float64, [None, emb_dim], name='input')
update_node(graph, 'IteratorGetNext', input_ph)
with tf.Session(graph=graph) as sess:
saver = tf.train.import_meta_graph(meta_file)
saver.restore(sess, model_file)
output = sess.run(['output:0'], feed_dict={
'input:0': test_data
})
print('modify_graph:\n{}'.format(output))
if __name__ == '__main__':
modify_graph()
该文件定义了函数 update_node
来实现对 graph 中节点的替换, 函数如下:
def update_node(graph, src_node_name, tar_node):
"""
@params:
graph : tensorflow Graph object
src_node_name : source node name to be modified
tar_node : target node
"""
input = graph.get_tensor_by_name('{}:0'.format(src_node_name))
for op in input.consumers():
idx_list = []
for idx, item in enumerate(op.inputs):
if src_node_name in item.name:
idx_list.append(idx)
for idx in idx_list:
op._update_input(idx, tar_node)
其中 src_node_name
表示要被替换掉的节点名字, 比如希望替换 IteratorGetNext
. 通过该名字在 graph
中找到对应的节点 input
, 然后调用 input.consumers()
找到使用该节点的 op
, 再通过更新 op
的输入 (op.inputs
) 来实现对节点的替换. 由于替换的方法 op._update_input
需要使用索引 idx
, 因此用 idx_list
来记录要替换节点的索引.
frozen_graph 格式
前面介绍的 checkpoint
格式将网络结构和参数分开保存, 而 frozen_graph
格式则会将网络参数以 Const 节点的形式写入到 GraphDef, 并保存到统一的 protobuf 文件中, 由于 protobuf 是跨语言、跨平台序列化数据协议, 因此还可以用 C++/Java/Python 等对模型进行加载.
下面写了个简单的将 ckpt 转换为 frozen_graph 的例子 frozen_graph.py
, 代码如下:
#_*_ coding:utf-8 _*_
## frozen_graph.py
import tensorflow as tf
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import dtypes
from tensorflow.python.tools import optimize_for_inference_lib
import os
from os.path import join, exists
import numpy as np
emb_dim = 4
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
meta_file = saver_dir + '.meta'
model_file = tf.train.latest_checkpoint(checkpoint_dir)
np.random.seed(123)
test_data = np.random.randn(4, emb_dim)
def read_pb_meta(meta_file):
meta_graph_def = meta_graph.read_meta_graph_file(meta_file)
return meta_graph_def
def update_node(graph, src_node_name, tar_node):
"""
@params:
graph : tensorflow Graph object
src_node_name : source node name to be modified
tar_node : target node
"""
input = graph.get_tensor_by_name('{}:0'.format(src_node_name))
for op in input.consumers():
idx_list = []
for idx, item in enumerate(op.inputs):
if src_node_name in item.name:
idx_list.append(idx)
for idx in idx_list:
op._update_input(idx, tar_node)
def check_graph_def(graph_def):
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
name=""
)
print('===> {}'.format(type(graph)))
for op in graph.get_operations():
print(op.name, op.values()) ## 打印网络结构
def write_frozen_graph():
meta_graph_def = read_pb_meta(meta_file)
with tf.Graph().as_default() as graph:
tf.import_graph_def(meta_graph_def.graph_def, name="")
input_ph = tf.placeholder(tf.float64, [None, emb_dim], name='input')
update_node(graph, 'IteratorGetNext', input_ph)
with tf.Session(graph=graph) as sess:
saver = tf.train.import_meta_graph(meta_file)
saver.restore(sess, model_file)
input_node_names = ['input']
##placeholder_type_enum = [dtypes.float64.as_datatype_enum]
placeholder_type_enum = [input_ph.dtype.as_datatype_enum]
output_node_names = ['output']
## 对 graph 进行优化, 把和 inference 无关的节点给删除, 比如 Saver 有关的节点
graph_def = optimize_for_inference_lib.optimize_for_inference(
graph.as_graph_def(), input_node_names, output_node_names, placeholder_type_enum
)
check_graph_def(graph_def)
## 将 ckpt 转换为 frozen_graph, 网络权重和结构写入统一 pb 文件中, 参数以 Const 的形式保存
frozen_graph = tf.graph_util.convert_variables_to_constants(sess,
graph_def, output_node_names)
out_graph_path = os.path.join('.', "frozen_model.pb")
with tf.gfile.GFile(out_graph_path, "wb") as f:
f.write(frozen_graph.SerializeToString())
def read_frozen_graph():
with tf.Graph().as_default() as graph:
graph_def = tf.GraphDef()
with open("frozen_model.pb", 'rb') as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
# print(graph_def)
with tf.Session(graph=graph) as sess:
output = sess.run(['output:0'], feed_dict={
'input:0': test_data
})
print('frozen_graph:\n{}'.format(output))
if __name__ == '__main__':
write_frozen_graph()
read_frozen_graph()
其中 write_frozen_graph()
中调用 optimize_for_inference_lib.optimize_for_inference
对 Graph 节点进行优化, 将在下一节进行介绍. 此外还调用 tf.graph_util.convert_variables_to_constants
将 ckpt 转换为 frozen_graph, 参数以 Const 的形式保存:
Serving 图优化
在上一节生成 frozen_graph 时, 调用了 optimize_for_inference_lib.optimize_for_inference
对 Graph 节点进行优化, 本节简要对其进行说明. 在调用该函数前如果打印从 checkpoint 中加载的 graph 时, 会发现结构中包含很多在训练时需要但在线 Serving 时并不需要的 Op, 如优化算法 Adam
, 模型保存 Saver
, 梯度 gradients
等等, 如下图:
optimize_for_inference_lib.optimize_for_inference
函数的一个主要工作就是将 graph 在 Serving 时无用的 Op 给去除.
该函数定义在 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference_lib.py,
def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
placeholder_type_enum, toco_compatible=False):
## ..... 显示核心代码
optimized_graph_def = strip_unused_lib.strip_unused(
optimized_graph_def, input_node_names, output_node_names,
placeholder_type_enum)
optimized_graph_def = graph_util.remove_training_nodes(
optimized_graph_def, output_node_names)
## ....
return optimized_graph_def
其中 strip_unused_lib.strip_unused 定义如下:
def strip_unused(input_graph_def, input_node_names, output_node_names,
placeholder_type_enum):
"""Removes unused nodes from a GraphDef.
Args:
input_graph_def: A graph with nodes we want to prune.
input_node_names: A list of the nodes we use as inputs.
output_node_names: A list of the output nodes.
placeholder_type_enum: The AttrValue enum for the placeholder data type, or
a list that specifies one value per input node name.
Returns:
A `GraphDef` with all unnecessary ops removed.
Raises:
ValueError: If any element in `input_node_names` refers to a tensor instead
of an operation.
KeyError: If any element in `input_node_names` is not found in the graph.
"""
for name in input_node_names:
if ":" in name:
raise ValueError(f"Name '{name}' appears to refer to a Tensor, not an "
"Operation.")
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
not_found = {name for name in input_node_names}
inputs_replaced_graph_def = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names:
not_found.remove(node.name)
placeholder_node = node_def_pb2.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
if isinstance(placeholder_type_enum, list):
input_node_index = input_node_names.index(node.name)
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum[
input_node_index]))
else:
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum))
if "_output_shapes" in node.attr:
placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
"_output_shapes"])
if "shape" in node.attr:
placeholder_node.attr["shape"].CopyFrom(node.attr["shape"])
inputs_replaced_graph_def.node.extend([placeholder_node])
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
if not_found:
raise KeyError(f"The following input nodes were not found: {not_found}.")
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names)
return output_graph_def
该代码需要传入 graph_def
, 输入节点名字 input_node_names
以及输出节点名字 output_node_names
, 前面一大段代码是为了用 Placeholder
替换原本的输入节点, 算是将整个 Graph 重新写了一遍. 之后在 graph_util.extract_sub_graph 函数中, 利用 BFS 算法保留 Serving 时需要的节点, 而将不需要的节点全部给去除:
def extract_sub_graph(graph_def, dest_nodes):
"""Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
Args:
graph_def: A graph_pb2.GraphDef proto.
dest_nodes: An iterable of strings specifying the destination node names.
Returns:
The GraphDef of the sub-graph.
Raises:
TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
"""
## ... BFS 遍历 Serving 时用到的节点
nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)
nodes_to_keep_list = sorted(
list(nodes_to_keep), key=lambda n: name_to_seq_num[n])
# Now construct the output GraphDef
out = graph_pb2.GraphDef()
for n in nodes_to_keep_list:
out.node.extend([copy.deepcopy(name_to_node[n])])
out.library.CopyFrom(graph_def.library)
out.versions.CopyFrom(graph_def.versions)
return out
其中 BFS 函数定义如下:
def _node_name(n):
if n.startswith("^"):
return n[1:]
else:
return n.split(":")[0]
def _extract_graph_summary(graph_def):
"""Extracts useful information from the graph and returns them."""
name_to_input_name = {} # Keyed by the dest node name.
name_to_node = {} # Keyed by node name.
# Keeps track of node sequences. It is important to still output the
# operations in the original order.
name_to_seq_num = {} # Keyed by node name.
seq = 0
for node in graph_def.node:
n = _node_name(node.name)
name_to_node[n] = node
name_to_input_name[n] = [_node_name(x) for x in node.input]
### ....
name_to_seq_num[n] = seq
seq += 1
return name_to_input_name, name_to_node, name_to_seq_num
def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
"""Breadth first search for reachable nodes from target nodes."""
nodes_to_keep = set()
# Breadth first search to find all the nodes that we should keep.
next_to_visit = list(target_nodes)
while next_to_visit:
node = next_to_visit[0]
del next_to_visit[0]
if node in nodes_to_keep:
# Already visited this node.
continue
nodes_to_keep.add(node)
if node in name_to_input_name:
next_to_visit += name_to_input_name[node]
return nodes_to_keep
之所以把这几段代码单独拎出来, 可以在合适的时候拿出来对 graph_def
进行调试, 打印中间结果. 经过 optimize_for_inference_lib.optimize_for_inference
的处理后, graph 更为简洁轻量, 打印其中的 Op 得到:
可以看到, 训练中会用到的 Adam
, Saver
等节点全部被移除了, 整个 graph 变得异常干净整洁.
总结
写文章就是, 一鼓作气, 再而衰, 三而竭, 再一鼓作气.
我要去玩耍了.