Pytorch实用教程:pytorch中 argmax(dim)用法详解

时间:2024-06-01 19:12:08

argmax(dim) 是 PyTorch 中的一个函数,用于找出指定维度上最大值的索引。argmax 函数是在多维张量上进行操作的,通过 dim 参数可以指定在哪一个维度上查找最大值。

参数解释

  • dim: 指定要在哪个维度上执行寻找最大值的操作。维度的索引从 0 开始,对应于张量的各个轴。

返回值

  • 返回一个新的张量,包含了指定维度 dim 上每个位置最大值的索引。

使用场景

在深度学习中,argmax 常用于分类任务的输出处理。例如,在处理模型的输出时,经常需要从 softmax 层输出的概率分布中找出概率最高的类别索引。

示例代码

假设你有一个模型输出了每个类别的预测分数(或概率),现在你想知道每个样本最可能属于哪个类别:

import torch

# 假设 outputs 是一个模型的输出,形状为 [batch_size, num_classes]
# 每行代表一个样本对于各类别的预测分数
outputs = torch.tensor