numpy 中的Axis(轴)含义 np.newaxis numpy.expand_dims

时间:2022-04-20 23:50:20

理解numpy中的轴:

:表示当前维的所有索引值都取
import numpy as np
t = np.array(
[
[
[
[1,2,3],
[4,5,6]
],
[
[7,8,9],
[10,11,12]
],
[
[13,14,15],
[16,17,18]
]
],
[
[
[19,20,21],
[22,23,24]
],
[
[25,26,27],
[28,29,30]
],
[
[31,32,33],
[34,35,36]
]
]
])
print(t[0,:,:,:])
[[[ 1 2 3]
[ 4 5 6]]


[[ 7 8 9]
[10 11 12]]


[[13 14 15]
[16 17 18]]
]
print(t[:,0,:,:])
[[[ 1 2 3]
[ 4 5 6]]


[[19 20 21]
[22 23 24]]
]
print(t[:,:,0,:])
[[[ 1 2 3]
[ 7 8 9]
[13 14 15]]


[[19 20 21]
[25 26 27]
[31 32 33]]
]
print(t[:,:,:,0])
[[[ 1 4]
[ 7 10]
[13 16]]


[[19 22]
[25 28]
[31 34]]
]

np.newaxis:
np.newaxis相当于新插入一个轴

a=np.array([1,2,3,4,5])
b=a[np.newaxis,:]
print a.shape,b.shape
print a
print b

输出结果:
(5,) (1, 5)
[1 2 3 4 5]
[[1 2 3 4 5]]
a=np.array([1,2,3,4,5])
b=a[:,np.newaxis]
print a.shape,b.shape
print a
print b

输出结果
(5,) (5, 1)
[1 2 3 4 5]
[[1]
[2]
[3]
[4]
[5]]

numpy.expand_dims
numpy.expand_dims同样是用于扩充数组维度

>>> x = np.array([1,2])
>>> x.shape
(2,)

>>> y = np.expand_dims(x, axis=0) #等价于 x[np.newaxis,:]或x[np.newaxis]
>>> y
array([[1, 2]])
>>> y.shape
(1, 2) #看np.newaxis位置(在:之前)可知插入在2之前

>>> y = np.expand_dims(x, axis=1) #等价于x[:,newaxis]
>>> y
array([[1],
[2]]
)
>>> y.shape
(2, 1) #看np.newaxis位置(在:之后)可知插入在2之后

>>> np.newaxis is None
True

二维情况:

x = np.array([[1,2,3],[4,5,6]])
print x
print x.shape

[[1 2 3]
[4 5 6]]

(2, 3)

y = np.expand_dims(x,axis=0)
print y
print "y.shape: ",y.shape
print "y[0][1]: ",y[0][1]

[[[1 2 3]
[4 5 6]]]
y.shape: (1, 2, 3)
y[0][1]: [4 5 6]

y = np.expand_dims(x,axis=1)
print y
print "y.shape: ",y.shape
print "y[1][0]: ",y[1][0]

[[[1 2 3]]

[[4 5 6]]]
y.shape: (2, 1, 3)
y[1][0]: [4 5 6]

y = np.expand_dims(x,axis=3)
print y
print "y.shape: ",y.shape


[[[1]
[2]
[3]]


[[4]
[5]
[6]]
]
y.shape: (2, 3, 1)