【深度学习】多分类任务评估指标sklearn和torchmetrics对比

时间:2024-11-05 11:28:56

【深度学习】多分类任务评估指标sklearn和torchmetrics对比

  • 说明
  • sklearn代码
  • torchmetrics代码
  • 两个MultiClassReport类的对比分析
    • 1. 代码结构与实现方式
    • 2. 数据处理与内存使用
    • 3. 性能与效率
  • 二分类任务评估指标
    • 1. 准确率(Accuracy)
    • 2. 精确率(Precision)
    • 3. 召回率(Recall)
    • 4. F1值(F1-score)
  • 多分类评估指标
    • 1. 混淆矩阵(Confusion Matrix)
    • 2. 准确率(Accuracy)
    • 3. 精确率(Precision)
    • 4. 召回率(Recall)
    • 5. F1值(宏平均)

说明

sklearn和torchmetrics两个metric代码跑模型的输出结果一致,对比他们的区别。评估指标写在下面

sklearn代码

import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

class MultiClassReport():
    """
    Accuracy, F1 Score, Precision and Recall for multi - class classification task.
    """

    def __init__(self, name='MultiClassReport', average='macro'):
        super(MultiClassReport, self).__init__()
        self.average = average
        self._name = name
        self.reset()

    def reset(self):
        """
        Resets all the metric state.
        """
        self.y_prob = []
        self.y_true = []

    def update(self, probs, labels):
        # 将Tensor转换为numpy数组并添加到相应列表中
        if isinstance(probs, torch.Tensor):
            if probs.requires_grad:
                probs = probs.detach()
            probs = probs.cpu().numpy()
        if isinstance(labels, torch.Tensor):
            if labels.requires_grad:
                labels = labels.detach()
            labels = labels.cpu().numpy()
        self.y_prob.extend(probs)
        self.y_true.extend(labels)
        self.y_prob.extend(probs)
        self.y_true.extend(labels)

    def accumulate(self):
        accuracy = accuracy_score(self.y_true, np.argmax(self.y_prob, axis=1))
        f1 = f1_score(self.y_true, np.argmax(self.y_prob, axis=1), average=self.average)
        precision = precision_score(self.y_true, np.argmax(self.y_prob, axis=1), average=self.average)
        recall = recall_score(self.y_true, np.argmax(self.y_prob, axis=1), average=self.average)
        return accuracy, f1, precision, recall

    def name(self):
        """
        Returns metric name
        """
        return self._name

torchmetrics代码

from torchmetrics import Accuracy, F1Score, Precision, Recall
from model import polarity_classes, device

# 创建评估指标对象
accuracy_metric = Accuracy(task='multiclass', num_classes=polarity_classes).to(device)
f1_metric = F1Score(task='multiclass', num_classes=polarity_classes, average='macro').to(device)
precision_metric = Precision(task='multiclass', num_classes=polarity_classes, average='macro').to(device)
recall_metric = Recall(task='multiclass', num_classes=polarity_classes, average='macro').to(device)

class MultiClassReport():
    """
    Accuracy, F1 Score, Precision and Recall for multi-class classification task.
    average:micro、macro
    """

    def __init__(self, name='MultiClassReport', average='macro'):
        super(MultiClassReport, self).__init__()
        self.average = average
        self._name = name

    def reset(self):
        """
        Resets all the metric state.
        """
        accuracy_metric.reset()
        f1_metric.reset()
        precision_metric.reset()
        recall_metric.reset()

    def update(self, probs, labels):
        accuracy_metric.update(probs, labels)
        f1_metric.update(probs, labels)
        precision_metric.update(probs, labels)
        recall_metric.update(probs, labels)

    def accumulate(self):
        accuracy = accuracy_metric.compute()
        f1 = f1_metric.compute()
        precision = precision_metric.compute()
        recall = recall_metric.compute()
        return accuracy, f1, precision, recall

    def name(self):
        """
        Returns metric name
        """
        return self._name

两个MultiClassReport类的对比分析

1. 代码结构与实现方式

  • sklearn版本
    • 代码逻辑较为清晰直接。在update方法中,将输入的PyTorch张量转换为numpy数组,并存储到y_proby_true列表中。在accumulate方法中,直接使用sklearnaccuracy_scoref1_scoreprecision_scorerecall_score函数基于存储的列表数据计算评估指标。
  • torchmetrics版本
    • 利用torchmetrics库提供的专门的评估指标类(AccuracyF1ScorePrecisionRecall)。在update方法中,直接调用这些类的update方法来处理输入数据,内部有自己的状态管理机制。在accumulate方法中,通过调用相应类的compute方法获取评估指标值。
    • 这种方式与PyTorch的生态系统集成得更好,尤其是在基于PyTorch进行深度学习项目开发时,可以方便地在GPU上进行计算(如果deviceGPU),并且可以利用torchmetrics库的其他特性,如分布式训练支持等。

2. 数据处理与内存使用

  • sklearn版本
    • update方法中不断扩展y_proby_true列表来存储数据。如果处理大量数据,可能会占用较多内存,因为它需要将所有的预测概率和真实标签都保存在内存中。
    • 每次计算评估指标时,都需要对整个存储的数组进行操作,如np.argmax等,这在数据量较大时可能会有一定的计算开销。
  • torchmetrics版本
    • 虽然torchmetrics类内部也需要存储一定的状态信息,但它们的设计可能更高效地利用内存和处理数据更新。例如,它们可能会采用增量计算的方式,而不是像sklearn版本那样一次性处理所有数据。
    • 在处理大规模数据或长时间训练过程中,torchmetrics版本可能在内存管理和计算效率方面更有优势。

3. 性能与效率

  • sklearn版本
    • 在小规模数据和简单场景下,性能表现良好。但随着数据量的增加和模型复杂度的提高,由于数据转换和计算方式的原因,可能会出现性能瓶颈。
  • torchmetrics版本
    • 设计初衷就是为了在PyTorch深度学习环境中高效运行,特别是在利用GPU计算资源时,能够更高效地更新和计算评估指标,更适合大规模数据和复杂模型的评估场景。

二分类任务评估指标

TP(True Positive)是真正例,TN(True Negative)是真反例,FP(False Positive)是假正例,FN(False Negative)是假反例。

1. 准确率(Accuracy)

准确率是指在所有预测样本中,预测正确的样本所占的比例。它衡量的是模型整体预测正确的程度。
在这里插入图片描述

2. 精确率(Precision)

精确率是指在所有被预测为正类的样本中,真正为正类的样本所占的比例。
![在这里插入图片描述](https://i-blog.****img.cn/direct/0c6f404c85df462ab1c4831f84cbd94d.png

3. 召回率(Recall)

召回率是指在所有实际为正类的样本中,被模型正确预测为正类的样本所占的比例。
在这里插入图片描述

4. F1值(F1-score)

F1值是精确率和召回率的调和平均数,它综合考虑了精确率和召回率两个指标,能够更全面地评估模型的性能。
在这里插入图片描述

多分类评估指标

1. 混淆矩阵(Confusion Matrix)

它是一个方阵,用来展示分类模型在每个类别上的预测对错情况。行代表真实类别,列代表预测类别,某个位置的值就是实际是某类却被预测成另一类的样本数量,能直观呈现模型对各类别预测的混淆情况。
在这里插入图片描述

2. 准确率(Accuracy)

就是模型预测正确的样本数占总样本数的比例,反映整体预测正确程度。
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

3. 精确率(Precision)

  • 类别精确率:对于每个类别,是预测为该类且正确的样本数除以预测为该类的样本数,看预测某类时的准确程度。
    在这里插入图片描述

  • 宏平均精确率(Macro-average Precision):先算出每个类别的精确率,再求平均,平等看待每个类别。
    在这里插入图片描述

  • 微平均精确率(Micro-average Precision):将所有类别预测对的情况汇总除以预测的总数,从整体上看预测的精准情况,对类别不平衡不太敏感。
    在这里插入图片描述

4. 召回率(Recall)

  • 类别召回率:对于每个类别,是预测为该类且正确的样本数除以实际是该类的样本数,体现对该类样本的召回能力。
    在这里插入图片描述

  • 宏平均召回率(Macro- average Recall):先算出每个类别的召回率,再求平均,衡量对每个类别样本的召回水平。
    在这里插入图片描述

  • 微平均召回率(Micro-average Recall):从整体角度,用所有类别预测正确的样本总数除以实际各类别样本总数,综合评估召回情况。
    在这里插入图片描述

5. F1值(宏平均)

  • 类别F1值:是类别精确率和召回率的调和平均数,综合二者信息。
    在这里插入图片描述

  • 宏平均F1值(Macro-average F1):先计算每个类别的F1值,再平均,更全面地体现模型对各分类的整体性能。
    在这里插入图片描述

  • 微平均F1值(Micro-average F1):基于微平均精确率(Precision micro)和微平均召回率(Recall micro)来计算
    微平均
    在这里插入图片描述