在numpy中,对于多维数组进行sum,mean,min,max,sort的操作时,均会涉及到axis这一参数。那么axis具体是什么呢?
我们先引入一个切片(slices)的概念:
- 如果 k k k维数组 A ∈ R n 1 × n 2 × . . . × n k A in mathbb{R}^{n_1times n_2 times ...times n_k} A∈Rn1×n2×...×nk ,则 A A A的Slices为固定其中1个索引位置之后形成的 k − 1 k-1 k−1维数组
ok,有了slices的概念,我们如何对高维数组进行sum,mean,min,max,sort之类的运算呢?
以sum函数为例。对于一个 k k k维数组 A ∈ R n 1 × n 2 × . . . × n k A in mathbb{R}^{n_1times n_2 times ...times n_k} A∈Rn1×n2×...×nk,在axis=i上进行运算就是在第i个维度上对相应的slices进行运算。
以sum函数在axis=i上运算为例,这个过程就相当于计算:
n
p
.
s
u
m
(
A
,
a
x
i
s
=
i
)
=
∑
j
=
1
n
i
x
[
:
,
:
,
.
.
.
,
:
,
j
,
.
.
.
,
:
]
.
np.sum(A,axis=i) = sum_{j=1}^{n_i}x[:,:,...,:,j,...,:].
np.sum(A,axis=i)=j=1∑nix[:,:,...,:,j,...,:].
以sort函数在axis=i上运算为例,这个过程就相当于计算:
n
p
.
s
o
r
t
(
A
,
a
x
i
s
=
i
)
=
s
o
r
t
e
d
{
x
[
:
,
:
,
.
.
.
,
:
,
j
,
.
.
.
,
:
]
,
j
=
1
,
2
,
⋯
,
n
i
}
.
np.sort(A,axis=i) = sorted{x[:,:,...,:,j,...,:],j=1,2,cdots,n_i}.
np.sort(A,axis=i)=sorted{x[:,:,...,:,j,...,:],j=1,2,⋯,ni}.
以min函数在axis=i上运算为例,这个过程就相当于计算:
n
p
.
m
i
n
(
A
,
a
x
i
s
=
i
)
=
m
i
n
{
x
[
:
,
:
,
.
.
.
,
:
,
j
,
.
.
.
,
:
]
,
j
=
1
,
2
,
⋯
,
n
i
}
.
np.min(A,axis=i) = min{x[:,:,...,:,j,...,:],j=1,2,cdots,n_i}.
np.min(A,axis=i)=min{x[:,:,...,:,j,...,:],j=1,2,⋯,ni}.
以mean函数在axis=i上运算为例,这个过程就相当于计算:
n
p
.
m
a
x
(
A
,
a
x
i
s
=
i
)
=
m
e
a
n
{
x
[
:
,
:
,
.
.
.
,
:
,
j
,
.
.
.
,
:
]
,
j
=
1
,
2
,
⋯
,
n
i
}
.
np.max(A,axis=i) = mean{x[:,:,...,:,j,...,:],j=1,2,cdots,n_i}.
np.max(A,axis=i)=mean{x[:,:,...,:,j,...,:],j=1,2,⋯,ni}.
我们定义一个 2 × 3 × 4 2times 3 times 4 2×3×4的变量data,打印一下它在各个axis上的slices:
import numpy as np
np.random.seed(1)
data = np.random.randint(0,24,(2,3,4))
# data变量在axis=0上的全部slices:
print('data在axis=0上的全部slices:')
for slices in range(data.shape[0]):
print(data[slices,:,:])
print('---------------')
print('data在axis=1上的全部slices:')
for slices in range(data.shape[1]):
print(data[:,slices,:])
print('---------------')
print('data在axis=2上的全部slices:')
for slices in range(data.shape[2]):
print(data[:,:,slices])
print('---------------')
输出:
data在axis=0上的全部slices: [[ 5 11 12 8] [ 9 11 5 15] [ 0 16 1 12]] --------------- [[ 7 13 6 18] [20 5 18 20] [11 10 14 18]] --------------- data在axis=1上的全部slices: [[ 5 11 12 8] [ 7 13 6 18]] --------------- [[ 9 11 5 15] [20 5 18 20]] --------------- [[ 0 16 1 12] [11 10 14 18]] --------------- data在axis=2上的全部slices: [[ 5 9 0] [ 7 20 11]] --------------- [[11 11 16] [13 5 10]] --------------- [[12 5 1] [ 6 18 14]] --------------- [[ 8 15 12] [18 20 18]] ---------------np.sum函数
假设现在我们对data变量在axis=0上做sum运算:知道了data在axis=0上的slices,一个for循环就可以解决。
sum_axis0 = 0
for slices in range(data.shape[0]):
sum_axis0+=data[slices,:,:]
print(sum_axis0)
print(sum_axis0==data.sum(axis=0))
输出:
[[12 24 18 26] [29 16 23 35] [11 26 15 30]] [[ True True True True] [ True True True True] [ True True True True]]
从上面可以看出,sum作用的基本单元其实就是data在axis=0上的各个slices。同理我们可以得到axis=1,axis=2上面的sum。这里就不写了。
np.mean函数有了sum的结果,计算mean也是easy了:
sum_axis0 = 0
for slices in range(data.shape[0]):
sum_axis0+=data[slices,:,:]
mean_data = sum_axis0/2
print(mean_data)
print(mean_data==data.mean(axis=0))
输出:
[[ 6. 12. 9. 13. ] [14.5 8. 11.5 17.5] [ 5.5 13. 7.5 15. ]] [[ True True True True] [ True True True True] [ True True True True]]np.sort函数
现在来试一下sort函数。以axis=0为例,data在axis=0上的slices为:
[[ 5 11 12 8] [ 9 11 5 15] [ 0 16 1 12]] --------------- [[ 7 13 6 18] [20 5 18 20] [11 10 14 18]]
那么sort函数的作用过程就是(以升序排列为例):依次比较上面两个slices的各个元素。具体过程为:
- 对slices(0)与slices(1)的(0,0)元素进行排序:由于5<7,故排序后的array的(0,0,0)元素为5,(1,0,0)元素为7
- 对slices(0)与slices(1)的(0,1)元素进行排序:由于11<13,故排序后的array的(0,0,1)元素为11,(1,0,1)元素为13
- 对slices(0)与slices(1)的(0,2)元素进行排序:由于6<12,故排序后的array的(0,0,2)元素为6,(1,0,2)元素为12
- …
- 对slices(0)与slices(1)的(2,3)元素进行排序:由于8<18,故排序后的array的(0,2,3)元素为12,(1,2,3)元素为18
具体代码为:
slices0 = data[0,:,:]
slices1 = data[1,:,:]
sorted_data = np.zeros((2,3,4),dtype=int)
_,row,col = data.shape
for i in range(row):
for j in range(col):
sorted_data[0,i,j],sorted_data[1,i,j] = sorted([slices0[i,j],slices1[i,j]])
print(sorted_data)
default_sort_data = np.sort(data,axis=0)
print(sorted_data==default_sort_data)
输出
[[[ 5 11 6 8] [ 9 5 5 15] [ 0 10 1 12]] [[ 7 13 12 18] [20 11 18 20] [11 16 14 18]]] [[[ True True True True] [ True True True True] [ True True True True]] [[ True True True True] [ True True True True] [ True True True True]]]np.min与np.max函数
好哒,那么min和max函数也是手到擒来了
slices0 = data[0,:,:]
slices1 = data[1,:,:]
min_data = np.zeros((3,4),dtype=int)
_,row,col = data.shape
for i in range(row):
for j in range(col):
min_data[i,j]= min([slices0[i,j],slices1[i,j]])
print(min_data)
default_min_data = np.min(data,axis=0)
print(min_data==default_min_data)
输出:
[[ 5 11 6 8] [ 9 5 5 15] [ 0 10 1 12]] [[ True True True True] [ True True True True] [ True True True True]]总结
从以上各个例子中我们还可以发现,sum、mean、min和max函数作用之后得到的结果与axis=i上的slices的维度相同,但是sort函数作用之后的结果是与原始数组的维度相同的。具体原因通过代码也是可以看出来的,简单提一下就是了,sort函数结果与原始数组相同是因为它是先对axis=i下的全部slices进行一个排序操作,然后再把这些排序后的结果合在一起,所以sort函数的结果维度与原始数组维度保持一致。



