解决方案代码-
import numpy as np# Given axis along which elementwise multiplication with broadcasting # is to be performedgiven_axis = 1# Create an array which would be used to reshape 1D array, b to have # singleton dimensions except for the given axis where we would put -1 # signifying to use the entire length of elements along that axis dim_array = np.ones((1,a.ndim),int).ravel()dim_array[given_axis] = -1# Reshape b with dim_array and perform elementwise multiplication with # broadcasting along the singleton dimensions for the final outputb_reshaped = b.reshape(dim_array)mult_out = a*b_reshaped
运行示例以演示步骤-
In [149]: import numpy as npIn [150]: a = np.random.randint(0,9,(4,2,3))In [151]: b = np.random.randint(0,9,(2,1)).ravel()In [152]: whosVariable Type Data/Info-------------------------------a ndarray 4x2x3: 24 elems, type `int32`, 96 bytesb ndarray 2: 2 elems, type `int32`, 8 bytesIn [153]: given_axis = 1
现在,我们要沿进行元素乘法
given axis = 1。让我们来创建
dim_array:
In [154]: dim_array = np.ones((1,a.ndim),int).ravel() ...: dim_array[given_axis] = -1 ...:In [155]: dim_arrayOut[155]: array([ 1, -1, 1])
最后,重塑
b形状并执行逐元素乘法:
In [156]: b_reshaped = b.reshape(dim_array) ...: mult_out = a*b_reshaped ...:
whos再次查看信息,并特别注意
b_reshaped&
mult_out:
In [157]: whosVariable Type Data/Info---------------------------------a ndarray 4x2x3: 24 elems, type `int32`, 96 bytesb ndarray 2: 2 elems, type `int32`, 8 bytesb_reshaped ndarray 1x2x1: 2 elems, type `int32`, 8 bytesdim_array ndarray 3: 3 elems, type `int32`, 12 bytesgiven_axis int 1mult_out ndarray 4x2x3: 24 elems, type `int32`, 96 bytes



