栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

Java(18)

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Java(18)

学习来源:日撸 Java 三百行(51-60天,kNN 与 NB)

第 51 天: KNN 分类器

K个最近邻法(K-Nearst-Neighbor,KNN),解决监督学习中的分类问题。

算法思想:
如果一个实例在特征空间中的K个最相似(即特征空间中最近邻)的实例中的大多数属于某一个类别,则该实例也属于这个类别。并且所选择的邻居都是已经正确分类的实例。(即未标记样本的类别由距离其最近的K个邻居投票来决定)

距离衡量方法:
曼哈顿距离:

欧式距离:

K值的选择:
K值选择是KNN算法的关键,对近邻算法的结果有重大影响。
K值的具体含义:在决策时通过依据测试样本的K个最近邻"数据样本"做决策判断。
K值一般取较小值,一般采用交叉验证法来选取最优K值,也就是比较不同的K值时的交叉验证平均误差,选择平均误差最小的那个K值。

算法步骤:
1.从iris.arff中读入数据,并划分训练集和测试集。
2.分别计算测试集中的每个样本与训练集中所有样本的距离,选出距离最近的k个邻居。
3.k个邻居根据自身类别进行投票,票数最多的类别就是待预测样本的种类。
4.通过测试集中的样本的预测计算与实际结果来计算预测准确度。

代码:

package machine_learning;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import weka.core.*;



public class KnnClassification {
	public static final int MANHATTAN = 0;
	public static final int EUCLIDEAN = 1;
	public int distanceMeasure = EUCLIDEAN;
	public static final Random random = new Random();
	
	int numNeighbors = 7;
	Instances dataset;
	int[] trainingSet;
	int[] testingSet;
	int[] predictions;

	public KnnClassification(String paraFilename) {
		try {
			FileReader fileReader = new FileReader(paraFilename);
			dataset = new Instances(fileReader);
			// The last attribute is the decision class.
			dataset.setClassIndex(dataset.numAttributes() - 1);
			fileReader.close();
		} catch (Exception ee) {
			System.out.println("Error occurred while trying to read '" + paraFilename+ "' in KnnClassification constructor.rn" + ee);
			System.exit(0);
		} // Of try
	}// Of the first constructor

	public static int[] getRandomIndices(int paraLength) {
		int[] resultIndices = new int[paraLength];

		for (int i = 0; i < paraLength; i++) {
			resultIndices[i] = i;
		} // Of for i

		int tempFirst, tempSecond, tempValue;
		for (int i = 0; i < paraLength; i++) {
			tempFirst = random.nextInt(paraLength);
			tempSecond = random.nextInt(paraLength);

			tempValue = resultIndices[tempFirst];
			resultIndices[tempFirst] = resultIndices[tempSecond];
			resultIndices[tempSecond] = tempValue;
		} // Of for i

		return resultIndices;
	}// Of getRandomIndices

	public void splitTrainingTesting(double paraTrainingFraction) {
		int tempSize = dataset.numInstances();
		int[] tempIndices = getRandomIndices(tempSize);
		int tempTrainingSize = (int) (tempSize * paraTrainingFraction);

		trainingSet = new int[tempTrainingSize];
		testingSet = new int[tempSize - tempTrainingSize];

		for (int i = 0; i < tempTrainingSize; i++) {
			trainingSet[i] = tempIndices[i];
		} // Of for i

		for (int i = 0; i < tempSize - tempTrainingSize; i++) {
			testingSet[i] = tempIndices[tempTrainingSize + i];
		} // Of for i
	}// Of splitTrainingTesting
	
	public void predict() {
		predictions = new int[testingSet.length];
		for (int i = 0; i < predictions.length; i++) {
			predictions[i] = predict(testingSet[i]);
		} // Of for i
	}// Of predict

	public int predict(int paraIndex) {
		int[] tempNeighbors = computeNearests(paraIndex);
		int resultPrediction = simpleVoting(tempNeighbors);

		return resultPrediction;
	}// Of predict

	public double distance(int paraI, int paraJ) {
		double resultDistance = 0;
		double tempDifference;
		switch (distanceMeasure) {
		case MANHATTAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
				if (tempDifference < 0) {
					resultDistance -= tempDifference;
				} else {
					resultDistance += tempDifference;
				} // Of if
			} // Of for i
			break;

		case EUCLIDEAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
				resultDistance += tempDifference * tempDifference;
			} // Of for i
			break;
		default:
			System.out.println("Unsupported distance measure: " + distanceMeasure);
		}// Of switch

		return resultDistance;
	}// Of distance

	public double getAccuracy() {
		// A double divides an int gets another double.
		double tempCorrect = 0;
		for (int i = 0; i < predictions.length; i++) {
			if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
				tempCorrect++;
			} // Of if
		} // Of for i

		return tempCorrect / testingSet.length;
	}// Of getAccuracy

	public int[] computeNearests(int paraCurrent) {
		int[] resultNearests = new int[numNeighbors];
		boolean[] tempSelected = new boolean[trainingSet.length];
		double tempMinimalDistance;
		int tempMinimalIndex = 0;

		double[] tempDistances = new double[trainingSet.length];
		for (int i = 0; i < trainingSet.length; i ++) {
			tempDistances[i] = distance(paraCurrent, trainingSet[i]);
		}//Of for i
		
		for (int i = 0; i < numNeighbors; i++) {
			tempMinimalDistance = Double.MAX_VALUE;

			for (int j = 0; j < trainingSet.length; j++) {
				if (tempSelected[j]) {
					continue;
				} // Of if

				if (tempDistances[j] < tempMinimalDistance) {
					tempMinimalDistance = tempDistances[j];
					tempMinimalIndex = j;
				} // Of if
			} // Of for j

			resultNearests[i] = trainingSet[tempMinimalIndex];
			tempSelected[tempMinimalIndex] = true;
		} // Of for i

		System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
		return resultNearests;
	}// Of computeNearests

	public int simpleVoting(int[] paraNeighbors) {
		int[] tempVotes = new int[dataset.numClasses()];
		for (int i = 0; i < paraNeighbors.length; i++) {
			tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
		} // Of for i

		int tempMaximalVotingIndex = 0;
		int tempMaximalVoting = 0;
		for (int i = 0; i < dataset.numClasses(); i++) {
			if (tempVotes[i] > tempMaximalVoting) {
				tempMaximalVoting = tempVotes[i];
				tempMaximalVotingIndex = i;
			} // Of if
		} // Of for i

		return tempMaximalVotingIndex;
	}// Of simpleVoting


	public static void main(String args[]) {
		KnnClassification tempClassifier = new KnnClassification("C:\Users\LXY\Desktop\iris.arff");
		tempClassifier.splitTrainingTesting(0.8);
		tempClassifier.predict();
		System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
	}// Of main

}// Of class KnnClassification

运行截图:

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/859532.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号