keras封装的比较厉害,官网给的例子写的云里雾里,
在*找到了答案
You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).
1
2
3
4
|
def custom_loss_wrapper(input_tensor):
def custom_loss(y_true, y_pred):
return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
return custom_loss
|
1
2
3
4
5
|
input_tensor = Input (shape = ( 10 ,))
hidden = Dense( 100 , activation = 'relu' )(input_tensor)
out = Dense( 1 , activation = 'sigmoid' )(hidden)
model = Model(input_tensor, out)
model. compile (loss = custom_loss_wrapper(input_tensor), optimizer = 'adam' )
|
You can verify that input_tensor and the loss value will change as different X is passed to the model.
1
2
3
4
5
|
X = np.random.rand( 1000 , 10 )
y = np.random.randint( 2 , size = 1000 )
model.test_on_batch(X, y) # => 1.1974642
X * = 1000
model.test_on_batch(X, y) # => 511.15466
|
fit_generator
fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.
Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)
1
2
3
4
|
### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs = [inputX_1,inputX_2],outputs = ...)
|
补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)
首先辨析一下概念:
1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的
2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程
一、keras自定义损失函数
在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:
1
2
3
4
5
6
7
|
# 方式一
def vae_loss(x, x_decoded_mean):
xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.mean( 1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis = - 1 )
return xent_loss + kl_loss
vae. compile (optimizer = 'rmsprop' , loss = vae_loss)
|
或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
def __init__( self , * * kwargs):
self .is_placeholder = True
super (CustomVariationalLayer, self ).__init__( * * kwargs)
def vae_loss( self , x, x_decoded_mean_squash):
x = K.flatten(x)
x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
kl_loss = - 0.5 * K.mean( 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis = - 1 )
return K.mean(xent_loss + kl_loss)
def call( self , inputs):
x = inputs[ 0 ]
x_decoded_mean_squash = inputs[ 1 ]
loss = self .vae_loss(x, x_decoded_mean_squash)
self .add_loss(loss, inputs = inputs)
# We don't use this output.
return x
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae. compile (optimizer = 'rmsprop' , loss = None )
|
在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置
注意事项:
1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar
2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错
有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如
discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)
二、keras中的样本权重
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
# Import
import numpy as np
from sklearn.utils import class_weight
# Example model
model = Sequential()
model.add(Dense( 32 , activation = 'relu' , input_dim = 100 ))
model.add(Dense( 1 , activation = 'sigmoid' ))
# Use binary crossentropy loss
model. compile (optimizer = 'rmsprop' ,
loss = 'binary_crossentropy' ,
metrics = [ 'accuracy' ])
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight( 'balanced' ,
np.unique(y_train),
y_train)
# Add the class weights to the training
model.fit(x_train, y_train, epochs = 10 , batch_size = 32 , class_weight = weights)
|
Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].
以上这篇keras 自定义loss层+接受输入实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/u013608336/article/details/82559469