栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

聚类--KMeans算法

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

聚类--KMeans算法

算法流程

从数据集中随机选取k个聚类样本作为初始的聚类中心,然后计算数据集中每个样本到这k个聚类中心的距离(一般为欧氏距离),选取距离最小的聚类中心所对应的类别作为该样本点的类别;将所有样本点归类后,重新计算每个类别的聚类中心(取每类别样本集的均值)。重复上述过程,直到聚类中心不再更新或者达到阈值(如最大迭代次数)。

算法缺点
    k-means是局部最优的,容易受到初始聚类中心(即质心)的影响,造成次优的聚类结果;k值的选取也会影响聚类结果,最优聚类的k值应与样本数据本身的结构信息相吻合,而这种结构信息是很难去掌握,因此选取最优k值是非常困难的。
python代码实现

假设对如下数据进行聚类,首先可视化数据:

import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
%matplotlib notebook
path='D:codepythondatabaseex7data2'
data=sio.loadmat(path)
X=data['X']
plt.scatter(X[:,0],X[:,1])


定义模型并训练

class KMeans:
    #初始化模型参数
    def __init__(self,k,max_iter=50):
        self.k = k  #聚类簇数
        self.max_iter = max_iter #最大迭代次数
        self.all_centroids = []  #存放迭代的聚类中心

    #初始化聚类中心,从样本中随机选取
    def initCentroids(self,X):
        m,n = X.shape
        centroids = np.zeros((self.k,n))
        index = np.random.randint(0,m,self.k)
        for i in range(len(centroids)):
            centroids[i] = X[index[i]]
        return centroids

    #寻找每个样本点的最近簇
    def findClosestCentroids(self,X,centroids):
        m = X.shape[0]
        label = np.zeros(m)
        for i in range(m):
            min_dist = np.inf
            for j in range(len(centroids)):
                dist = np.sum((X[i]-centroids[j])**2)
                if min_dist > dist:
                    min_dist = dist
                    label[i] = j
        return label

    #计算聚类中心
    def getCentroids(self,X,label):
        centroids = np.zeros((self.k,X.shape[1]))
        for i in range(self.k):
            centroids[i] = X[label==i].mean(axis=0)
        return centroids

    #训练模型
    def fit(self,X):
        #centroids = self.initCentroids(X) #初始化聚类中心
        centroids = np.array([[3,3],[6,2],[8,5]])#以该聚类中心展示效果
        for i in range(self.max_iter):
            self.all_centroids.append(centroids)
            label = self.findClosestCentroids(X,centroids)  #每个样本点所属类别
            centroids = self.getCentroids(X,label)  #更新聚类中心
        self.label = self.findClosestCentroids(X,centroids)

    #绘制KMeans算法结果
    def plotKMeans(self,X):
        x, y = [], []  
        #绘制聚类结果
        plt.scatter(X[:, 0], X[:, 1], c=self.label, cmap='rainbow')
        for i in range(self.max_iter):
            x.append(self.all_centroids[i][:,0])
            y.append(self.all_centroids[i][:,1])
        #绘制聚类中心迭代路径
        plt.plot(x,y,marker='*')
        plt.show()

    #计算模型平方误差
    def squareError(self,X):
        error = 0
        final_centroids = self.all_centroids[-1]
        for i in range(X.shape[0]):
            j = int(self.label[i])
            error += np.sum((X[i]-final_centroids[j])**2)
        return error
    
model = KMeans(3,10)
model.fit(X)
model.plotKMeans(X)

利用肘部法则选取k的值,选k=3。

error = []
for k in range(1,9):
    model = KMeans(k,10)
    model.fit(X)
    error.append(model.squareError(X))
plt.xlabel('k')
plt.ylabel('error')
plt.plot(range(1,9),error)
D:apythonanaconda.installlibsite-packagesipykernel_launcher.py:35: RuntimeWarning: Mean of empty slice.
D:apythonanaconda.installlibsite-packagesnumpycore_methods.py:73: RuntimeWarning: invalid value encountered in true_divide
  ret, rcount, out=ret, casting='unsafe', subok=False)




[]
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/740471.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号