对于np.argmax()让我迷惑了很久,尤其是其中的axis=1的比较结果。
一、np.argmax()的理解
1、最简单的例子
假定现在有一个数组a = [3, 1, 2, 4, 6, 1]现在要算数组a中最大数的索引是多少。最直接的思路,先假定第0个数最大,然后拿这个和后面的数比,找到大的就更新索引。代码如下
1
2
3
4
5
6
7
8
|
a = [ 3 , 1 , 2 , 4 , 6 , 1 ]
maxindex = 0
i = 0
for tmp in a:
if tmp > a[maxindex]:
maxindex = i
i + = 1
print (maxindex)
|
这个问题可以帮助我们理解argmax.
2、函数的解释
一维数组
1
2
3
|
import numpy as np
a = np.array([ 3 , 1 , 2 , 4 , 6 , 1 ])
print (np.argmax(a))
|
argmax返回的是最大数的索引.argmax有一个参数axis,默认是0,表示第几维的最大值。
二维数组
1
2
3
4
5
|
import numpy as np
a = np.array([[ 1 , 5 , 5 , 2 ],
[ 9 , 6 , 2 , 8 ],
[ 3 , 7 , 9 , 1 ]])
print (np.argmax(a, axis = 0 ))
|
为了描述方便,a就表示这个二维数组。np.argmax(a, axis=0)的含义是a[0][j],a[1][j],a[2]j中最大值的索引。从a[0][j]开始,最大值索引最初为(0,0,0,0),拿a[0][j]和a[1][j]作比较,9大于1,6大于5,8大于2,所以最大值索引由(0,0,0,0)更新为(1,1,0,1),再和a[2][j]作比较,7大于6,9大于5所以更新为(1,2,2,1)。
再分析下面的输出.
1
2
3
4
5
|
import numpy as np
a = np.array([[ 1 , 5 , 5 , 2 ],
[ 9 , 6 , 2 , 8 ],
[ 3 , 7 , 9 , 1 ]])
print (np.argmax(a, axis = 1 ))
|
np.argmax(a, axis=1)的含义是a[i][0],a[i][1],a[i][2],a[i]3中最大值的索引。从a[i][0]开始,a[i][0]对应的索引为(0,0,0),先假定它就是最大值索引(思路和上节简单例子完全一致)拿a[i][0]和a[i][1]作比较,5大于1,7大于3所以最大值索引由(0,0,0)更新为(1,0,1),再和a[i][2]作比较,9大于7,更新为(1,0,2),再和a[i][3]作比较,不用更新,最终值为(1,0,2)
三维数组
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
import numpy as np
a = np.array([
[
[ 1 , 5 , 5 , 2 ],
[ 9 , - 6 , 2 , 8 ],
[ - 3 , 7 , - 9 , 1 ]
],
[
[ - 1 , 5 , - 5 , 2 ],
[ 9 , 6 , 2 , 8 ],
[ 3 , 7 , 9 , 1 ]
]
])
print (np.argmax(a, axis = 0 ))
|
np.argmax(a, axis=0)的含义是a[0][j][k],a[1][j][k] (j=0,1,2,k=0,1,2,3)中最大值的索引。
从a[0][j][k]开始,a[0][j][k]对应的索引为((0,0,0,0),(0,0,0,0),(0,0,0,0)),拿a[0][j][k]和a[1][j][k]对应项作比较6大于-6,3大于-3,9大于-9,所以更新这几个位置的索引,将((0,0,0,0),(0,0,0,0),(0,0,0,0))更新为((0,0,0,0),(0,1,0,0),(1,0,1,0)).。
再看axis=1的情况
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
import numpy as np
a = np.array([
[
[ 1 , 5 , 5 , 2 ],
[ 9 , - 6 , 2 , 8 ],
[ - 3 , 7 , - 9 , 1 ]
],
[
[ - 1 , 5 , - 5 , 2 ],
[ 9 , 6 , 2 , 8 ],
[ 3 , 7 , 9 , 1 ]
]
])
print (np.argmax(a, axis = 1 ))
|
np.argmax(a, axis=1)的含义是a[i][0][k],a[i][1][k] (i=0,1,k=0,1,2,3)中最大值的索引。从a[i][0][k]开始,a[i][0][k]对应的索引为((0,0,0,0),(0,0,0,0)),拿a[i][0][k]和a[i][1][k]对应项作比较,9大于1,8大于2,9大于-1,6大于5,2大于-5,8大于2,所以更新这几个位置的索引,将((0,0,0,0),(0,0,0,0))更新为((1,0,0,1),(1,1,1,1)),现在最大值对应的数组为((9,5,5,8),(9,6,2,8))。
再拿((9,5,5,8),(9,6,2,8))和a[i][2][k]对应项从比较,7大于5,7大于6,9大于2.更新这几个位置的索引。
将((1,0,0,1),(1,1,1,1))更新为((1,2,0,1),(1,2,2,1)).axis=2的情况也是类似的。
二、关于axis的理解
设置axis的主要原因是方便我们进行多个维度的计算。
通过例子来进行理解
比如:
1
2
3
4
5
6
|
a = np.array([[ 1 , 2 , 3 ],
[ 2 , 3 , 4 ],
[ 5 , 4 , 3 ],
[ 8 , 7 , 2 ]])
np.argmax(a, 0 ) #输出:array([ 3 , 3 , 1 ]
np.argmax(a, 1 ) #输出:array([ 2 , 2 , 0 , 0 ]
|
axis = 0:
你就这么想,0是最大的范围,所有的数组都要进行比较,只是比较的是这些数组相同位置上的数(我的理解是0 列比较输出):
1
2
3
4
5
|
a[ 0 ] = array([ 1 , 2 , 3 ])
a[ 1 ] = array([ 2 , 3 , 4 ])
a[ 2 ] = array([ 5 , 4 , 3 ])
a[ 3 ] = array([ 8 , 7 , 2 ])
# output : [3, 3, 1]
|
axis = 1: (行比较输出)
等于1的时候,比较范围缩小了,只会比较每个数组内的数的大小,结果也会根据有几个数组,产生几个结果。
1
2
3
4
|
a[ 0 ] = array([ 1 , 2 , 3 ]) #2
a[ 1 ] = array([ 2 , 3 , 4 ]) #2
a[ 2 ] = array([ 5 , 4 , 3 ]) #0
a[ 3 ] = array([ 8 , 7 , 2 ]) #0
|
特例
这是里面都是数组长度一致的情况,如果不一致,axis最大值为最小的数组长度-1,超过则报错。
当不一致的时候,axis=0的比较也就变成了每个数组的和的比较。
比较示例如下
当数组长度都一样时
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
import numpy as np
a = np.array([
[
[ 1 , 5 , 5 , 2 ],
[ 9 , - 6 , 2 , 8 ],
[ - 3 , 7 , - 9 , 1 ]
],
[
[ - 1 , 5 , - 5 , 2 ],
[ 9 , 6 , 2 , 8 ],
[ 3 , 7 , 9 , 1 ]
]
])
print (np.argmax(a, axis = 0 ))
print (np.argmax(a, axis = 1 ))
|
输出为
[[0 0 0 0]
[0 1 0 0]
[1 0 1 0]]
[[1 2 0 1][1 2 2 1]]
当数组长度都不一样时,
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
a = np.array([
[
[ 1 , 5 , 5 , 2 ],
[ 9 , - 6 , 2 , 8 ],
[ - 3 , 7 , - 9 , 1 ]
],
[
[ - 1 , 5 , - 5 , 2 ],
[ 9 , 6 , 2 , 8 ],
[ 3 , 7 , 9 ]
]
])
print (np.argmax(a, axis = 0 ))
print (np.argmax(a, axis = 1 ))
|
输出为
[0 1 1]
[1 1]
numpy 的argmax的参数axis=0/1的概念
对numpy的argmax一直记不得默认是行还是列搜索,总是用糊涂,每次都要查资料,今天突然醒悟。
先列后行,为什么呢?
看下面的一个列表,就知道了。
1
2
3
4
5
|
>>b = np.array([ 1 , 2 , 3 , 4 , 3 , 2 , 1 ])
>>np.argmax(b)
>> 3
>>np.argmax(b, axis = 0 )
>> 3
|
默认axis=0,列表只有一个维度,自然就是一行数据的最大数的索引。
那么对于二维向量,只需要记住axis是坐标轴的方向,不是行列的概念。
在Numpy库中:
轴用来为超过一维的数组定义的属性,二维数据拥有两个轴:
第0轴沿着行的垂直往下,第1轴沿着列的方向水平延伸。简单的来记就是axis=0代表往跨行(down),而axis=1代表跨列(across)。
所以axis=0代表的就是列查找,axis=1代表着行查找。
1
2
3
4
5
6
7
|
>>a = np.array([[ 1 , 5 , 5 , 2 ],
[ 9 , 6 , 2 , 8 ],
[ 3 , 7 , 9 , 1 ]])
>>np.argmax(a,axis = 0 )
>>array([ 1 , 2 , 2 , 1 ], dtype = int64)
>>np.argmax(a,axis = 1 )
>>array([ 1 , 0 , 2 ], dtype = int64)
|
结论:
argmax返回的是最大数的索引。argmax有一个参数axis,默认是0,表示每一列的最大值的索引,axis=1表示每一行的最大值的索引。
以上为个人经验,希望能给大家一个参考,也希望大家多多支持服务器之家。
原文链接:https://blog.csdn.net/XYKenny/article/details/98865532