栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 前沿技术 > 大数据 > 大数据系统

spark中调用xgboost实现2分类时需对rawPrediction修改

spark中调用xgboost实现2分类时需对rawPrediction修改

1、修改代码


  private def amendXGBPred(res: Dataframe): Dataframe = {
    val columns = res.columns

    if (columns.contains("rawPrediction")) {
      val aRes  = res.withColumnRenamed("rawPrediction", "rawPrediction_Ori")
      val code = (arg: Vector) => {//这个函数使原来的vector,变成新的vector
        val rawPre = arg.apply(0)
        new DenseVector(Array(-1.0 * rawPre, rawPre))
      }
      val addCol = udf(code)
      val columns = aRes.columns
      aRes.selectExpr(columns:_*).withColumn("rawPrediction", addCol(aRes("rawPrediction_Ori")))
    } else {
      res
    }

2、原由

一般的2分类模型,rawPrediction有2列,一列是分类为0的原始预测数值、一列是1的;

但,XGBoost 0.81 Java 开源版本有bug,二分类预测结果rawPrediction只有一列数据,是分类为1的预测数值;

而,spark内置的交叉验证源代码评估时使用的是rawPrediction列,因此对XGBoost算法产生的rawPrediction列进行下调整修改;

3、图片

 

 

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/688130.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号