栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

均值漂移聚类算法

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

均值漂移聚类算法

不调用包实现

在Mean Shift算法中,最关键的就是计算每个点的偏移均值,然后根据新计算的偏移均值更新点的位置。
对于给定的维空间中的个样本点,则对于点,其Mean Shift向量的基本形式为:

基础版本:
import numpy as np
import matplotlib.pyplot as plt

X1, y1 = make_blobs(n_samples=200, n_features=2, centers=2)
#plt.scatter(X1[:, 0],X1[:, 1])

def meanshift(point, X, r, eps):
    pointNeigh = X[np.linalg.norm(X - point, axis=1) <= r]
    shift = np.sum(pointNeigh - point, axis=0) / len(pointNeigh)

    points = [point]
    while np.linalg.norm(shift) > eps:
        point = point + shift
        pointNeigh = X[np.linalg.norm(X - point, axis=1) <= r]
        shift = np.sum(pointNeigh - point, axis=0) / len(pointNeigh)
        points.append(point)
    return points

points = meanshift(np.array([1, 1]), X1, 30, 0.001)
points = np.array(points)
plt.figure(figsize=(10, 6))
plt.scatter(X1[:, 0], X1[:, 1])
plt.scatter(points[-1][0],points[-1][1],c='r')
#plt.plot(points[:, 0], points[:, 1], 'r<--', markersize=8)
plt.show()

分析:
meanshift(point, X, r, eps)
point 起始点坐标
X 数据集
r 质点的半径
eps

1.统计出所有在质点半径内的点

np.linalg.norm(X - point, axis=1) <= r

np.sum(a, axis=0) ------->列求和

np.sum(pointNeigh - point, axis=0) / len(pointNeigh)
求x,y的和,再除以个数求均值
得到半径范围内点对起始点的向量平均和,即为起始点的移动方向和距离

判断否到达最密集点

np.linalg.norm(shift) > eps,当移动的距离小于设定值时,那么就不再移动,默认到达了最密集点,否则就重复上面的步骤。

提升版本,引入高斯核,还没有看懂
def train_mean_shift(points, kenel_bandwidth=2):
    #shift_points = np.array(points)
    # 转换为矩阵
    mean_shift_points = np.mat(points)
    max_min_dist = 1
    iter = 0
    # 获取矩阵的行列值
    m, n = np.shape(mean_shift_points)
    need_shift = [True] * m
    print(n,m)
    #cal the mean shift vector
    while max_min_dist > MIN_DISTANCE:
        max_min_dist = 0
        iter += 1
        print ("iter : " + str(iter))
        for i in range(0, m):
            #判断每一个样本点是否需要计算偏置均值
            if not need_shift[i]:
                continue
            p_new = mean_shift_points[i]
            p_new_start = p_new
            p_new = shift_point(p_new, points, kenel_bandwidth)
            dist = euclidean_dist(p_new, p_new_start)
            if dist > max_min_dist:#record the max in all points
                max_min_dist = dist
            if dist < MIN_DISTANCE:#no need to move
                need_shift[i] = False

            mean_shift_points[i] = p_new
    #计算最终的group
    group = group_points(mean_shift_points)

    return np.mat(points), mean_shift_points, group

def train_mean_shift(points, kenel_bandwidth=2):
points:所有数据点
kenel_bandwidth:
将数据转换为矩阵类型
mean_shift_points = np.mat(points)

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/883215.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号