栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

对应python中curve

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

对应python中curve

对应python中curve_fit的多元线性回归java实现
  • python中的拟合方法
  • java中实现多元线性拟合方法
    • 参考文章
    • 源代码及说明
    • 关于代码中一些参数的说明

python中的拟合方法

在python中实现拟合很方便,使用curve_fit,填好公式,样本数据和结果集,初始猜想和边界,很快就能实现,如下示例:
curve_fit(fit_function, a_tuple, b, p0=init_guess, bounds=(lb_tuple, ub_tuple), maxfev=1000)
fit_function如:theta1* x1 + theta2 * x2 + …
根据需要调整公式,可以实现多元线性拟合或非线性拟合。

java中实现多元线性拟合方法

在java中没有找到像python那样便捷地实现拟合的方法。针对多元线性拟合,经过多方面查找和实验,有一些小小的经验,记录备查。

参考文章

着重感谢下面文章的作者,基于他们的文章学到了很多知识:
https://zhuanlan.zhihu.com/p/25765735
https://www.cnblogs.com/donaldlee2008/p/5861796.html
https://my.oschina.net/u/1778239/blog/1858397

后面2篇文章的代码都能正常运行,得到的结果类似,测试过程中对代码做了的一些调整,记录备查。

源代码及说明

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;

public class LinearRegression {


private double [][] trainData;//训练数据,一行一个数据,每一行最后一个数据为 y
private int row;//训练数据  行数
private int column;//训练数据 列数
private double [] theta;//参数theta
private double alpha;//训练步长
private int iteration;//迭代次数

public LinearRegression(String fileName)
{
    int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数
    int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数

    trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
    this.row=rowoffile;
    this.column=columnoffile+1;

    this.alpha = 0.001;//步长默认为0.001
    this.iteration=100000;//迭代次数默认为 100000

    theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
    initialize_theta(0.5);

    loadTrainDataFromFile(fileName,rowoffile,columnoffile);
}
public LinearRegression(String fileName,double alpha,int iteration)
{
    int rowoffile=getRowNumber(fileName);//获取输入训练数据文本的   行数
    int columnoffile = getColumnNumber(fileName);//获取输入训练数据文本的   列数

    trainData = new double[rowoffile][columnoffile+1];//这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于0
    this.row=rowoffile;
    this.column=columnoffile+1;

    this.alpha = alpha;
    this.iteration=iteration;

    theta = new double [column-1];//h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
    initialize_theta(1.0/3.0);

    loadTrainDataFromFile(fileName,rowoffile,columnoffile);
}

private int getRowNumber(String fileName)
{
    int count =0;
    File file = new File(fileName);
    BufferedReader reader = null;
    try {
        reader = new BufferedReader(new FileReader(file));
        while ( reader.readLine() != null)
            count++;
        reader.close();
    } catch (IOException e) {
        e.printStackTrace();
    } finally {
        if (reader != null) {
            try {
                reader.close();
            } catch (IOException e1) {
            }
        }
    }
    return count;

}

private int getColumnNumber(String fileName)
{
    int count =0;
    File file = new File(fileName);
    BufferedReader reader = null;
    try {
        reader = new BufferedReader(new FileReader(file));
        String tempString = reader.readLine();
        if(tempString!=null)
            count = tempString.split(",").length;
        reader.close();
    } catch (IOException e) {
        e.printStackTrace();
    } finally {
        if (reader != null) {
            try {
                reader.close();
            } catch (IOException e1) {
            }
        }
    }
    return count;
}

private void initialize_theta(double init_guess)//将theta各个参数全部初始化为init_guess
{
    for(int i=0;i0 )
    {
        //对每个theta i 求 偏导数
        double [] partial_derivative = compute_partial_derivative();//偏导数
        theta[0] = 0.0;//如果不需要第一个theta0,将这个值设置为0
        //更新每个theta
        for(int i =0; i< theta.length;i++) {
            double tmpTheta = theta[i]-alpha * partial_derivative[i];
            //加入了边界值,超过边界值则不再处理。如果不需要边界值,直接赋值即可。
            if(tmpTheta <1 && tmpTheta >-1){
                theta[i] = tmpTheta;
            }
        }
    }

    double[] thetaResult = new double[theta.length-1];
    for(int i=1; i 

}

关于代码中一些参数的说明

去除theta0的影响:
由于不需要theta0,所以在初始化时,将样本数据的第一列设置为0,将theta0设置为0,每次求偏导后再次将theta0设置为0,去除该值的影响。
如果实际公式中需要theta0,则把样本数据的第一列设置为1.0,再把theta0赋0的代码去掉即可。

步长alpha和迭代次数iteration:
步长和迭代次数很重要,需要根据实际情况进行调整。
三元时,测试结果如下:
如果使用0.01的步长,大概100W次,才能得到与python类似的结果;
如果使用0.1的步长,大概10W次,可以得到与python类似的结果。

关于初始猜想值:
初始猜想值,原始代码设置为了1.0,根据实际需要,代码中调整为了 1/变量个数。

关于边界:
在python中,可以设置数据的上限和下限,在该代码中没有包含。
调整代码后,比较简单粗暴地进行了范围判断,超过范围后就不再赋值。如果哪位大侠有更好的判断数据范围的方法,请告知,不尽感谢!

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/423047.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号