- 1. K-Means算法原理
- 2. 西瓜数据集例题结果
- 3. Java代码
- 4. 代码改进详情
- 5. 一些感想
【聚类2】原型聚类——K-Means算法
2. 西瓜数据集例题结果- 西瓜数据集
- 结果
"这个数据集应该得到3个簇:"
1. C1 = {6,7,8,10,11,12,15,18,19,20}
2. C2 = {1,2,4,22,23,24,25,26,27,28,29,30}
3. C3 = {3,5,9,13,14,16,17,21}
3. Java代码
- xigua.arff
@relation xigua
@attribute 密度 numeric
@attribute 含糖量 numeric
@attribute 好瓜 {是,否}
@data
0.697,0.460,是
0.774,0.376,是
0.634,0.264,是
0.608,0.318,是
0.556,0.215,是
0.403,0.237,是
0.481,0.149,是
0.437,0.211,否
0.666,0.091,否
0.243,0.267,否
0.245,0.057,否
0.343,0.099,否
0.639,0.161,否
0.657,0.198,否
0.360,0.370,否
0.593,0.042,否
0.719,0.103,否
0.359,0.188,否
0.339,0.241,否
0.282,0.257,否
0.748,0.232,是
0.714,0.346,是
0.483,0.312,是
0.478,0.437,是
0.525,0.369,是
0.751,0.489,是
0.532,0.472,是
0.473,0.376,是
0.725,0.445,是
0.446,0.459,是
- KMeans.java
package cluster;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import weka.core.Instances;
public class KMeans {
Instances dataset;
int k;
int[][] clusters;
public KMeans(String paraFilename, int numClusters) {
dataset = null;
k = numClusters;
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception e) {
System.out.println("Cannot read the file: " + paraFilename + "rn" + e);
System.exit(0);
}
}
public static int[] getRandomIndices(int paraLength) {
Random random = new Random();
int[] resultIndices = new int[paraLength];
// Step 1. Initialize.
for (int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
} // Of for i
// Step 2. Randomly swap.
int tempFirst, tempSecond, tempValue;
for (int i = 0; i < paraLength; i++) {
// Generate two random indices.
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
// Swap.
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempSecond] = tempValue;
} // Of for i
return resultIndices;
}// Of getRandomIndices
public double distance(int paraI, double[] paraArray) {
double resultDistance = 0.0;
double tempDifference;
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
resultDistance += tempDifference * tempDifference;
} // Of for i
return Math.sqrt(resultDistance);
}// Of distance
public void clustering() {
int[] tempOldClusterArray = new int[dataset.numInstances()];
tempOldClusterArray[0] = -1;
int[] tempClusterArray = new int[dataset.numInstances()];
Arrays.fill(tempClusterArray, 0);
double[][] tempCenters = new double[k][dataset.numAttributes() - 1];
// Step 1. Initialize centers.
int[] tempRandomOrders = getRandomIndices(dataset.numInstances());
for (int i = 0; i < k; i++) {
for (int j = 0; j < tempCenters[0].length; j++) {
tempCenters[i][j] = dataset.instance(tempRandomOrders[i]).value(j);
} // Of for j
} // Of for i
int[] tempClusterLengths = null;
while (!Arrays.equals(tempOldClusterArray, tempClusterArray)) {
System.out.println("New loop ...");
tempOldClusterArray = tempClusterArray;
tempClusterArray = new int[dataset.numInstances()];
// Step 2.1 Minimization. Assign cluster to each instance.
int tempNearestCenter;
double tempNearestDistance;
double tempDistance;
for (int i = 0; i < dataset.numInstances(); i++) {
tempNearestCenter = -1;
tempNearestDistance = Double.MAX_VALUE;
for (int j = 0; j < k; j++) {
tempDistance = distance(i, tempCenters[j]);
if (tempNearestDistance > tempDistance) {
tempNearestDistance = tempDistance;
tempNearestCenter = j;
} // Of if
} // Of for j
tempClusterArray[i] = tempNearestCenter;
} // Of for i
// Step 2.2 Mean. Find new centers.
tempClusterLengths = new int[k];
Arrays.fill(tempClusterLengths, 0);
double[][] tempNewCenters = new double[k][dataset.numAttributes() - 1];
// Arrays.fill(tempNewCenters, 0);
for (int i = 0; i < dataset.numInstances(); i++) {
for (int j = 0; j < tempNewCenters[0].length; j++) {
tempNewCenters[tempClusterArray[i]][j] += dataset.instance(i).value(j);
} // Of for j
tempClusterLengths[tempClusterArray[i]]++;
} // Of for i
// Step 2.3 Now average
for (int i = 0; i < tempNewCenters.length; i++) {
for (int j = 0; j < tempNewCenters[0].length; j++) {
tempNewCenters[i][j] /= tempClusterLengths[i];
} // Of for j
} // Of for i
System.out.println("Now the new centers are: " + Arrays.deepToString(tempNewCenters));
tempCenters = tempNewCenters;
} // Of while
// Step 3. Form clusters.
clusters = new int[k][];
int[] tempCounters = new int[k];
for (int i = 0; i < k; i++) {
clusters[i] = new int[tempClusterLengths[i]];
} // Of for i
for (int i = 0; i < tempClusterArray.length; i++) {
clusters[tempClusterArray[i]][tempCounters[tempClusterArray[i]]] = i + 1;
tempCounters[tempClusterArray[i]]++;
} // Of for i
System.out.println("The clusters are: " + Arrays.deepToString(clusters));
}// Of clustering
public static void main(String arags[]) {
KMeans tempKMeans = new KMeans("D:/data/xigua.arff", 3);
tempKMeans.clustering();
}
}// KMeans
- 输出
New loop ... Now the new centers are: [[0.48975, 0.20060000000000003], [0.744, 0.361], [0.5783749999999999, 0.43837499999999996]] New loop ... Now the new centers are: [[0.4591176470588236, 0.18811764705882353], [0.7204285714285713, 0.3731428571428571], [0.5103333333333333, 0.4051666666666667]] New loop ... Now the new centers are: [[0.4296923076923077, 0.17038461538461538], [0.7027, 0.3231], [0.47100000000000003, 0.3992857142857143]] New loop ... Now the new centers are: [[0.38918181818181813, 0.17845454545454545], [0.6943333333333334, 0.29025], [0.47100000000000003, 0.3992857142857143]] New loop ... Now the new centers are: [[0.3725, 0.1748], [0.6836923076923077, 0.2844615384615385], [0.47100000000000003, 0.3992857142857143]] New loop ... Now the new centers are: [[0.3725, 0.1748], [0.6836923076923077, 0.2844615384615385], [0.47100000000000003, 0.3992857142857143]] The clusters are: [[6, 7, 8, 10, 11, 12, 16, 18, 19, 20], [1, 2, 3, 4, 5, 9, 13, 14, 17, 21, 22, 26, 29], [15, 23, 24, 25, 27, 28, 30]]4. 代码改进详情
- 随机生成3个整数,在1~30范围,且不相同
set集合法
package test;
import java.util.HashSet;
import java.util.Set;
public class TestRomdom {
private final static Integer K = 3;
public static void main(String[] args) {
System.out.println(getRandomIndices(30));
}
public static Set getRandomIndices(int length) {
// set 容器保证了元素不重复
Set set = new HashSet<>(K);
// while 保证了输出的元素是K个,只要set的大小为K的时候退出循环
while (set.size() != K) {
for(int i = 0; i < K; i++) {
set.add((int)(Math.random() * length + 1));
}
}
return set;
}
}
交换法:
public static int[] getRandomIndices(int paraLength) {
Random random = new Random();
int[] resultIndices = new int[paraLength];
// Step 1. Initialize.
for (int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
} // Of for i
// Step 2. Randomly swap.
int tempFirst, tempSecond, tempValue;
for (int i = 0; i < paraLength; i++) {
// Generate two random indices.
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
// Swap.
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempSecond] = tempValue;
} // Of for i
return resultIndices;
}// Of getRandomIndices
- 标准化欧式距离
package cluster;
import java.io.FileReader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import weka.core.Instances;
public class KMeans {
Instances dataset;
int k;
int[][] clusters;
public KMeans(String fileURL, int numClusters) {
dataset = null;
k = numClusters;
try {
FileReader fileReader = new FileReader(fileURL);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception e) {
System.out.println("Cannot read the file: " + fileURL + "rn" + e);
System.exit(0);
}
}
public int[] getRandomIndices(int length, int num) {
// set 容器保证了元素不重复
Set set = new HashSet<>(num);
// while 保证了输出的元素是K个,只要set的大小为K的时候退出循环
while (set.size() != num) {
for (int i = 0; i < num; i++) {
set.add((int) (Math.random() * length));
}
}
int[] result = new int[num];
Iterator it = set.iterator();
for (int i = 0; i < num; i++) {
result[i] = it.next();
}
return result;
}
public double eudistance(int prototypeIndex, int otherIndex) {
double result = 0.0;
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
double temp;
double p1 = dataset.instance(prototypeIndex).value(i);
double p2 = dataset.instance(otherIndex).value(i);
temp = Math.abs((p1 - p2)) * Math.abs((p1 - p2));
result += temp;
}
return Math.sqrt(result);
}
public void clustering() {
int length = dataset.numInstances();
int[] prototypeIndex = getRandomIndices(length, k);
double[][] tempDistances = getDistances(prototypeIndex);
clusters = selectCluster(tempDistances);
System.out.println(Arrays.deepToString(tempDistances));
System.out.println(Arrays.deepToString(clusters));
}
public int[][] selectCluster(double[][] distances) {
int length = dataset.numInstances();
int[][] result = null;
// 1
int[] minIndex = new int[length];
for (int i = 0; i < length; i++) {
double min = 100;
for (int j = 0; j < k; j++) {
if (min > distances[i][j]) {
min = distances[i][j];
minIndex[i] = j;
}
}
}
// 2 statistic
int[] count = new int[k];
for(int i=0; i
for(int j=0; j
if(minIndex[j] == i) {
count[i]++;
}
}
}
// 3
for(int i=0; i
result = new int[i][count[i]];
}
// 4
int t = 0;
for(int i=0; i
result[minIndex[i]][t++] = i;
}
System.out.println(Arrays.toString(minIndex));
System.out.println(Arrays.toString(count));
return result;
}
public double[][] getDistances(int[] prototypeIndex) {
int length = dataset.numInstances();
double[][] distances = new double[length][k];
for (int i = 0; i < length; i++) {
for (int j = 0; j < k; j++) {
distances[i][j] = eudistance(i, prototypeIndex[j]);
}
}
return distances;
}
public static void main(String arags[]) {
KMeans kmeans = new KMeans("D:/data/xigua.arff", 3);
kmeans.clustering();
}
}// KMeans
- 代码改进
package cluster;
import java.io.FileReader;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import weka.core.Instances;
public class KMeans {
Instances dataset;
int k;
int[][] clusters;
public KMeans(String paraFilename, int numClusters) {
dataset = null;
k = numClusters;
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
fileReader.close();
} catch (Exception e) {
System.out.println("Cannot read the file: " + paraFilename + "rn" + e);
System.exit(0);
}
}
public int[] getRandomIndices(int length, int num) {
// set 容器保证了元素不重复
Set set = new HashSet<>(num);
// while 保证了输出的元素是K个,只要set的大小为K的时候退出循环
while (set.size() != num) {
for (int i = 0; i < num; i++) {
set.add((int) (Math.random() * length));
}
}
int[] result = new int[num];
Iterator it = set.iterator();
for (int i = 0; i < num; i++) {
result[i] = it.next();
}
return result;
}
public double eudistance(int prototypeIndex, int otherIndex) {
double result = 0.0;
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
double temp;
double p1 = dataset.instance(prototypeIndex).value(i);
double p2 = dataset.instance(otherIndex).value(i);
temp = Math.abs((p1 - p2)) * Math.abs((p1 - p2));
result += temp;
}
return Math.sqrt(result);
}
public void clustering() {
int length = dataset.numInstances();
int[] prototypeIndex = getRandomIndices(length, k);
double[][] tempDistances = getDistances();
clusters = selectCluster(tempDistances);
System.out.println(Arrays.deepToString(tempDistances));
}
public int[][] selectCluster(double[][] distances) {
int length = dataset.numInstances();
int[][] result = new int[k][];
int clusterLength = 0;
for(int i=0; i
for(int j=0; j
}
}
return result;
}
public double[][] getDistances() {
int length = dataset.numInstances();
int[] prototypeIndex = getRandomIndices(length, k);
double[][] distances = new double[length][k];
for(int i=0; i
for(int j=0; j
distances[i][j] = eudistance(i, prototypeIndex[j]);
}
}
return distances;
}
public static void main(String arags[]) {
KMeans kmeans = new KMeans("D:/data/xigua.arff", 3);
kmeans.clustering();
}
}// KMeans
5. 一些感想
k-means代码,原理,都是比较简单的。但代码实现,还是要一些功夫,更好的算法的还有改进。 k-means所属是原型聚类,原型向量是通过随机来选择的,k为所分的簇的数量。k为几就要随机选几个 样本作为原型向量。 这就有个问题:例如:如何在1~30范围内,选择3个数,并保证3个数肯定不相同?有的很妙的思路: 我初始一个30容量的数组,里面内容是1~30(或0~29),然后随机生成两个在1~30的随机数,这两个随机数 决定,这个数组哪两个位置来交换。这样我k=3,我就取这个数组的前3个就好,这样保证了这3个数肯定不一样。 但缺陷还是有的,所以我们还有一种算法:高斯混合聚类,这个算法是以概率函数来选择原型向量,而不是通过 随机选择。 在距离计算时,密度,含糖量,单位不一样,向加的时候,你是否想过?所以我们,可以用标准化欧式距离,在 算样本之间的距离的同时,除以标准差(又叫均方差)。(1,2),(3,5).先算第一位的平均数(1+3)/2=2, (2-1)^2+(2-3)^2=2;2/2=1;sqrt(1)=1;这个1就是标准差。这样一除,单位就变成标准差了,单位一样,就可以 放心相加了。 更新均值向量,是这个算法的关键。



