栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Java

【聚类4】K-Means

Java 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【聚类4】K-Means

文章目录
  • 1. K-Means算法原理
  • 2. 西瓜数据集例题结果
  • 3. Java代码
  • 4. 代码改进详情
  • 5. 一些感想

1. K-Means算法原理

    【聚类2】原型聚类——K-Means算法

2. 西瓜数据集例题结果
  • 西瓜数据集

  • 结果
"这个数据集应该得到3个簇:"
		1. C1 = {6,7,8,10,11,12,15,18,19,20}
		2. C2 = {1,2,4,22,23,24,25,26,27,28,29,30}
		3. C3 = {3,5,9,13,14,16,17,21}
3. Java代码
  • xigua.arff
@relation xigua

@attribute 密度 numeric
@attribute 含糖量 numeric
@attribute 好瓜 {是,否}

@data
0.697,0.460,是
0.774,0.376,是
0.634,0.264,是
0.608,0.318,是
0.556,0.215,是
0.403,0.237,是
0.481,0.149,是
0.437,0.211,否
0.666,0.091,否
0.243,0.267,否
0.245,0.057,否
0.343,0.099,否
0.639,0.161,否
0.657,0.198,否
0.360,0.370,否
0.593,0.042,否
0.719,0.103,否
0.359,0.188,否
0.339,0.241,否
0.282,0.257,否
0.748,0.232,是
0.714,0.346,是
0.483,0.312,是
0.478,0.437,是
0.525,0.369,是
0.751,0.489,是
0.532,0.472,是
0.473,0.376,是
0.725,0.445,是
0.446,0.459,是

  • KMeans.java
package cluster;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import weka.core.Instances;

public class KMeans {
	Instances dataset;
	int k;
	int[][] clusters;

	public KMeans(String paraFilename, int numClusters) {
		dataset = null;
		k = numClusters;
		try {
			FileReader fileReader = new FileReader(paraFilename);
			dataset = new Instances(fileReader);
			fileReader.close();
		} catch (Exception e) {
			System.out.println("Cannot read the file: " + paraFilename + "rn" + e);
			System.exit(0);
		}
	}

	public static int[] getRandomIndices(int paraLength) {
		Random random = new Random();
		int[] resultIndices = new int[paraLength];

		// Step 1. Initialize.
		for (int i = 0; i < paraLength; i++) {
			resultIndices[i] = i;
		} // Of for i

		// Step 2. Randomly swap.
		int tempFirst, tempSecond, tempValue;
		for (int i = 0; i < paraLength; i++) {
			// Generate two random indices.
			tempFirst = random.nextInt(paraLength);
			tempSecond = random.nextInt(paraLength);

			// Swap.
			tempValue = resultIndices[tempFirst];
			resultIndices[tempFirst] = resultIndices[tempSecond];
			resultIndices[tempSecond] = tempValue;
		} // Of for i

		return resultIndices;
	}// Of getRandomIndices

	public double distance(int paraI, double[] paraArray) {
		double resultDistance = 0.0;
		double tempDifference;
		for (int i = 0; i < dataset.numAttributes() - 1; i++) {
			tempDifference = dataset.instance(paraI).value(i) - paraArray[i];
			resultDistance += tempDifference * tempDifference;
		} // Of for i

		return Math.sqrt(resultDistance);
	}// Of distance

	
	public void clustering() {
		int[] tempOldClusterArray = new int[dataset.numInstances()];
		tempOldClusterArray[0] = -1;
		int[] tempClusterArray = new int[dataset.numInstances()];
		Arrays.fill(tempClusterArray, 0);
		double[][] tempCenters = new double[k][dataset.numAttributes() - 1];

		// Step 1. Initialize centers.
		int[] tempRandomOrders = getRandomIndices(dataset.numInstances());
		for (int i = 0; i < k; i++) {
			for (int j = 0; j < tempCenters[0].length; j++) {
				tempCenters[i][j] = dataset.instance(tempRandomOrders[i]).value(j);
			} // Of for j
		} // Of for i

		int[] tempClusterLengths = null;
		while (!Arrays.equals(tempOldClusterArray, tempClusterArray)) {
			System.out.println("New loop ...");
			tempOldClusterArray = tempClusterArray;
			tempClusterArray = new int[dataset.numInstances()];

			// Step 2.1 Minimization. Assign cluster to each instance.
			int tempNearestCenter;
			double tempNearestDistance;
			double tempDistance;

			for (int i = 0; i < dataset.numInstances(); i++) {
				tempNearestCenter = -1;
				tempNearestDistance = Double.MAX_VALUE;

				for (int j = 0; j < k; j++) {
					tempDistance = distance(i, tempCenters[j]);
					if (tempNearestDistance > tempDistance) {
						tempNearestDistance = tempDistance;
						tempNearestCenter = j;
					} // Of if
				} // Of for j
				tempClusterArray[i] = tempNearestCenter;
			} // Of for i

			// Step 2.2 Mean. Find new centers.
			tempClusterLengths = new int[k];
			Arrays.fill(tempClusterLengths, 0);
			double[][] tempNewCenters = new double[k][dataset.numAttributes() - 1];
			// Arrays.fill(tempNewCenters, 0);
			for (int i = 0; i < dataset.numInstances(); i++) {
				for (int j = 0; j < tempNewCenters[0].length; j++) {
					tempNewCenters[tempClusterArray[i]][j] += dataset.instance(i).value(j);
				} // Of for j
				tempClusterLengths[tempClusterArray[i]]++;
			} // Of for i

			// Step 2.3 Now average
			for (int i = 0; i < tempNewCenters.length; i++) {
				for (int j = 0; j < tempNewCenters[0].length; j++) {
					tempNewCenters[i][j] /= tempClusterLengths[i];
				} // Of for j
			} // Of for i

			System.out.println("Now the new centers are: " + Arrays.deepToString(tempNewCenters));
			tempCenters = tempNewCenters;
		} // Of while

		// Step 3. Form clusters.
		clusters = new int[k][];
		int[] tempCounters = new int[k];
		for (int i = 0; i < k; i++) {
			clusters[i] = new int[tempClusterLengths[i]];
		} // Of for i

		for (int i = 0; i < tempClusterArray.length; i++) {
			clusters[tempClusterArray[i]][tempCounters[tempClusterArray[i]]] = i + 1;
			tempCounters[tempClusterArray[i]]++;
		} // Of for i

		System.out.println("The clusters are: " + Arrays.deepToString(clusters));
	}// Of clustering

	public static void main(String arags[]) {
		KMeans tempKMeans = new KMeans("D:/data/xigua.arff", 3);
		tempKMeans.clustering();
	}

}// KMeans

  • 输出
New loop ...
Now the new centers are: [[0.48975, 0.20060000000000003], [0.744, 0.361], [0.5783749999999999, 0.43837499999999996]]
New loop ...
Now the new centers are: [[0.4591176470588236, 0.18811764705882353], [0.7204285714285713, 0.3731428571428571], [0.5103333333333333, 0.4051666666666667]]
New loop ...
Now the new centers are: [[0.4296923076923077, 0.17038461538461538], [0.7027, 0.3231], [0.47100000000000003, 0.3992857142857143]]
New loop ...
Now the new centers are: [[0.38918181818181813, 0.17845454545454545], [0.6943333333333334, 0.29025], [0.47100000000000003, 0.3992857142857143]]
New loop ...
Now the new centers are: [[0.3725, 0.1748], [0.6836923076923077, 0.2844615384615385], [0.47100000000000003, 0.3992857142857143]]
New loop ...
Now the new centers are: [[0.3725, 0.1748], [0.6836923076923077, 0.2844615384615385], [0.47100000000000003, 0.3992857142857143]]
The clusters are: [[6, 7, 8, 10, 11, 12, 16, 18, 19, 20], [1, 2, 3, 4, 5, 9, 13, 14, 17, 21, 22, 26, 29], [15, 23, 24, 25, 27, 28, 30]]

4. 代码改进详情
  • 随机生成3个整数,在1~30范围,且不相同

set集合法

 
package test;
 
import java.util.HashSet;
import java.util.Set;
 
public class TestRomdom {
 
    private final static Integer K = 3;
 
    public static void main(String[] args) {
        System.out.println(getRandomIndices(30));
    }
    
    public static Set getRandomIndices(int length) {
        // set 容器保证了元素不重复
        Set set  = new HashSet<>(K);
        // while 保证了输出的元素是K个,只要set的大小为K的时候退出循环
        while (set.size() != K) {
            for(int i = 0; i < K; i++) {
                set.add((int)(Math.random() * length + 1));
            }
        }
        return set;
    }
}
 

交换法:

	public static int[] getRandomIndices(int paraLength) {
		Random random = new Random();
		int[] resultIndices = new int[paraLength];

		// Step 1. Initialize.
		for (int i = 0; i < paraLength; i++) {
			resultIndices[i] = i;
		} // Of for i

		// Step 2. Randomly swap.
		int tempFirst, tempSecond, tempValue;
		for (int i = 0; i < paraLength; i++) {
			// Generate two random indices.
			tempFirst = random.nextInt(paraLength);
			tempSecond = random.nextInt(paraLength);

			// Swap.
			tempValue = resultIndices[tempFirst];
			resultIndices[tempFirst] = resultIndices[tempSecond];
			resultIndices[tempSecond] = tempValue;
		} // Of for i

		return resultIndices;
	}// Of getRandomIndices
  • 标准化欧式距离
package cluster;

import java.io.FileReader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

import weka.core.Instances;

public class KMeans {
	Instances dataset;
	int k;
	int[][] clusters;

	public KMeans(String fileURL, int numClusters) {
		dataset = null;
		k = numClusters;
		try {
			FileReader fileReader = new FileReader(fileURL);
			dataset = new Instances(fileReader);
			fileReader.close();
		} catch (Exception e) {
			System.out.println("Cannot read the file: " + fileURL + "rn" + e);
			System.exit(0);
		}
	}

	public int[] getRandomIndices(int length, int num) {
		// set 容器保证了元素不重复
		Set set = new HashSet<>(num);
		// while 保证了输出的元素是K个,只要set的大小为K的时候退出循环
		while (set.size() != num) {
			for (int i = 0; i < num; i++) {
				set.add((int) (Math.random() * length));
			}
		}

		int[] result = new int[num];
		Iterator it = set.iterator();
		for (int i = 0; i < num; i++) {
			result[i] = it.next();
		}

		return result;
	}

	public double eudistance(int prototypeIndex, int otherIndex) {
		double result = 0.0;
		for (int i = 0; i < dataset.numAttributes() - 1; i++) {
			double temp;
			double p1 = dataset.instance(prototypeIndex).value(i);
			double p2 = dataset.instance(otherIndex).value(i);
			temp = Math.abs((p1 - p2)) * Math.abs((p1 - p2));
			result += temp;
		}
		return Math.sqrt(result);
	}

	public void clustering() {
		int length = dataset.numInstances();
		int[] prototypeIndex = getRandomIndices(length, k);

		double[][] tempDistances = getDistances(prototypeIndex);
		clusters = selectCluster(tempDistances);

		System.out.println(Arrays.deepToString(tempDistances));
		System.out.println(Arrays.deepToString(clusters));
	}

	public int[][] selectCluster(double[][] distances) {
		int length = dataset.numInstances();
		int[][] result = null;
		// 1
		int[] minIndex = new int[length];
		for (int i = 0; i < length; i++) {
			double min = 100;
			for (int j = 0; j < k; j++) {
				if (min > distances[i][j]) {
					min = distances[i][j];
					minIndex[i] = j;
				}
			}
		}
		// 2 statistic
		int[] count = new int[k];
		for(int i=0; i
			for(int j=0; j
				if(minIndex[j] == i) {
					count[i]++;
				}
			}
		}
		// 3
		for(int i=0; i
			result = new int[i][count[i]];
		}
		// 4
		int t = 0;
		for(int i=0; i
			result[minIndex[i]][t++] = i;
		}

		System.out.println(Arrays.toString(minIndex));
		System.out.println(Arrays.toString(count));
		return result;
	}

	public double[][] getDistances(int[] prototypeIndex) {
		int length = dataset.numInstances();
		double[][] distances = new double[length][k];
		for (int i = 0; i < length; i++) {
			for (int j = 0; j < k; j++) {
				distances[i][j] = eudistance(i, prototypeIndex[j]);
			}
		}
		return distances;
	}

	public static void main(String arags[]) {
		KMeans kmeans = new KMeans("D:/data/xigua.arff", 3);
		kmeans.clustering();
	}

}// KMeans

  • 代码改进
package cluster;

import java.io.FileReader;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

import weka.core.Instances;

public class KMeans {
	Instances dataset;
	int k;
	int[][] clusters;

	public KMeans(String paraFilename, int numClusters) {
		dataset = null;
		k = numClusters;
		try {
			FileReader fileReader = new FileReader(paraFilename);
			dataset = new Instances(fileReader);
			fileReader.close();
		} catch (Exception e) {
			System.out.println("Cannot read the file: " + paraFilename + "rn" + e);
			System.exit(0);
		}
	}

	public int[] getRandomIndices(int length, int num) {
		// set 容器保证了元素不重复
		Set set = new HashSet<>(num);
		// while 保证了输出的元素是K个,只要set的大小为K的时候退出循环
		while (set.size() != num) {
			for (int i = 0; i < num; i++) {
				set.add((int) (Math.random() * length));
			}
		}

		int[] result = new int[num];
		Iterator it = set.iterator();
		for (int i = 0; i < num; i++) {
			result[i] = it.next();
		}

		return result;
	}

	public double eudistance(int prototypeIndex, int otherIndex) {
		double result = 0.0;
		for (int i = 0; i < dataset.numAttributes() - 1; i++) {
			double temp;
			double p1 = dataset.instance(prototypeIndex).value(i);
			double p2 = dataset.instance(otherIndex).value(i);
			temp = Math.abs((p1 - p2)) * Math.abs((p1 - p2));
			result += temp;
		}
		return Math.sqrt(result);
	}

	public void clustering() {
		int length = dataset.numInstances();
		int[] prototypeIndex = getRandomIndices(length, k);
		
		double[][] tempDistances = getDistances();
		clusters = selectCluster(tempDistances);
		
		System.out.println(Arrays.deepToString(tempDistances));
	}
	
	public int[][] selectCluster(double[][] distances) {
		int length = dataset.numInstances();
		int[][] result = new int[k][];
		int clusterLength = 0;
		for(int i=0; i
			for(int j=0; j
				
				
				
			}
		}
		return result;
	}
	
	public double[][] getDistances() {
		int length = dataset.numInstances();
		int[] prototypeIndex = getRandomIndices(length, k);
		double[][] distances = new double[length][k];
		for(int i=0; i
			for(int j=0; j
				distances[i][j] = eudistance(i, prototypeIndex[j]);
			}
		}
		return distances;
	}

	public static void main(String arags[]) {
		KMeans kmeans = new KMeans("D:/data/xigua.arff", 3);
		kmeans.clustering();
	}

}// KMeans

5. 一些感想
	k-means代码,原理,都是比较简单的。但代码实现,还是要一些功夫,更好的算法的还有改进。
k-means所属是原型聚类,原型向量是通过随机来选择的,k为所分的簇的数量。k为几就要随机选几个
样本作为原型向量。
	这就有个问题:例如:如何在1~30范围内,选择3个数,并保证3个数肯定不相同?有的很妙的思路:
我初始一个30容量的数组,里面内容是1~30(或0~29),然后随机生成两个在1~30的随机数,这两个随机数
决定,这个数组哪两个位置来交换。这样我k=3,我就取这个数组的前3个就好,这样保证了这3个数肯定不一样。
但缺陷还是有的,所以我们还有一种算法:高斯混合聚类,这个算法是以概率函数来选择原型向量,而不是通过
随机选择。
	在距离计算时,密度,含糖量,单位不一样,向加的时候,你是否想过?所以我们,可以用标准化欧式距离,在
算样本之间的距离的同时,除以标准差(又叫均方差)。(1,2),(3,5).先算第一位的平均数(1+3)/2=2,
(2-1)^2+(2-3)^2=2;2/2=1;sqrt(1)=1;这个1就是标准差。这样一除,单位就变成标准差了,单位一样,就可以
放心相加了。
	更新均值向量,是这个算法的关键。
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/865090.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号