栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

快速组合,无需替换数组-NumPy / Python

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

快速组合,无需替换数组-NumPy / Python

I.成对组合

一种方法是

numba
获取内存,从而提高性能-

from numba import njit@njitdef pairwise_combs_numba(a):    n = len(a)    L = n*(n-1)//2    out = np.empty((L,2),dtype=a.dtype)    iterID = 0    for i in range(n):        for j in range(i+1,n): out[iterID,0] = a[i] out[iterID,1] = a[j] iterID += 1    return out

另一个基于NumPy的

np.broadcast_to
控件将用于获取网格视图,然后进行遮罩-

def pairwise_combs_mask(a):    n = len(a)    L = n*(n-1)//2    out = np.empty((L,2),dtype=a.dtype)    m = ~np.tri(len(a),dtype=bool)    out[:,0] = np.broadcast_to(a[:,None],(n,n))[m]    out[:,1] = np.broadcast_to(a,(n,n))[m]    return out

二。三联体组合

我们将扩展相同的方法,以使自己成为三元组合-

@njitdef triplet_combs_numba(a):    n = len(a)    L = n*(n-1)*(n-2)//6    out = np.empty((L,3),dtype=a.dtype)    iterID = 0    for i in range(n):        for j in range(i+1,n): for k in range(j+1,n):     out[iterID,0] = a[i]     out[iterID,1] = a[j]     out[iterID,2] = a[k]     iterID += 1    return outdef triplet_combs_mask(a):    n = len(a)    L = n*(n-1)*(n-2)//6    out = np.empty((L,3),dtype=a.dtype)    r = np.arange(n)    m = (r[:,None,None]<r[:,None]) & (r[:,None]<r)    out[:,0] = np.broadcast_to(a[:,None,None],(n,n,n))[m]    out[:,1] = np.broadcast_to(a[None,:,None],(n,n,n))[m]    out[:,2] = np.broadcast_to(a[None,None,:],(n,n,n))[m]    return out

高阶组合将同样扩展。

样品运行-

In [54]: a = np.array([3,9,4,1,7])In [55]: pairwise_combs_numba(a)Out[55]: array([[3, 9],       [3, 4],       [3, 1],       [3, 7],       [9, 4],       [9, 1],       [9, 7],       [4, 1],       [4, 7],       [1, 7]])In [56]: triplet_combs_numba(a)Out[56]: array([[3, 9, 4],       [3, 9, 1],       [3, 9, 7],       [3, 4, 1],       [3, 4, 7],       [3, 1, 7],       [9, 4, 1],       [9, 4, 7],       [9, 1, 7],       [4, 1, 7]])

时间(包括Python的内置-

itertools.combinations
)-

In [68]: a = np.random.rand(4000)In [69]: %timeit pairwise_combs_numba(a)    ...: %timeit pairwise_combs_mask(a)    ...: %timeit list(itertools.combinations(a, 2))10 loops, best of 3: 52.2 ms per loop10 loops, best of 3: 146 ms per loop1 loop, best of 3: 597 ms per loopIn [70]: a = np.random.rand(400)In [71]: %timeit triplet_combs_numba(a)    ...: %timeit triplet_combs_mask(a)    ...: %timeit list(itertools.combinations(a, 3))10 loops, best of 3: 98.5 ms per loop1 loop, best of 3: 352 ms per loop1 loop, best of 3: 795 ms per loop


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/626602.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号