【深度学习】多分类任务评估指标sklearn和torchmetrics对比
- 说明
- sklearn代码
- torchmetrics代码
- 两个MultiClassReport类的对比分析
- 1. 代码结构与实现方式
- 2. 数据处理与内存使用
- 3. 性能与效率
- 二分类任务评估指标
- 1. 准确率(Accuracy)
- 2. 精确率(Precision)
- 3. 召回率(Recall)
- 4. F1值(F1-score)
- 多分类评估指标
- 1. 混淆矩阵(Confusion Matrix)
- 2. 准确率(Accuracy)
- 3. 精确率(Precision)
- 4. 召回率(Recall)
- 5. F1值(宏平均)
说明
sklearn和torchmetrics两个metric代码跑模型的输出结果一致,对比他们的区别。评估指标写在下面
sklearn代码
import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
class MultiClassReport():
"""
Accuracy, F1 Score, Precision and Recall for multi - class classification task.
"""
def __init__(self, name='MultiClassReport', average='macro'):
super(MultiClassReport, self).__init__()
self.average = average
self._name = name
self.reset()
def reset(self):
"""
Resets all the metric state.
"""
self.y_prob = []
self.y_true = []
def update(self, probs, labels):
# 将Tensor转换为numpy数组并添加到相应列表中
if isinstance(probs, torch.Tensor):
if probs.requires_grad:
probs = probs.detach()
probs = probs.cpu().numpy()
if isinstance(labels, torch.Tensor):
if labels.requires_grad:
labels = labels.detach()
labels = labels.cpu().numpy()
self.y_prob.extend(probs)
self.y_true.extend(labels)
self.y_prob.extend(probs)
self.y_true.extend(labels)
def accumulate(self):
accuracy = accuracy_score(self.y_true, np.argmax(self.y_prob, axis=1))
f1 = f1_score(self.y_true, np.argmax(self.y_prob, axis=1), average=self.average)
precision = precision_score(self.y_true, np.argmax(self.y_prob, axis=1), average=self.average)
recall = recall_score(self.y_true, np.argmax(self.y_prob, axis=1), average=self.average)
return accuracy, f1, precision, recall
def name(self):
"""
Returns metric name
"""
return self._name
torchmetrics代码
from torchmetrics import Accuracy, F1Score, Precision, Recall
from model import polarity_classes, device
# 创建评估指标对象
accuracy_metric = Accuracy(task='multiclass', num_classes=polarity_classes).to(device)
f1_metric = F1Score(task='multiclass', num_classes=polarity_classes, average='macro').to(device)
precision_metric = Precision(task='multiclass', num_classes=polarity_classes, average='macro').to(device)
recall_metric = Recall(task='multiclass', num_classes=polarity_classes, average='macro').to(device)
class MultiClassReport():
"""
Accuracy, F1 Score, Precision and Recall for multi-class classification task.
average:micro、macro
"""
def __init__(self, name='MultiClassReport', average='macro'):
super(MultiClassReport, self).__init__()
self.average = average
self._name = name
def reset(self):
"""
Resets all the metric state.
"""
accuracy_metric.reset()
f1_metric.reset()
precision_metric.reset()
recall_metric.reset()
def update(self, probs, labels):
accuracy_metric.update(probs, labels)
f1_metric.update(probs, labels)
precision_metric.update(probs, labels)
recall_metric.update(probs, labels)
def accumulate(self):
accuracy = accuracy_metric.compute()
f1 = f1_metric.compute()
precision = precision_metric.compute()
recall = recall_metric.compute()
return accuracy, f1, precision, recall
def name(self):
"""
Returns metric name
"""
return self._name
两个MultiClassReport类的对比分析
1. 代码结构与实现方式
-
sklearn
版本:- 代码逻辑较为清晰直接。在
update
方法中,将输入的PyTorch
张量转换为numpy
数组,并存储到y_prob
和y_true
列表中。在accumulate
方法中,直接使用sklearn
的accuracy_score
、f1_score
、precision_score
和recall_score
函数基于存储的列表数据计算评估指标。
- 代码逻辑较为清晰直接。在
-
torchmetrics
版本:- 利用
torchmetrics
库提供的专门的评估指标类(Accuracy
、F1Score
、Precision
、Recall
)。在update
方法中,直接调用这些类的update
方法来处理输入数据,内部有自己的状态管理机制。在accumulate
方法中,通过调用相应类的compute
方法获取评估指标值。 - 这种方式与
PyTorch
的生态系统集成得更好,尤其是在基于PyTorch
进行深度学习项目开发时,可以方便地在GPU
上进行计算(如果device
是GPU
),并且可以利用torchmetrics
库的其他特性,如分布式训练支持等。
- 利用
2. 数据处理与内存使用
-
sklearn
版本:- 在
update
方法中不断扩展y_prob
和y_true
列表来存储数据。如果处理大量数据,可能会占用较多内存,因为它需要将所有的预测概率和真实标签都保存在内存中。 - 每次计算评估指标时,都需要对整个存储的数组进行操作,如
np.argmax
等,这在数据量较大时可能会有一定的计算开销。
- 在
-
torchmetrics
版本:- 虽然
torchmetrics
类内部也需要存储一定的状态信息,但它们的设计可能更高效地利用内存和处理数据更新。例如,它们可能会采用增量计算的方式,而不是像sklearn
版本那样一次性处理所有数据。 - 在处理大规模数据或长时间训练过程中,
torchmetrics
版本可能在内存管理和计算效率方面更有优势。
- 虽然
3. 性能与效率
-
sklearn
版本:- 在小规模数据和简单场景下,性能表现良好。但随着数据量的增加和模型复杂度的提高,由于数据转换和计算方式的原因,可能会出现性能瓶颈。
-
torchmetrics
版本:- 设计初衷就是为了在
PyTorch
深度学习环境中高效运行,特别是在利用GPU
计算资源时,能够更高效地更新和计算评估指标,更适合大规模数据和复杂模型的评估场景。
- 设计初衷就是为了在
二分类任务评估指标
TP(True Positive)是真正例,TN(True Negative)是真反例,FP(False Positive)是假正例,FN(False Negative)是假反例。
1. 准确率(Accuracy)
准确率是指在所有预测样本中,预测正确的样本所占的比例。它衡量的是模型整体预测正确的程度。
2. 精确率(Precision)
精确率是指在所有被预测为正类的样本中,真正为正类的样本所占的比例。
3. 召回率(Recall)
召回率是指在所有实际为正类的样本中,被模型正确预测为正类的样本所占的比例。
4. F1值(F1-score)
F1值是精确率和召回率的调和平均数,它综合考虑了精确率和召回率两个指标,能够更全面地评估模型的性能。
多分类评估指标
1. 混淆矩阵(Confusion Matrix)
它是一个方阵,用来展示分类模型在每个类别上的预测对错情况。行代表真实类别,列代表预测类别,某个位置的值就是实际是某类却被预测成另一类的样本数量,能直观呈现模型对各类别预测的混淆情况。
2. 准确率(Accuracy)
就是模型预测正确的样本数占总样本数的比例,反映整体预测正确程度。
3. 精确率(Precision)
-
类别精确率:对于每个类别,是预测为该类且正确的样本数除以预测为该类的样本数,看预测某类时的准确程度。
-
宏平均精确率(Macro-average Precision):先算出每个类别的精确率,再求平均,平等看待每个类别。
-
微平均精确率(Micro-average Precision):将所有类别预测对的情况汇总除以预测的总数,从整体上看预测的精准情况,对类别不平衡不太敏感。
4. 召回率(Recall)
-
类别召回率:对于每个类别,是预测为该类且正确的样本数除以实际是该类的样本数,体现对该类样本的召回能力。
-
宏平均召回率(Macro- average Recall):先算出每个类别的召回率,再求平均,衡量对每个类别样本的召回水平。
-
微平均召回率(Micro-average Recall):从整体角度,用所有类别预测正确的样本总数除以实际各类别样本总数,综合评估召回情况。
5. F1值(宏平均)
-
类别F1值:是类别精确率和召回率的调和平均数,综合二者信息。
-
宏平均F1值(Macro-average F1):先计算每个类别的F1值,再平均,更全面地体现模型对各分类的整体性能。
-
微平均F1值(Micro-average F1):基于微平均精确率(Precision micro)和微平均召回率(Recall micro)来计算
微平均