栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

对于numpy包中的axis参数的理解

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

对于numpy包中的axis参数的理解

对于numpy包中的axis参数的理解

在numpy中,对于多维数组进行sum,mean,min,max,sort的操作时,均会涉及到axis这一参数。那么axis具体是什么呢?

我们先引入一个切片(slices)的概念:

  • 如果 k k k维数组 A ∈ R n 1 × n 2 × . . . × n k A in mathbb{R}^{n_1times n_2 times ...times n_k} A∈Rn1​×n2​×...×nk​ ,则 A A A的Slices为固定其中1个索引位置之后形成的 k − 1 k-1 k−1维数组

ok,有了slices的概念,我们如何对高维数组进行sum,mean,min,max,sort之类的运算呢?

以sum函数为例。对于一个 k k k维数组 A ∈ R n 1 × n 2 × . . . × n k A in mathbb{R}^{n_1times n_2 times ...times n_k} A∈Rn1​×n2​×...×nk​,在axis=i上进行运算就是在第i个维度上对相应的slices进行运算。

以sum函数在axis=i上运算为例,这个过程就相当于计算:
n p . s u m ( A , a x i s = i ) = ∑ j = 1 n i x [ : , : , . . . , : , j , . . . , : ] . np.sum(A,axis=i) = sum_{j=1}^{n_i}x[:,:,...,:,j,...,:]. np.sum(A,axis=i)=j=1∑ni​​x[:,:,...,:,j,...,:].

以sort函数在axis=i上运算为例,这个过程就相当于计算:
n p . s o r t ( A , a x i s = i ) = s o r t e d { x [ : , : , . . . , : , j , . . . , : ] , j = 1 , 2 , ⋯   , n i } . np.sort(A,axis=i) = sorted{x[:,:,...,:,j,...,:],j=1,2,cdots,n_i}. np.sort(A,axis=i)=sorted{x[:,:,...,:,j,...,:],j=1,2,⋯,ni​}.

以min函数在axis=i上运算为例,这个过程就相当于计算:
n p . m i n ( A , a x i s = i ) = m i n { x [ : , : , . . . , : , j , . . . , : ] , j = 1 , 2 , ⋯   , n i } . np.min(A,axis=i) = min{x[:,:,...,:,j,...,:],j=1,2,cdots,n_i}. np.min(A,axis=i)=min{x[:,:,...,:,j,...,:],j=1,2,⋯,ni​}.

以mean函数在axis=i上运算为例,这个过程就相当于计算:
n p . m a x ( A , a x i s = i ) = m e a n { x [ : , : , . . . , : , j , . . . , : ] , j = 1 , 2 , ⋯   , n i } . np.max(A,axis=i) = mean{x[:,:,...,:,j,...,:],j=1,2,cdots,n_i}. np.max(A,axis=i)=mean{x[:,:,...,:,j,...,:],j=1,2,⋯,ni​}.

具体操作

我们定义一个 2 × 3 × 4 2times 3 times 4 2×3×4的变量data,打印一下它在各个axis上的slices:

import numpy as np

np.random.seed(1)
data = np.random.randint(0,24,(2,3,4))
# data变量在axis=0上的全部slices:
print('data在axis=0上的全部slices:')
for slices in range(data.shape[0]):  
    print(data[slices,:,:])
    print('---------------')
print('data在axis=1上的全部slices:')
for slices in range(data.shape[1]):  
    print(data[:,slices,:])
    print('---------------')
print('data在axis=2上的全部slices:')
for slices in range(data.shape[2]):  
    print(data[:,:,slices])
    print('---------------')

输出:

data在axis=0上的全部slices:
[[ 5 11 12  8]
 [ 9 11  5 15]
 [ 0 16  1 12]]
---------------
[[ 7 13  6 18]
 [20  5 18 20]
 [11 10 14 18]]
---------------
data在axis=1上的全部slices:
[[ 5 11 12  8]
 [ 7 13  6 18]]
---------------
[[ 9 11  5 15]
 [20  5 18 20]]
---------------
[[ 0 16  1 12]
 [11 10 14 18]]
---------------
data在axis=2上的全部slices:
[[ 5  9  0]
 [ 7 20 11]]
---------------
[[11 11 16]
 [13  5 10]]
---------------
[[12  5  1]
 [ 6 18 14]]
---------------
[[ 8 15 12]
 [18 20 18]]
---------------
np.sum函数

假设现在我们对data变量在axis=0上做sum运算:知道了data在axis=0上的slices,一个for循环就可以解决。

sum_axis0 = 0
for slices in range(data.shape[0]):  
    sum_axis0+=data[slices,:,:]
print(sum_axis0)
print(sum_axis0==data.sum(axis=0))

输出:

[[12 24 18 26]
 [29 16 23 35]
 [11 26 15 30]]
[[ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]]

从上面可以看出,sum作用的基本单元其实就是data在axis=0上的各个slices。同理我们可以得到axis=1,axis=2上面的sum。这里就不写了。

np.mean函数

有了sum的结果,计算mean也是easy了:

sum_axis0 = 0
for slices in range(data.shape[0]):  
    sum_axis0+=data[slices,:,:]
mean_data = sum_axis0/2
print(mean_data)
print(mean_data==data.mean(axis=0))

输出:

[[ 6.  12.   9.  13. ]
 [14.5  8.  11.5 17.5]
 [ 5.5 13.   7.5 15. ]]
[[ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]]
np.sort函数

现在来试一下sort函数。以axis=0为例,data在axis=0上的slices为:

[[ 5 11 12  8]
 [ 9 11  5 15]
 [ 0 16  1 12]]
---------------
[[ 7 13  6 18]
 [20  5 18 20]
 [11 10 14 18]]

那么sort函数的作用过程就是(以升序排列为例):依次比较上面两个slices的各个元素。具体过程为:

  • 对slices(0)与slices(1)的(0,0)元素进行排序:由于5<7,故排序后的array的(0,0,0)元素为5,(1,0,0)元素为7
  • 对slices(0)与slices(1)的(0,1)元素进行排序:由于11<13,故排序后的array的(0,0,1)元素为11,(1,0,1)元素为13
  • 对slices(0)与slices(1)的(0,2)元素进行排序:由于6<12,故排序后的array的(0,0,2)元素为6,(1,0,2)元素为12
  • 对slices(0)与slices(1)的(2,3)元素进行排序:由于8<18,故排序后的array的(0,2,3)元素为12,(1,2,3)元素为18

具体代码为:

slices0 = data[0,:,:]
slices1 = data[1,:,:]
sorted_data  =  np.zeros((2,3,4),dtype=int)
_,row,col = data.shape
for i in range(row):
    for j in range(col):
        sorted_data[0,i,j],sorted_data[1,i,j] = sorted([slices0[i,j],slices1[i,j]])

print(sorted_data)
default_sort_data = np.sort(data,axis=0)
print(sorted_data==default_sort_data)

输出

[[[ 5 11  6  8]
  [ 9  5  5 15]
  [ 0 10  1 12]]

 [[ 7 13 12 18]
  [20 11 18 20]
  [11 16 14 18]]]
[[[ True  True  True  True]
  [ True  True  True  True]
  [ True  True  True  True]]

 [[ True  True  True  True]
  [ True  True  True  True]
  [ True  True  True  True]]]
np.min与np.max函数

好哒,那么min和max函数也是手到擒来了

slices0 = data[0,:,:]
slices1 = data[1,:,:]
min_data  =  np.zeros((3,4),dtype=int)
_,row,col = data.shape
for i in range(row):
    for j in range(col):
        min_data[i,j]= min([slices0[i,j],slices1[i,j]])

print(min_data)
default_min_data = np.min(data,axis=0)
print(min_data==default_min_data)

输出:

[[ 5 11  6  8]
 [ 9  5  5 15]
 [ 0 10  1 12]]
[[ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]]
总结

从以上各个例子中我们还可以发现,sum、mean、min和max函数作用之后得到的结果与axis=i上的slices的维度相同,但是sort函数作用之后的结果是与原始数组的维度相同的。具体原因通过代码也是可以看出来的,简单提一下就是了,sort函数结果与原始数组相同是因为它是先对axis=i下的全部slices进行一个排序操作,然后再把这些排序后的结果合在一起,所以sort函数的结果维度与原始数组维度保持一致。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/286180.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号