栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

检查两个3D numpy数组是否包含重叠的2D数组

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

检查两个3D numpy数组是否包含重叠的2D数组

我们可以利用

views
我在一些问答中使用过的辅助功能来发挥作用。要获得子数组的存在,我们可以
np.isin
在视图上使用或使用更加费力的视图
np.searchsorted

方法1: 使用

np.isin
-

# https://stackoverflow.com/a/45313353/ @Divakardef view1D(a, b): # a, b are arrays    a = np.ascontiguousarray(a)    b = np.ascontiguousarray(b)    void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))    return a.view(void_dt).ravel(),  b.view(void_dt).ravel()def isin_nd(a,b):    # a,b are the 3D input arrays to give us "isin-like" functionality across them    A,B = view1D(a.reshape(a.shape[0],-1),b.reshape(b.shape[0],-1))    return np.isin(A,B)

方法2: 我们也可以利用

np.searchsorted
views
-

def isin_nd_searchsorted(a,b):    # a,b are the 3D input arrays    A,B = view1D(a.reshape(a.shape[0],-1),b.reshape(b.shape[0],-1))    sidx = A.argsort()    sorted_index = np.searchsorted(A,B,sorter=sidx)    sorted_index[sorted_index==len(A)] = len(A)-1    idx = sidx[sorted_index]    return A[idx] == B

因此,这两个解决方案为我们提供了

a
in中每个子数组的存在掩码
b
。因此,为了获得所需的计数,它应该是-
isin_nd(a,b).sum()
isin_nd_searchsorted(a,b).sum()

样品运行-

In [71]: # Setup with 3 common "subarrays"    ...: np.random.seed(0)    ...: a = np.random.randint(0,9,(10,4,5))    ...: b = np.random.randint(0,9,(7,4,5))    ...:     ...: b[1] = a[4]    ...: b[3] = a[2]    ...: b[6] = a[0]In [72]: isin_nd(a,b).sum()Out[72]: 3In [73]: isin_nd_searchsorted(a,b).sum()Out[73]: 3

大型阵列上的时间-

In [74]: # Setup    ...: np.random.seed(0)    ...: a = np.random.randint(0,9,(100,100,100))    ...: b = np.random.randint(0,9,(100,100,100))    ...: idxa = np.random.choice(range(len(a)), len(a)//2, replace=False)    ...: idxb = np.random.choice(range(len(b)), len(b)//2, replace=False)    ...: a[idxa] = b[idxb]# Verify outputIn [82]: np.allclose(isin_nd(a,b),isin_nd_searchsorted(a,b))Out[82]: TrueIn [75]: %timeit isin_nd(a,b).sum()10 loops, best of 3: 31.2 ms per loopIn [76]: %timeit isin_nd_searchsorted(a,b).sum()100 loops, best of 3: 1.98 ms per loop


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/637710.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号