如下所示:
1
2
3
4
5
6
7
8
9
10
|
from keras import backend as K
from keras.models import load_model
models = load_model( 'models.hdf5' )
image = r 'image.png'
images = cv2.imread(r 'image.png' )
image_arr = process_image(image, ( 224 , 224 , 3 ))
image_arr = np.expand_dims(image_arr, axis = 0 )
layer_1 = K.function([base_model.get_input_at( 0 )], [base_model.get_layer( 'layer_name' ).output])
f1 = layer_1([image_arr])[ 0 ]
|
加载训练好并保存的网络模型
加载数据(图像),并将数据处理成array形式
指定输出层
将处理后的数据输入,然后获取输出
其中,K.function有两种不同的写法:
1. 获取名为layer_name的层的输出
layer_1 = K.function([base_model.get_input_at(0)], [base_model.get_layer('layer_name').output]) #指定输出层的名称
2. 获取第n层的输出
layer_1 = K.function([model.get_input_at(0)], [model.layers[5].output]) #指定输出层的序号(层号从0开始)
另外,需要注意的是,书写不规范会导致报错:
报错:
TypeError: inputs to a TensorFlow backend function should be a list or tuple
将该句:
f1 = layer_1(image_arr)[0]
修改为:
f1 = layer_1([image_arr])[0]
补充知识:keras.backend.function()
如下所示:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
def function(inputs, outputs, updates = None , * * kwargs):
"""Instantiates a Keras function.
Arguments:
inputs: List of placeholder tensors.
outputs: List of output tensors.
updates: List of update ops.
**kwargs: Passed to `tf.Session.run`.
Returns:
Output values as Numpy arrays.
Raises:
ValueError: if invalid kwargs are passed in.
"""
if kwargs:
for key in kwargs:
if (key not in tf_inspect.getargspec(session_module.Session.run)[ 0 ] and
key not in tf_inspect.getargspec(Function.__init__)[ 0 ]):
msg = ( 'Invalid argument "%s" passed to K.function with Tensorflow '
'backend' ) % key
raise ValueError(msg)
return Function(inputs, outputs, updates = updates, * * kwargs)
|
这是keras.backend.function()的源码。其中函数定义开头的注释就是官方文档对该函数的解释。
我们可以发现function()函数返回的是一个Function对象。下面是Function类的定义。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
|
class Function( object ):
"""Runs a computation graph.
Arguments:
inputs: Feed placeholders to the computation graph.
outputs: Output tensors to fetch.
updates: Additional update ops to be run at function call.
name: a name to help users identify what this function does.
"""
def __init__( self , inputs, outputs, updates = None , name = None ,
* * session_kwargs):
updates = updates or []
if not isinstance (inputs, ( list , tuple )):
raise TypeError( '`inputs` to a TensorFlow backend function '
'should be a list or tuple.' )
if not isinstance (outputs, ( list , tuple )):
raise TypeError( '`outputs` of a TensorFlow backend function '
'should be a list or tuple.' )
if not isinstance (updates, ( list , tuple )):
raise TypeError( '`updates` in a TensorFlow backend function '
'should be a list or tuple.' )
self .inputs = list (inputs)
self .outputs = list (outputs)
with ops.control_dependencies( self .outputs):
updates_ops = []
for update in updates:
if isinstance (update, tuple ):
p, new_p = update
updates_ops.append(state_ops.assign(p, new_p))
else :
# assumed already an op
updates_ops.append(update)
self .updates_op = control_flow_ops.group( * updates_ops)
self .name = name
self .session_kwargs = session_kwargs
def __call__( self , inputs):
if not isinstance (inputs, ( list , tuple )):
raise TypeError( '`inputs` should be a list or tuple.' )
feed_dict = {}
for tensor, value in zip ( self .inputs, inputs):
if is_sparse(tensor):
sparse_coo = value.tocoo()
indices = np.concatenate((np.expand_dims(sparse_coo.row, 1 ),
np.expand_dims(sparse_coo.col, 1 )), 1 )
value = (indices, sparse_coo.data, sparse_coo.shape)
feed_dict[tensor] = value
session = get_session()
updated = session.run(
self .outputs + [ self .updates_op],
feed_dict = feed_dict,
* * self .session_kwargs)
return updated[: len ( self .outputs)]
|
所以,function函数利用我们之前已经创建好的comuptation graph。遵循计算图,从输入到定义的输出。这也是为什么该函数经常用于提取中间层结果。
以上这篇keras K.function获取某层的输出操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/qq_37974048/article/details/102727653