TensorFlow、Numpy中的axis的理解

时间:2021-06-27 21:25:34

TensorFlow中有很多函数涉及到axis,比如tf.reduce_mean(),其函数原型如下:

1 def reduce_mean(input_tensor,
2                 axis=None,
3                 keepdims=None,
4                 name=None,
5                 reduction_indices=None,
6                 keep_dims=None):

其中axis表示的是,对该维度进行求均值(默认情况下,是对所有值求均值)。
除了TensorFlow中,numpy中也经常遇到很多对矩阵操作的函数会涉及axis操作。比如np.mean(),其函数原型如下:

1 def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue):

想要弄清楚如何处理涉及axis(维度)的操作,必须先明白axis是什么。
首先axis是维度,如果axis=0则对应着高; 如果axis=1则对应着行处理;如果axis=2则对应着列;如果axis=3…n(无法用直观的图来表示)。我相信很多人看到这还是会一头雾水。什么是高,行还有列。为了说明这个问题,我举个列子:

data=[[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]],[[10,11,12],[100,110,120]],[[7,8,9],[77,88,99]]]
data_np=np.array(data)
print(data_np)
[[[  1   2   3]
  [ 11  22  33]]

 [[  4   5   6]
  [ 44  55  66]]

 [[ 10  11  12]
  [100 110 120]]

 [[  7   8   9]
  [ 77  88  99]]]
  
如上面,可以将最外层[ ]去掉,可以发现有4组元素(这里的元素是矩阵),你可以将其理解为高。
再从这3组元素中选取一组,比如选择的是
[[  1   2   3]
  [ 11  22  33]]
然后将该组的最外层[ ]去掉,可以发现有2组元素分别为[  1   2   3]和 [ 11  22  33],此时对应的是行。
在从这两组元素中选组一组,比如选择的是
 [ 11  22  33]
 现在无需去掉最外层的[ ]了,一眼就能看出里面有3个元素。这就是对应的列。
 理解了上面的分析后,很容易就知道(高,行,列)对应的其实就是改矩阵的shape.
print(data_np.shape):
(4,2,3)

现在弄清楚了axis的值与(高,行,列)的关系后,再来分析tf.reduce_mean()或者np.mean()等函数是如何对axis进行操作的。

 1 data=[[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]],[[10,11,12],[100,110,120]],[[7,8,9],[77,88,99]]]
 2 
 3 data_tensor=tf.constant(data,dtype=tf.float32)
 4 
 5 mean_axis0=tf.reduce_mean(data_tensor,axis=0)
 6 mean_axis1=tf.reduce_mean(data_tensor,axis=1)
 7 mean_axis2=tf.reduce_mean(data_tensor,axis=2)
 8 
 9 with tf.Session() as sess:
10     print(sess.run(mean_axis0))
11     print(sess.run(mean_axis1))
12     print(sess.run(mean_axis2))

针对上述代码,我们先对axis=0维度的数据处理进行分析。
首先对上述data数据进行立体化变换,如下图(本人本想用软件来绘制3D的矩阵叠加效果,可惜找了很多软件都不适合,也许是本人寻找的还不够,欢迎有知道可以绘制3D的矩阵叠加效果的朋友们,能够分享一下。感激…)

TensorFlow、Numpy中的axis的理解

如上如,axis=0的维度数据求均值,

[[(1+4+10+7)/4         (2+5+11+8)/4       (3+6+12+9)/4]
[(11+44+100+77)/4      (22+55+110+88)/4   (33+66+120+99)/4]]
=
[[ 5.5   6.5   7.5 ]
 [58.   68.75 79.5 ]]

同理,对axis=1的维度数据求均值,

[[(1+11)/2    (2+22)/2    (3+33)/2]
 [(4+44)/2    (5+55)/2    (6+66)/2]
 [(10+100)/2  (11+110)/2  (12+120)/2]
 [(7+77)/2    (8+88)/2    (9+99)/2]]
 =
 [[ 6.  12.  18. ]
 [24.  30.  36. ]
 [55.  60.5 66. ]
 [42.  48.  54. ]]

同理可得axis=2维度的数据平均值为(过程留给读者去推,运算结果如下):

[[  2.  22.]
 [  5.  55.]
 [ 11. 110.]
 [  8.  88.]]

在python的世界里,有很多时候都需要对数据进行维度的操作,如果对axis理解的不透的话,很容易找不着方向。

更多干货请关注:

TensorFlow、Numpy中的axis的理解