在Mean Shift算法中,最关键的就是计算每个点的偏移均值,然后根据新计算的偏移均值更新点的位置。
对于给定的维空间中的个样本点,则对于点,其Mean Shift向量的基本形式为:
import numpy as np
import matplotlib.pyplot as plt
X1, y1 = make_blobs(n_samples=200, n_features=2, centers=2)
#plt.scatter(X1[:, 0],X1[:, 1])
def meanshift(point, X, r, eps):
pointNeigh = X[np.linalg.norm(X - point, axis=1) <= r]
shift = np.sum(pointNeigh - point, axis=0) / len(pointNeigh)
points = [point]
while np.linalg.norm(shift) > eps:
point = point + shift
pointNeigh = X[np.linalg.norm(X - point, axis=1) <= r]
shift = np.sum(pointNeigh - point, axis=0) / len(pointNeigh)
points.append(point)
return points
points = meanshift(np.array([1, 1]), X1, 30, 0.001)
points = np.array(points)
plt.figure(figsize=(10, 6))
plt.scatter(X1[:, 0], X1[:, 1])
plt.scatter(points[-1][0],points[-1][1],c='r')
#plt.plot(points[:, 0], points[:, 1], 'r<--', markersize=8)
plt.show()
分析:
meanshift(point, X, r, eps)
point 起始点坐标
X 数据集
r 质点的半径
eps
np.linalg.norm(X - point, axis=1) <= r
np.sum(pointNeigh - point, axis=0) / len(pointNeigh)
求x,y的和,再除以个数求均值
得到半径范围内点对起始点的向量平均和,即为起始点的移动方向和距离
np.linalg.norm(shift) > eps,当移动的距离小于设定值时,那么就不再移动,默认到达了最密集点,否则就重复上面的步骤。
提升版本,引入高斯核,还没有看懂def train_mean_shift(points, kenel_bandwidth=2):
#shift_points = np.array(points)
# 转换为矩阵
mean_shift_points = np.mat(points)
max_min_dist = 1
iter = 0
# 获取矩阵的行列值
m, n = np.shape(mean_shift_points)
need_shift = [True] * m
print(n,m)
#cal the mean shift vector
while max_min_dist > MIN_DISTANCE:
max_min_dist = 0
iter += 1
print ("iter : " + str(iter))
for i in range(0, m):
#判断每一个样本点是否需要计算偏置均值
if not need_shift[i]:
continue
p_new = mean_shift_points[i]
p_new_start = p_new
p_new = shift_point(p_new, points, kenel_bandwidth)
dist = euclidean_dist(p_new, p_new_start)
if dist > max_min_dist:#record the max in all points
max_min_dist = dist
if dist < MIN_DISTANCE:#no need to move
need_shift[i] = False
mean_shift_points[i] = p_new
#计算最终的group
group = group_points(mean_shift_points)
return np.mat(points), mean_shift_points, group
def train_mean_shift(points, kenel_bandwidth=2):
points:所有数据点
kenel_bandwidth:
将数据转换为矩阵类型
mean_shift_points = np.mat(points)



