之前只知道二分类交叉熵的公式,或者调包实现过,从来没用纯python实现过。正好有个机会需要写这么一个函数,特此记录一下。
问题:给定一个logits输出x, 和真实的标签y要求计算其交叉熵
- 首先要计算x的概率,用python实现softmax
- 然后找到标签对应的概率
话不多说直接上代码:
import math
def softmax(x):
m, n = len(x), len(x[0])
for i in range(m):
cur_m = max(x[i])
for j in range(n):
x[i][j] = (x[i][j]-cur_m)
cur_s = sum(x[i])
for j in range(n):
x[i][j] = x[i][j]/cur_s
return x
def crossentropy(y_hat, y):
# y_hat: [N, C]
# y: [N]
loss = 0
batch_size = len(y)
y_softmax = softmax(y_hat)
for i in range(batch_size):
loss -= (y_softmax[i][y[i]])
return loss
x = [[0.1, 0.2, 0.7], [0.5, 0.2, 0.3]]
y = [2, 0]
print(cross_entropy(x, y))
关于softmax的原理和注意事项参考:
/Answer3664/article/details/92070045
关于交叉熵的原理和实现参考:
/Answer3664/article/details/92804033
按照自己想法写的,有任何改进的地方欢迎批评指正。