我认为您就在那儿。您需要将数据集放到一个数组或结构中,以供您使用一个单独的全局目标函数,该函数提供给minimum()并使用所有数据集的单个参数集来适合所有数据集。您可以根据需要在数据集之间共享此集。稍微扩展一下示例,下面的代码确实可以对5种不同的高斯函数进行一次拟合。对于跨数据集绑定参数的示例,我对5个数据集的sigma使用了几乎相同的值。我创建了5个不同的sigma参数(“
sig_1”,“ sig_2”,…,“ sig_5”),然后使用数学约束将其设置为相同的值。因此,问题中有11个变量,而不是15个。
import numpy as npimport matplotlib.pyplot as pltfrom lmfit import minimize, Parameters, report_fitdef gauss(x, amp, cen, sigma): "basic gaussian" return amp*np.exp(-(x-cen)**2/(2.*sigma**2))def gauss_dataset(params, i, x): """calc gaussian from params for data set i using simple, hardwired naming convention""" amp = params['amp_%i' % (i+1)].value cen = params['cen_%i' % (i+1)].value sig = params['sig_%i' % (i+1)].value return gauss(x, amp, cen, sig)def objective(params, x, data): """ calculate total residual for fits to several data sets held in a 2-D array, and modeled by Gaussian functions""" ndata, nx = data.shape resid = 0.0*data[:] # make residual per data set for i in range(ndata): resid[i, :] = data[i, :] - gauss_dataset(params, i, x) # now flatten this to a 1D array, as minimize() needs return resid.flatten()# create 5 datasetsx = np.linspace( -1, 2, 151)data = []for i in np.arange(5): params = Parameters() amp = 0.60 + 9.50*np.random.rand() cen = -0.20 + 1.20*np.random.rand() sig = 0.25 + 0.03*np.random.rand() dat = gauss(x, amp, cen, sig) + np.random.normal(size=len(x), scale=0.1) data.append(dat)# data has shape (5, 151)data = np.array(data)assert(data.shape) == (5, 151)# create 5 sets of parameters, one per data setfit_params = Parameters()for iy, y in enumerate(data): fit_params.add( 'amp_%i' % (iy+1), value=0.5, min=0.0, max=200) fit_params.add( 'cen_%i' % (iy+1), value=0.4, min=-2.0, max=2.0) fit_params.add( 'sig_%i' % (iy+1), value=0.3, min=0.01, max=3.0)# but now constrain all values of sigma to have the same value# by assigning sig_2, sig_3, .. sig_5 to be equal to sig_1for iy in (2, 3, 4, 5): fit_params['sig_%i' % iy].expr='sig_1'# run the global fit to all the data setsresult = minimize(objective, fit_params, args=(x, data))report_fit(result.fit_params)# plot the data sets and fitsplt.figure()for i in range(5): y_fit = gauss_dataset(result.fit_params, i, x) plt.plot(x, data[i, :], 'o', x, y_fit, '-')plt.show()
对于它的价值,我会考虑将多个数据集保存在字典或DataSet类列表中,而不是多维数组。无论如何,我希望这能帮助您进入真正需要做的事情。



