学习来源:日撸 Java 三百行(51-60天,kNN 与 NB)
第 51 天: KNN 分类器K个最近邻法(K-Nearst-Neighbor,KNN),解决监督学习中的分类问题。
算法思想:
如果一个实例在特征空间中的K个最相似(即特征空间中最近邻)的实例中的大多数属于某一个类别,则该实例也属于这个类别。并且所选择的邻居都是已经正确分类的实例。(即未标记样本的类别由距离其最近的K个邻居投票来决定)
距离衡量方法:
曼哈顿距离:
欧式距离:
K值的选择:
K值选择是KNN算法的关键,对近邻算法的结果有重大影响。
K值的具体含义:在决策时通过依据测试样本的K个最近邻"数据样本"做决策判断。
K值一般取较小值,一般采用交叉验证法来选取最优K值,也就是比较不同的K值时的交叉验证平均误差,选择平均误差最小的那个K值。
算法步骤:
1.从iris.arff中读入数据,并划分训练集和测试集。
2.分别计算测试集中的每个样本与训练集中所有样本的距离,选出距离最近的k个邻居。
3.k个邻居根据自身类别进行投票,票数最多的类别就是待预测样本的种类。
4.通过测试集中的样本的预测计算与实际结果来计算预测准确度。
代码:
package machine_learning;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import weka.core.*;
public class KnnClassification {
public static final int MANHATTAN = 0;
public static final int EUCLIDEAN = 1;
public int distanceMeasure = EUCLIDEAN;
public static final Random random = new Random();
int numNeighbors = 7;
Instances dataset;
int[] trainingSet;
int[] testingSet;
int[] predictions;
public KnnClassification(String paraFilename) {
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
// The last attribute is the decision class.
dataset.setClassIndex(dataset.numAttributes() - 1);
fileReader.close();
} catch (Exception ee) {
System.out.println("Error occurred while trying to read '" + paraFilename+ "' in KnnClassification constructor.rn" + ee);
System.exit(0);
} // Of try
}// Of the first constructor
public static int[] getRandomIndices(int paraLength) {
int[] resultIndices = new int[paraLength];
for (int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
} // Of for i
int tempFirst, tempSecond, tempValue;
for (int i = 0; i < paraLength; i++) {
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempSecond] = tempValue;
} // Of for i
return resultIndices;
}// Of getRandomIndices
public void splitTrainingTesting(double paraTrainingFraction) {
int tempSize = dataset.numInstances();
int[] tempIndices = getRandomIndices(tempSize);
int tempTrainingSize = (int) (tempSize * paraTrainingFraction);
trainingSet = new int[tempTrainingSize];
testingSet = new int[tempSize - tempTrainingSize];
for (int i = 0; i < tempTrainingSize; i++) {
trainingSet[i] = tempIndices[i];
} // Of for i
for (int i = 0; i < tempSize - tempTrainingSize; i++) {
testingSet[i] = tempIndices[tempTrainingSize + i];
} // Of for i
}// Of splitTrainingTesting
public void predict() {
predictions = new int[testingSet.length];
for (int i = 0; i < predictions.length; i++) {
predictions[i] = predict(testingSet[i]);
} // Of for i
}// Of predict
public int predict(int paraIndex) {
int[] tempNeighbors = computeNearests(paraIndex);
int resultPrediction = simpleVoting(tempNeighbors);
return resultPrediction;
}// Of predict
public double distance(int paraI, int paraJ) {
double resultDistance = 0;
double tempDifference;
switch (distanceMeasure) {
case MANHATTAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
if (tempDifference < 0) {
resultDistance -= tempDifference;
} else {
resultDistance += tempDifference;
} // Of if
} // Of for i
break;
case EUCLIDEAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
resultDistance += tempDifference * tempDifference;
} // Of for i
break;
default:
System.out.println("Unsupported distance measure: " + distanceMeasure);
}// Of switch
return resultDistance;
}// Of distance
public double getAccuracy() {
// A double divides an int gets another double.
double tempCorrect = 0;
for (int i = 0; i < predictions.length; i++) {
if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
tempCorrect++;
} // Of if
} // Of for i
return tempCorrect / testingSet.length;
}// Of getAccuracy
public int[] computeNearests(int paraCurrent) {
int[] resultNearests = new int[numNeighbors];
boolean[] tempSelected = new boolean[trainingSet.length];
double tempMinimalDistance;
int tempMinimalIndex = 0;
double[] tempDistances = new double[trainingSet.length];
for (int i = 0; i < trainingSet.length; i ++) {
tempDistances[i] = distance(paraCurrent, trainingSet[i]);
}//Of for i
for (int i = 0; i < numNeighbors; i++) {
tempMinimalDistance = Double.MAX_VALUE;
for (int j = 0; j < trainingSet.length; j++) {
if (tempSelected[j]) {
continue;
} // Of if
if (tempDistances[j] < tempMinimalDistance) {
tempMinimalDistance = tempDistances[j];
tempMinimalIndex = j;
} // Of if
} // Of for j
resultNearests[i] = trainingSet[tempMinimalIndex];
tempSelected[tempMinimalIndex] = true;
} // Of for i
System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
return resultNearests;
}// Of computeNearests
public int simpleVoting(int[] paraNeighbors) {
int[] tempVotes = new int[dataset.numClasses()];
for (int i = 0; i < paraNeighbors.length; i++) {
tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
} // Of for i
int tempMaximalVotingIndex = 0;
int tempMaximalVoting = 0;
for (int i = 0; i < dataset.numClasses(); i++) {
if (tempVotes[i] > tempMaximalVoting) {
tempMaximalVoting = tempVotes[i];
tempMaximalVotingIndex = i;
} // Of if
} // Of for i
return tempMaximalVotingIndex;
}// Of simpleVoting
public static void main(String args[]) {
KnnClassification tempClassifier = new KnnClassification("C:\Users\LXY\Desktop\iris.arff");
tempClassifier.splitTrainingTesting(0.8);
tempClassifier.predict();
System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
}// Of main
}// Of class KnnClassification
运行截图:



