进一步提高性能
起初,一般的经验法则。您正在使用数值数组,因此请使用数组而不是列表。列表看起来可能有点像一个通用数组,但是在后端却完全不同,并且对于大多数数值计算来说绝对是不可行的。
如果您使用Numpy-Arrays编写简单的代码,则可以通过简单地按如下所示的方式添加代码来获得性能。如果使用列表,则可以或多或少地重写代码。
import numpy as npimport numba as nb@nb.njit(fastmath=True)def prod(array): assert array.shape[1]==3 #Enable SIMD-Vectorization (adding some performance) res=np.empty(array.shape[0],dtype=array.dtype) for i in range(array.shape[0]): res[i]=array[i,0]*array[i,1]*array[i,2] return res
使用
np.prod(a,axis=1)并不是一个坏主意,但是性能并不是很好。对于只有1000x3的数组,函数调用开销非常大。在另一个jitted函数中使用jitted
prod函数时,可以完全避免这种情况。
基准测试
# The first call to the jitted function takes about 200ms compilation overhead. #If you use @nb.njit(fastmath=True,cache=True) you can cache the compilation result for every successive call.n=999prod1 = 795 µsprod2 = 187 µsnp.prod = 7.42 µsprod 0.85 µsn=9990prod1 = 7863 µsprod2 = 1810 µsnp.prod = 50.5 µsprod 2.96 µs



