栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

手写KMeans(python)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

手写KMeans(python)

KMeans

一种聚类算法,是无监督学习,即没有类标签。
所用数据集就是几个点,点的分布和算的过程 见文末参考博客。

思想:
1. 数据
2. 随机选k个中心点
循环
3. 遍历所有点,算该点 和k点的距离,属于最小的 一类
4. 重新选中心点。新中心点 为该类所有点平均值
5. 结束条件:新中心和 旧中心差值 《 阈值(自己设置)
代码:
import random
def createDataSet():
    dataSet = [[1,4],
            [1,5],
            [2,4],
            [2,5],
            [2,6],
            [4,1],
            [4,2],
            [5,1],
            [5,2],
            [6,2]]
    return dataSet

# 两点间欧氏距离, 参数形式如:pointA = [x, y]
def getDistance(pointA, pointB):
    res = 0
    for i in range(len(pointA)):
        res += (pointA[i] - pointB[i]) ** 2
    print(f'{pointA}和{pointB}两点距离', res ** 0.5)
    return res ** 0.5

# 第一次随机选k个点为 k类中心点
def randomChoice(dataSet, K):
    centerPoint = []# 中心点 
    lastRanIndex = -1
    rList = []
    for i in range(K): # 选k个随机数
        randomIndex = random.randint(0, len(dataSet)-1)
        print('randomIndex =',randomIndex)
        if(randomIndex == lastRanIndex): # 两次随机数取值相同
            i -= 1
            continue
        lastRanIndex = randomIndex
        rList.append(randomIndex)
    rList = sorted(rList) # 下标递增 好一点
    #print(rList)
    for r in rList:
        centerPoint.append(dataSet[r])
    return centerPoint

def KMeans(dataSet):
    typeList = []
    [typeList.append(0) for i in range(len(dataSet))]
    # 默认K为2,两类
    K = 2
    centerPoint = randomChoice(dataSet, K) # 随机选K个中心点
    print('初次随机中心点 =', centerPoint)
    
    flag = 1
    while(flag == 1):
        # 遍历所有点,算该点 和k点的距离,属于最小的 一类
        for i in range(len(dataSet)):# 
            distList = []
            for k in range(K):
                dist = getDistance(dataSet[i], centerPoint[k]) # 第i点和第k中心点距离
                #print('两点距离 =', dist)
                distList.append(dist)
            print(distList)
            for j in range(K):
                if distList[j] == min(distList): # 找到最小的吗,就是类别
                    typeList[i] = j + 1 # j+1为类型
            print(typeList)
            
        # 计算新中心点
        newCenterPoint = renewCenterPoint(typeList, dataSet, K)
        print('新中心点 =', newCenterPoint)
        cntK = 0                # 统计k个中心点合格的个数
        for k in range(K):
            distChange = getDistance(centerPoint[k], newCenterPoint[k]) # 新旧中心点欧式距离
            print(f'第{k}个中心移动了 =', distChange)
            if distChange < 1:  # 中心点改变很小,不用再改变的中心点 +1
                cntK += 1
            if cntK == K:       #k个中心点全部合格才能结束
                flag = 0
                break
        centerPoint = newCenterPoint # 别忘了,更新中心点 
    return typeList

# 重新计算中心点
def renewCenterPoint(typeList, dataSet, K):
    centerPoint = []
    for k in range(K):
        X = Y = cnt = 0
        for j in range(len(typeList)):
            if typeList[j] == k + 1:
                X += dataSet[j][0]
                Y += dataSet[j][1]
                cnt += 1
        print(f'X和值{X}, Y和值{Y}, 该类有{cnt}个')
        centerPoint.append([X/cnt, Y/cnt])
    #print(centerPoint)
    return centerPoint
    
if __name__ == '__main__':
    dataSet = createDataSet()
    print(dataSet)
    typeList = KMeans(dataSet)# 最后结果
    print(typeList)
结果
[[1, 4], [1, 5], [2, 4], [2, 5], [2, 6], [4, 1], [4, 2], [5, 1], [5, 2], [6, 2]]
randomIndex = 0
randomIndex = 7
初次随机中心点 = [[1, 4], [5, 1]]
[1, 4]和[1, 4]两点距离 0.0
[1, 4]和[5, 1]两点距离 5.0
[0.0, 5.0]
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
......
新中心点 = [[1.6, 4.8], [4.8, 1.6]]
[1, 4]和[1.6, 4.8]两点距离 0.9999999999999999
第0个中心移动了 = 0.9999999999999999
[5, 1]和[4.8, 1.6]两点距离 0.632455532033676
第1个中心移动了 = 0.632455532033676
[1, 1, 1, 1, 1, 2, 2, 2, 2, 2]

参考文章

后序可能用plt画kmeans,更直观一点

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/844465.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号