numpy之axis如何理解

时间:2021-07-04 21:25:05

转载自:

fangjian1204


http://blog.csdn.net/fangjian1204/article/details/53055219


简单的来记就是axis=0代表往跨行(down),而axis=1代表跨列(across)

python进行科学计算必不可少的模块,随着深度学习越来越火,numpy也越来越流行。了解numpy的人知道,在numpy中,有很多的函数都涉及到axis,很多函数根据axis的取值不同,得到的结果也完全不同。可以说,axis让numpy的多维数组变的更加灵活,但也让numpy变得越发难以理解。这里通过详细的例子来学习下,axis到底是什么,它在numpy中的作用到底如何。

为什么会有axis这个东西,原因很简单:numpy是针对矩阵或者多为数组进行运算的,而在多维数组中,对数据的操作有太多的可能,我们先来看一个例子。比如我们有一个二维数组:

>>> import numpy as np
>>> data = np.array([
... [1,2,1],
... [0,3,1],
... [2,1,4],
... [1,3,1]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

这个数组代表了样本数据的特征,其中每一行代表一个样本的三个特征,每一列是不同样本的特征。如果在分析样本的过程中需要对每个样本的三个特征求和,该如何处理?简单:

>>> np.sum(data, axis=1)
array([4, 4, 7, 5])
  • 1
  • 2

那如果想求每种特征的最小值,该如何处理?也简单:

>>> np.min(data, axis=0)
array([0, 1, 1])
  • 1
  • 2

又如果想得知所有样本所有特征的平均值呢?还是很简单:

>>> np.average(data)
1.6666666666666667
  • 1
  • 2

由此可以看出,通过不同的axis,numpy会沿着不同的方向进行操作:如果不设置,那么对所有的元素操作;如果axis=0,则沿着纵轴进行操作;axis=1,则沿着横轴进行操作。但这只是简单的二位数组,如果是多维的呢?可以总结为一句话:设axis=i,则numpy沿着第i个下标变化的放下进行操作。例如刚刚的例子,可以将表示为:data =[[a00, a01],[a10,a11]],所以axis=0时,沿着第0个下标变化的方向进行操作,也就是a00->a10, a01->a11,也就是纵坐标的方向,axis=1时也类似。下面我们举一个四维的求sum的例子来验证一下:

>>> data = np.random.randint(0, 5, [4,3,2,3])
>>> data
array([[[[4, 1, 0], [4, 3, 0]], [[1, 2, 4], [2, 2, 3]], [[4, 3, 3], [4, 2, 3]]], [[[4, 0, 1], [1, 1, 1]], [[0, 1, 0], [0, 4, 1]], [[1, 3, 0], [0, 3, 0]]], [[[3, 3, 4], [0, 1, 0]], [[1, 2, 3], [4, 0, 4]], [[1, 4, 1], [1, 3, 2]]], [[[0, 1, 1], [2, 4, 3]], [[4, 1, 4], [1, 4, 1]], [[0, 1, 0], [2, 4, 3]]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

当axis=0时,numpy验证第0维的方向来求和,也就是第一个元素值=a0000+a1000+a2000+a3000=11,第二个元素=a0001+a1001+a2001+a3001=5,同理可得最后的结果如下:

>>> data.sum(axis=0)
array([[[11, 5, 6], [ 7, 9, 4]],

       [[ 6, 6, 11], [ 7, 10, 9]],

       [[ 6, 11, 4], [ 7, 12, 8]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

当axis=3时,numpy验证第3维的方向来求和,也就是第一个元素值=a0000+a0001+a0002=5,第二个元素=a0010+a0011+a0012=7,同理可得最后的结果如下:

>>> data.sum(axis=3)
array([[[ 5, 7], [ 7, 7], [10, 9]],

       [[ 5, 3], [ 1, 5], [ 4, 3]],

       [[10, 1], [ 6, 8], [ 6, 6]],

       [[ 2, 9], [ 9, 6], [ 1, 9]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

使用axis的相关函数

在numpy中,使用的axis的地方非常多,处理上文已经提到的average、max、min、sum,比较常见的还有sort和prod,下面分别举几个例子看一下:

  • sort
>>> data = np.random.randint(0, 5, [3,2,3])
>>> data
array([[[4, 2, 0], [0, 0, 4]],

       [[2, 1, 1], [1, 0, 2]],

       [[3, 0, 4], [0, 1, 3]]])
>>> np.sort(data)  ## 默认对最大的axis进行排序,这里即是axis=2
array([[[0, 2, 4], [0, 0, 4]],

       [[1, 1, 2], [0, 1, 2]],

       [[0, 3, 4], [0, 1, 3]]])
>>> np.sort(data, axis=0)  # 沿着第0维进行排序,原先的a000->a100->a200转变为a100->a200->a000
array([[[2, 0, 0], [0, 0, 2]],

       [[3, 1, 1], [0, 0, 3]],

       [[4, 2, 4], [1, 1, 4]]])
>>> np.sort(data, axis=1)  # 沿着第1维进行排序
array([[[0, 0, 0], [4, 2, 4]],

       [[1, 0, 1], [2, 1, 2]],

       [[0, 0, 3], [3, 1, 4]]])
>>> np.sort(data, axis=2)  # 沿着第2维进行排序
array([[[0, 2, 4], [0, 0, 4]],

       [[1, 1, 2], [0, 1, 2]],

       [[0, 3, 4], [0, 1, 3]]])
>>> np.sort(data, axis=None)  # 对全部数据进行排序
array([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • prod(即product,乘积)
 >>> np.prod([[1.,2.],[3.,4.]])
 24.0

 >>> np.prod([[1.,2.],[3.,4.]], axis=1)
 array([  2.,  12.])

 >>> np.prod([[1.,2.],[3.,4.]], axis=0)
 array([ 3.,  8.])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

相信通过上面的讲解与例子,你应该对axis有了比较清楚的了解。个人认为,如果没有理解axis的真正含义,很难熟悉的运用numpy进行数据处理

fangjian1204