栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

块元素逐点积

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

块元素逐点积

方法1

使用

np.einsum
-

np.einsum('ijkl,ilm->ijkm',m0,m1)

涉及的步骤:

  • 保持输入的第一个轴对齐。

  • 在减少总和中使最后一个轴相

    m0
    对于第二个轴丢失
    m1

  • 让其余的轴以外积方式从元素展开

    m0
    m1
    展开 /扩展。


方法#2

如果您正在寻找性能并且求和轴的长度较小,那么最好使用单循环并使用

matrix-multiplication
with
np.tensordot
,例如-

s0,s1,s2,s3 = m0.shapes4 = m1.shape[-1]r = np.empty((s0,s1,s2,s4))for i in range(s0):    r[i] = np.tensordot(m0[i],m1[i],axes=([2],[0]))

方法#3

现在,

np.dot
可以将其有效地用于2D输入,以进一步提高性能。因此,有了它,修改后的版本虽然更长一些,但希望性能最好的版本是-

s0,s1,s2,s3 = m0.shapes4 = m1.shape[-1]m0.shape = s0,s1*s2,s3   # Get m0 as 3D for temporary usager = np.empty((s0,s1*s2,s4))for i in range(s0):    r[i] = m0[i].dot(m1[i])r.shape = s0,s1,s2,s4m0.shape = s0,s1,s2,s3  # Put m0 back to 4D

运行时测试

功能定义-

def original_app(m0, m1):    s0,s1,s2,s3 = m0.shape    s4 = m1.shape[-1]    r = np.empty((s0,s1,s2,s4))    for i in range(s0):        for j in range(s1): r[i, j] = np.dot(m0[i, j], m1[i])    return rdef einsum_app(m0, m1):    return np.einsum('ijkl,ilm->ijkm',m0,m1)def tensordot_app(m0, m1):    s0,s1,s2,s3 = m0.shape    s4 = m1.shape[-1]    r = np.empty((s0,s1,s2,s4))    for i in range(s0):        r[i] = np.tensordot(m0[i],m1[i],axes=([2],[0]))    return rdef dot_app(m0, m1):    s0,s1,s2,s3 = m0.shape    s4 = m1.shape[-1]    m0.shape = s0,s1*s2,s3   # Get m0 as 3D for temporary usage    r = np.empty((s0,s1*s2,s4))    for i in range(s0):        r[i] = m0[i].dot(m1[i])    r.shape = s0,s1,s2,s4    m0.shape = s0,s1,s2,s3  # Put m0 back to 4D    return r

时间和验证-

In [291]: # Inputs     ...: m0 = np.random.rand(50,30,20,20)     ...: m1 = np.random.rand(50,20,20)     ...:In [292]: out1 = original_app(m0, m1)     ...: out2 = einsum_app(m0, m1)     ...: out3 = tensordot_app(m0, m1)     ...: out4 = dot_app(m0, m1)     ...:      ...: print np.allclose(out1, out2)     ...: print np.allclose(out1, out3)     ...: print np.allclose(out1, out4)     ...: TrueTrueTrueIn [293]: %timeit original_app(m0, m1)     ...: %timeit einsum_app(m0, m1)     ...: %timeit tensordot_app(m0, m1)     ...: %timeit dot_app(m0, m1)     ...: 100 loops, best of 3: 10.3 ms per loop10 loops, best of 3: 31.3 ms per loop100 loops, best of 3: 5.12 ms per loop100 loops, best of 3: 4.06 ms per loop


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/596382.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号