多分类 交叉熵 纯python实现

时间:2025-03-27 08:30:02

之前只知道二分类交叉熵的公式,或者调包实现过,从来没用纯python实现过。正好有个机会需要写这么一个函数,特此记录一下。

问题:给定一个logits输出x, 和真实的标签y要求计算其交叉熵

  1. 首先要计算x的概率,用python实现softmax
  2. 然后找到标签对应的概率

话不多说直接上代码:

import math
def softmax(x):
    m, n = len(x), len(x[0])
    for i in range(m):
        cur_m = max(x[i])
        for j in range(n):
            x[i][j] = (x[i][j]-cur_m)
        cur_s = sum(x[i])
        for j in range(n):
            x[i][j] = x[i][j]/cur_s
    return x
            
def crossentropy(y_hat, y):
    # y_hat: [N, C]
    # y: [N]
    loss = 0
    batch_size = len(y)
    y_softmax = softmax(y_hat)
    for i in range(batch_size):
        loss -= (y_softmax[i][y[i]])
    return loss

x = [[0.1, 0.2, 0.7], [0.5, 0.2, 0.3]]
y = [2, 0]
print(cross_entropy(x, y))

关于softmax的原理和注意事项参考:

/Answer3664/article/details/92070045

关于交叉熵的原理和实现参考:

/Answer3664/article/details/92804033

按照自己想法写的,有任何改进的地方欢迎批评指正。