栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

随机森林算法及其实现(2)

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

随机森林算法及其实现(2)

随机森林算法及其实现 算法实现
  1. 先实现随机化,有放回抽取样本,以及随机抽取属性(无放回)
IntArray* RandomForestClassifier::bootStrap()
{
	int count = 0;
	int tempIndex;
	IntArray* resInstances;
	int length = trainingSet->getRows();
	int* tempIndices = new int[length];
	memset(tempIndices, 0, length * sizeof(int));

	for (int i = 0; i < length; i++)
	{
		tempIndex = rand() % length;
		if (tempIndices[tempIndex] == 0) {
			tempIndices[tempIndex] = 1;
			count++;
		}// Of if
	}// Of for

	resInstances = new IntArray(count);

	for (int i = 0, j = 0; i < length; i++)
	{
		if (tempIndices[i] == 1)
		{
			resInstances->setValue(j++, i);
		}// Of if
	}// Of for

	std::cout << resInstances->toString() << std::endl;

	delete[] tempIndices;

	return resInstances;
}// Of bootStrap



IntArray* RandomForestClassifier::getAttributes()
{
	int tempIndex;
	int numAvailableAttributes = sqrt(numAttributes);
	int* tempAttributes = new int[numAttributes];
	memset(tempAttributes, 0, numAttributes * sizeof(int));
	IntArray* resAttributes = new IntArray(numAvailableAttributes);

	for (int i = 0; i < numAvailableAttributes; i++)
	{
		tempIndex = rand() % numAttributes;
		while (tempAttributes[tempIndex] == 1) {
			tempIndex = rand() % numAttributes;
		}// Of while
		tempAttributes[tempIndex] = 1;
	}// Of for

	tempIndex = 0;
	for (int i = 0; i < numAttributes; i++)
	{
		if (tempAttributes[i] == 1) {
			resAttributes->setValue(tempIndex++, i);
		}
	}// Of for
	
	delete[] tempAttributes;
	return resAttributes;
}// Of getAttributes
  1. 构造有穷颗树
Tree* RandomForestClassifier::buildTree()
{
	IntArray* availableInstances = bootStrap();
	IntArray* availableAttributes = getAttributes();

	Tree* tree = new Tree(trainingSet, trainingLables, availableInstances, availableAttributes, numClasses);
	tree->train();

	delete availableInstances;
	delete availableAttributes;
	return tree;
}// Of buildTree


void RandomForestClassifier::train()
{
	trees = new Tree * [numTrees];
	for (int i = 0; i < numTrees; i++)
	{
		trees[i] = buildTree();
	}// Of for
}// Of train
  1. 预测以及投票
int RandomForestClassifier::vote(IntArray* paraLabels)
{
	int* tempCountClasses = new int[numClasses];
	memset(tempCountClasses, 0, numClasses * sizeof(int));
	int max = 0;
	for (int i = 0; i < paraLabels->getLength(); i++)
	{
		tempCountClasses[paraLabels->getValue(i)]++;
		if (tempCountClasses[max] < tempCountClasses[paraLabels->getValue(i)])
		{
			max = paraLabels->getValue(i);
		}// Of if
	}// Of for
	delete[] tempCountClasses;
	return vote;
}// Of if


int RandomForestClassifier::predict(DoubleMatrix* paraInstance)
{
	IntArray* tempLabels = new IntArray(numTrees);

	for (int i = 0; i < numTrees; i++)
	{
		tempLabels->setValue(i, trees[i]->predict(paraInstance));
	}// Of for

	int resLable = vote(tempLabels);

	delete tempLabels;

	return resLable;
}// Of predict
实现过程出现的问题以及解决方式

实现之后发现准确率不太行,这里有可能是算法中还有点小问题,同时我在想会不会和数据集也有关系,这里测试的数据集数量只有17个,划分成训练集和测试集后就更少了。

数据集:

0, 0, 0, 0, 0, 0, 1
1, 0, 1, 0, 0, 0, 1
1, 0, 0, 0, 0, 0, 1
0, 0, 1, 0, 0, 0, 1
2, 0, 0, 0, 0, 0, 1
0, 1, 0, 0, 1, 1, 1
1, 1, 0, 1, 1, 1, 1
1, 1, 0, 0, 1, 0, 1
1, 1, 1, 1, 1, 0, 0
0, 2, 2, 0, 2, 1, 0
2, 2, 2, 2, 2, 0, 0
2, 0, 0, 2, 2, 1, 0
0, 1, 0, 1, 0, 0, 0
2, 1, 1, 1, 0, 0, 0
1, 1, 0, 0, 1, 1, 0
2, 0, 0, 2, 2, 0, 0
0, 0, 1, 1, 1, 0, 0



然后又测试了那个weather数据集,结果也是不太理想,如下图:

数据集:

0, 0, 0, 0, 0
0, 0, 0, 1, 0
1, 0, 0, 0, 1
2, 1, 0, 0, 1
2, 2, 1, 0, 1
2, 2, 1, 1, 0
1, 2, 1, 1, 1
0, 1, 0, 0, 0
0, 2, 1, 0, 1
2, 1, 1, 0, 1
0, 1, 1, 1, 1
1, 1, 0, 1, 1
1, 0, 1, 0, 1
2, 1, 0, 1, 0

应该是决策树里哪个部分有偏差,当然也参考了我导师的文章,日撸 Java 三百行(61-70天,决策树与集成学习),估计是在转 C++ 的过程里有些细节问题有疏忽,因此还得继续找 bug 。。。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/270567.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号