package com.cnic.algorithm.flink.kmeans001;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.FunctionAnnotation;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
import java.io.Serializable;
import java.util.Collection;
import java.util.linkedList;
import java.util.List;
public class KMeansData {
public static class Point implements Serializable {
// x坐标,y坐标
public double x, y;
public Point() {}
public Point(double x, double y) {
this.x = x;
this.y = y;
}
// 点坐标的加法器
public Point add(Point other) {
x += other.x;
y += other.y;
return this;
}
// 点坐标的除法器
public Point div(long val) {
x /= val;
y /= val;
return this;
}
// 计算点之间的欧式距离
public double euclideanDistance(Point other) {
return Math.sqrt((x - other.x) * (x - other.x) + (y - other.y) * (y - other.y));
}
public void clear() {
x = y = 0.0;
}
@Override
public String toString() {
return x + " " + y;
}
}
public static class Centroid extends Point {
public int id;
public Centroid() {}
public Centroid(int id, double x, double y) {
super(x, y);
this.id = id;
}
public Centroid(int id, Point p) {
super(p.x, p.y);
this.id = id;
}
@Override
public String toString() {
return id + " " + super.toString();
}
}
public static final Object[][] CENTROIDS = new Object[][]{
new Object[]{1, -31.85, -44.77},
new Object[]{2, 35.16, 17.46},
new Object[]{3, -5.16, 21.93},
new Object[]{4, -24.06, 6.81}
};
public static final Object[][] POINTS = new Object[][]{
new Object[]{-14.22, -48.01},
new Object[]{-22.78, 37.10},
new Object[]{56.18, -42.99},
new Object[]{35.04, 50.29},
new Object[]{-9.53, -46.26},
new Object[]{-34.35, 48.25},
new Object[]{55.82, -57.49},
new Object[]{21.03, 54.64},
new Object[]{-13.63, -42.26},
new Object[]{-36.57, 32.63},
new Object[]{50.65, -52.40},
new Object[]{24.48, 34.04},
new Object[]{-2.69, -36.02},
new Object[]{-38.80, 36.58},
new Object[]{24.00, -53.74},
new Object[]{32.41, 24.96},
new Object[]{-4.32, -56.92},
new Object[]{-22.68, 29.42},
new Object[]{59.02, -39.56},
new Object[]{24.47, 45.07},
new Object[]{5.23, -41.20},
new Object[]{-23.00, 38.15},
new Object[]{44.55, -51.50},
new Object[]{14.62, 59.06},
new Object[]{7.41, -56.05},
new Object[]{-26.63, 28.97},
new Object[]{47.37, -44.72},
new Object[]{29.07, 51.06},
new Object[]{0.59, -31.89},
new Object[]{-39.09, 20.78},
new Object[]{42.97, -48.98},
new Object[]{34.36, 49.08},
new Object[]{-21.91, -49.01},
new Object[]{-46.68, 46.04},
new Object[]{48.52, -43.67},
new Object[]{30.05, 49.25},
new Object[]{4.03, -43.56},
new Object[]{-37.85, 41.72},
new Object[]{38.24, -48.32},
new Object[]{20.83, 57.85}
};
public static DataSet getDefaultCentroidDataSet(ParameterTool params, ExecutionEnvironment env) {
List centroidList = new linkedList();
// 遍历质心数据
for (Object[] centroid : CENTROIDS) {
// 将质心数据集添加到centroidList中
centroidList.add(
new Centroid((Integer) centroid[0], (Double) centroid[1], (Double) centroid[2]));
}
return env.fromCollection(centroidList);
}
// 得到默认的点数据
public static DataSet getDefaultPointDataSet(ParameterTool params, ExecutionEnvironment env){
List pointList = new linkedList();
// 遍历点数据
for (Object[] point : POINTS) {
pointList.add(new Point((Double) point[0], (Double) point[1]));
}
return env.fromCollection(pointList);
}
// 确定数据点最近的集群中心
@FunctionAnnotation.ForwardedFields("* -> 1")
public static final class SelectNearestCenter extends RichMapFunction>{
private Collection centroids;
// 将广播变量中的质心数据集读到集合中
@Override
public void open(Configuration parameters) throws Exception {
this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
}
public Tuple2 map(Point p) throws Exception {
double minDistance = Double.MAX_VALUE;
int closestCentroidId = -1;
// 遍历所有的簇中心
for (Centroid centroid : centroids) {
// 计算点和簇中心的欧式距离
double distance = p.euclideanDistance(centroid);
// 找到距离点最近的簇中心
if (distance < minDistance) {
minDistance = distance;
closestCentroidId = centroid.id;
}
}
// 输出一条新的记录,由簇中心id和point组成
return new Tuple2(closestCentroidId,p);
}
}
@FunctionAnnotation.ForwardedFields("f0;f1")
public static final class CountAppender implements MapFunction, Tuple3> {
@Override
public Tuple3 map(Tuple2 t) {
// 对簇内点进行计数
return new Tuple3(t.f0, t.f1, 1L);
}
}
@FunctionAnnotation.ForwardedFields("0")
public static final class CentroidAccumulator implements ReduceFunction> {
@Override
public Tuple3 reduce(Tuple3 val1, Tuple3 val2) {
// 这一步逻辑很关键,对簇内点坐标累计,然后对簇内元素个数计数。
return new Tuple3(val1.f0, val1.f1.add(val2.f1), val1.f2 + val2.f2);
}
}
@FunctionAnnotation.ForwardedFields("0->id")
public static final class CentroidAverager implements MapFunction, Centroid> {
@Override
public Centroid map(Tuple3 value) {
// 坐标和/簇内点个数作为新的簇中心
return new Centroid(value.f0, value.f1.div(value.f2));
}
}
public static void main(String[] args) throws Exception {
// 1.解析命令行参数
final ParameterTool params = ParameterTool.fromArgs(args);
// 2.构建执行环境
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
// 3.使参数在web界面中可用
env.getConfig().setGlobalJobParameters(params);
// 4.得到输入数据:从提供的路径读取点和质心,或返回默认数据
DataSet points = getDefaultPointDataSet(params, env);
DataSet centroids = getDefaultCentroidDataSet(params, env);
// 5.为K-Means算法设置批量迭代次数
IterativeDataSet loop = centroids.iterate(params.getInt("iteration", 10));
// 6.K-Means算法计算过程
DataSet newCentroids = points
// 6.1. 计算每个点距离最近的质心
.map(new SelectNearestCenter()).withBroadcastSet(loop, "centroids")
// 6.2. 每个簇内的所有点坐标求和
.map(new CountAppender())
.groupBy(0).reduce(new CentroidAccumulator())
// 6.3. 根据点计数和坐标和计算新的质心
.map(new CentroidAverager());
// 7. 将新的质心数据反馈到下一个迭代中
DataSet finalCentroids = loop.closeWith(newCentroids);
// 8. 将点归宿给最终的簇
// 8. 将点归宿给最终的簇
DataSet> clusteredPoints = points
.map(new SelectNearestCenter()).withBroadcastSet(finalCentroids, "centroids");
// 9. 指定输出结果路径和执行
if (params.has("output")) {
clusteredPoints.writeAsCsv(params.get("output"), "n", "data/output");
env.execute("KMeans Example");
} else {
System.out.println("Printing result to stdout. Use --output to specify output path.");
clusteredPoints.print();
}
}
}