我们真正想做的是使用
np.in1d…,除了
np.in1d仅适用于一维数组。我们的数组是多维的。然而,我们可以 看到 阵列为的1维阵列
的字符串 :
arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))
例如,
In [15]: arr = np.array([[1, 2], [2, 3], [1, 3]])In [16]: arr = arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))In [30]: arr.dtypeOut[30]: dtype('V16')In [31]: arr.shapeOut[31]: (3, 1)In [37]: arrOut[37]: array([[b'x01x00x00x00x00x00x00x00x02x00x00x00x00x00x00x00'], [b'x02x00x00x00x00x00x00x00x03x00x00x00x00x00x00x00'], [b'x01x00x00x00x00x00x00x00x03x00x00x00x00x00x00x00']], dtype='|V16')这使得
arr字符串的每一行。现在,只需要将其连接到
np.in1d:
import numpy as npdef asvoid(arr): """ based on http://stackoverflow.com/a/16973510/190597 (Jaime, 2013-06) View the array as dtype np.void (bytes). The items along the last axis are viewed as one value. This allows comparisons to be performed on the entire row. """ arr = np.ascontiguousarray(arr) if np.issubdtype(arr.dtype, np.floating): """ Care needs to be taken here since np.array([-0.]).view(np.void) != np.array([0.]).view(np.void) Adding 0. converts -0. to 0. """ arr += 0. return arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))def inNd(a, b, assume_unique=False): a = asvoid(a) b = asvoid(b) return np.in1d(a, b, assume_unique)tests = [ (np.array([[1, 2], [2, 3], [1, 3]]), np.array([[2, 2], [3, 3], [4, 4]]), np.array([False, False, False])), (np.array([[1, 2], [2, 2], [1, 3]]), np.array([[2, 2], [3, 3], [4, 4]]), np.array([True, False, False])), (np.array([[1, 2], [3, 4], [5, 6]]), np.array([[1, 2], [3, 4], [7, 8]]), np.array([True, True, False])), (np.array([[1, 2], [5, 6], [3, 4]]), np.array([[1, 2], [5, 6], [7, 8]]), np.array([True, True, False])), (np.array([[-0.5, 2.5, -2, 100, 2], [5, 6, 7, 8, 9], [3, 4, 5, 6, 7]]), np.array([[1.0, 2, 3, 4, 5], [5, 6, 7, 8, 9], [-0.5, 2.5, -2, 100, 2]]), np.array([False, True, True]))]for a, b, answer in tests: result = inNd(b, a) try: assert np.all(answer == result) except AssertionError: print('''a:{a}b:{b}answer: {answer}result: {result}'''.format(**locals())) raiseelse: print('Success!')产量
Success!



