栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Python实现简单K-Means分类

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Python实现简单K-Means分类

通过Python写了个简单的K-Means分类

具体方法其实很简单:

  1. 生成几类随机数据点points
  2. 随机生成K个中心点centers
  3. 对每个点point求取距离最近的中心点center 即分类
    • 对每个分类集中的数据点求取平均点作为新的中心点坐标
    • 如果所有新的中心点 和 旧中心点 的距离都小于一定阈值 说明分类完成;否则迭代
import matplotlib.pyplot as plt
import numpy as np
import random
from icecream import ic
from collections import defaultdict
from matplotlib.colors import base_COLORS

def random_centers(k, points):
    for i in range(k):
        #在原本的可能坐标中随机生成k个中心点
        yield random.choice(points[:, 0]), random.choice(points[:, 1])

def mean(points):
    #all_x,all_y都是列表
    all_x, all_y = [x for x, y in points], [y for x, y in points]
    return np.mean(all_x), np.mean(all_y)


def distance(p1, p2):
    #求取两点之间的距离
    x1, y1 = p1
    x2, y2 = p2
    return np.sqrt((x1 - x2) ** 2 + (y1 - y2)**2)

def draw_points(centers,centers_neighbor,colors):
    #遍历每个中心点
    for i, c in enumerate(centers):
        #获取该中心点 所涵盖的point集合
        _points = centers_neighbor[c]
        all_x, all_y = [x for x, y in _points], [y for x, y in _points]
        #将对应点绘制颜色
        plt.scatter(all_x, all_y, c=colors[i])
    plt.show()

def kmeans(k, points, centers=None):
    #获取一个代表颜色信息值的列表
    colors = list(base_COLORS.values())
    #如果没有生成centers,则随机生成一个
    if not centers:
        centers = list(random_centers(k=k, points=points))
    #方便调试
    ic(centers)
    for i, c in enumerate(centers):#enumerate() 将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列
        plt.scatter([c[0]], [c[1]], s=90, marker='*', c=colors[i])#绘制散点图

    plt.scatter(*zip(*points), c='black')
    #defaultdict的作用是在于,当字典里的key不存在但被查找时,返回的不是keyError而是一个默认值  set对应set( ),即没有key时返回一个空集合
    centers_neighbor = defaultdict(set)

    for p in points:
        #min函数返回的是一个 中心点坐标
        closet_c = min(centers, key=lambda c: distance(p, c))
        #将points加入最近的中心点集合
        centers_neighbor[closet_c].add(tuple(p))

    #ic(centers_neighbor)

    draw_points(centers,centers_neighbor,colors)

    new_centers = []

    for c in centers_neighbor:
        #对每个中心点所包含的所有点求其平均值,作为新的中心点
        new_c = mean(centers_neighbor[c])
        new_centers.append(new_c)

    threshold = 0.1
    distances_old_and_new = [distance(c_old, c_new) for c_old, c_new in zip(centers, new_centers)]
    #ic(distances_old_and_new)
    if all(c < threshold for c in distances_old_and_new):
        return centers_neighbor
    else:
        kmeans(k, points, new_centers)

if __name__ == '__main__':
    #随机生成四组数据
    points0 = np.random.normal(loc=1, size=(100,2))
    points1 = np.random.normal(loc=2, size=(100, 2))
    points2 = np.random.normal(loc=4, size=(100, 2))
    points3 = np.random.normal(loc=5, size=(100, 2))

    points = np.concatenate([points0, points1, points2, points3])

    kmeans(3,points=points,centers=None)

效果图:

第一次迭代:

第二次迭代:

第三次迭代:

第四次迭代:

第五次迭代:

分类完成!

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/303150.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号