I.成对组合
一种方法是
numba获取内存,从而提高性能-
from numba import njit@njitdef pairwise_combs_numba(a): n = len(a) L = n*(n-1)//2 out = np.empty((L,2),dtype=a.dtype) iterID = 0 for i in range(n): for j in range(i+1,n): out[iterID,0] = a[i] out[iterID,1] = a[j] iterID += 1 return out
另一个基于NumPy的
np.broadcast_to控件将用于获取网格视图,然后进行遮罩-
def pairwise_combs_mask(a): n = len(a) L = n*(n-1)//2 out = np.empty((L,2),dtype=a.dtype) m = ~np.tri(len(a),dtype=bool) out[:,0] = np.broadcast_to(a[:,None],(n,n))[m] out[:,1] = np.broadcast_to(a,(n,n))[m] return out
二。三联体组合
我们将扩展相同的方法,以使自己成为三元组合-
@njitdef triplet_combs_numba(a): n = len(a) L = n*(n-1)*(n-2)//6 out = np.empty((L,3),dtype=a.dtype) iterID = 0 for i in range(n): for j in range(i+1,n): for k in range(j+1,n): out[iterID,0] = a[i] out[iterID,1] = a[j] out[iterID,2] = a[k] iterID += 1 return outdef triplet_combs_mask(a): n = len(a) L = n*(n-1)*(n-2)//6 out = np.empty((L,3),dtype=a.dtype) r = np.arange(n) m = (r[:,None,None]<r[:,None]) & (r[:,None]<r) out[:,0] = np.broadcast_to(a[:,None,None],(n,n,n))[m] out[:,1] = np.broadcast_to(a[None,:,None],(n,n,n))[m] out[:,2] = np.broadcast_to(a[None,None,:],(n,n,n))[m] return out
高阶组合将同样扩展。
样品运行-
In [54]: a = np.array([3,9,4,1,7])In [55]: pairwise_combs_numba(a)Out[55]: array([[3, 9], [3, 4], [3, 1], [3, 7], [9, 4], [9, 1], [9, 7], [4, 1], [4, 7], [1, 7]])In [56]: triplet_combs_numba(a)Out[56]: array([[3, 9, 4], [3, 9, 1], [3, 9, 7], [3, 4, 1], [3, 4, 7], [3, 1, 7], [9, 4, 1], [9, 4, 7], [9, 1, 7], [4, 1, 7]])
时间(包括Python的内置-
itertools.combinations)-
In [68]: a = np.random.rand(4000)In [69]: %timeit pairwise_combs_numba(a) ...: %timeit pairwise_combs_mask(a) ...: %timeit list(itertools.combinations(a, 2))10 loops, best of 3: 52.2 ms per loop10 loops, best of 3: 146 ms per loop1 loop, best of 3: 597 ms per loopIn [70]: a = np.random.rand(400)In [71]: %timeit triplet_combs_numba(a) ...: %timeit triplet_combs_mask(a) ...: %timeit list(itertools.combinations(a, 3))10 loops, best of 3: 98.5 ms per loop1 loop, best of 3: 352 ms per loop1 loop, best of 3: 795 ms per loop



