- 输入:特征向量
- 输出:分类
- 在训练集中寻找与当前点最近的k个点,然后根据例如多数表决等规则进行分类
- Lp distance
import numpy as np
x_1 = np.array([1, 2, 3, 4]) x_2 = np.array([5, 6, 7, 8])
def Lp_distance(p, x_1, x_2):
x = x_1 - x_2
x = np.abs(x)
x = np.power(x, p)
Sum = np.sum(x)
x = np.power(Sum, 1/p)
return x
Lp_distance(1, x_1, x_2)
16.0The value of K
- k is small -> the neighbor is small -> only a sample who has a short distance will be seen as the same class
- more sensitive -> if the neighbor sample is noise -> wrong answer
- the smaller is the k, the more complicated is model
- Normally, we used the cross validation
- for the depth = j, the divide feature l is l = (j mod k) + 1
class Node:
def __init__(self, data, lchild = None, rchild = None):
self.data = data
self.lchild = lchild
self.rchild = rchild
class KD_Tree:
def __init__(self):
self.kd_tree = None
def create(self, dataset, depth):
if len(dataset) > 0:
m, n = np.shape(dataset) # m:行数; n:列数
midIndex = int(m / 2)
axis = depth % n # 划分轴
sorted_dataset = self.sort(dataset, axis)
node = Node(sorted_dataset[midIndex])
left_dataset = sorted_dataset[:midIndex]
right_dataset = sorted_dataset[midIndex + 1:]
node.lchild = self.create(left_dataset, depth + 1)
node.rchild = self.create(right_dataset, depth + 1)
return node
else:
return None
def sort(self, dataset, axis):
sorted_dataset = dataset[:]
m, n = np.shape(sorted_dataset)
for i in range(m-1):
temp = i
for j in range(i, m):
if sorted_dataset[temp][axis] > sorted_dataset[j][axis]:
temp = j
t = sorted_dataset[temp]
sorted_dataset[temp] = sorted_dataset[i]
sorted_dataset[i] = t
return sorted_dataset
def preOrder(self, node):
if node != None:
print("tttt->%s" % node.data)
self.preOrder(node.lchild)
self.preOrder(node.rchild)
def search(self, tree, x):
self.nearest_point = None
self.nearest_value = None
def travel(node, depth = 0):
if node is not None:
n = len(x)
axis = depth % n
if x[axis] < node.data[axis]:
travel(node.lchild, depth + 1)
else:
travel(node.rchild, depth + 1)
# 递归到最底下以后——要做的事情
distance_Node_x = self.dist(x, node.data)
# print('distance_Node_x', distance_Node_x)
# print('node.data', node.data)
if self.nearest_point is None or self.nearest_value > distance_Node_x:
self.nearest_point = node.data
self.nearest_value = distance_Node_x
# print('self.nearest_point', self.nearest_point)
# print('self.nearest_value', self.nearest_value)
# print(axis)
if abs(x[axis] - node.data[axis]) <= self.nearest_value: # 这句是判断和轴的距离,看和另外一侧是否有交集
if x[axis] < node.data[axis]:
# x在这个轴上小于node.data,说明本身在node的左孩子,所以要去右孩子上看看
travel(node.rchild, depth + 1)
else:
travel(node.lchild, depth + 1)
travel(tree)
return self.nearest_point, self.nearest_value
def dist(self, x_1, x_2):
x_1 = np.array(x_1)
x_2 = np.array(x_2)
x = x_1 - x_2
x = np.abs(x)
x = np.power(x, 2)
Sum = np.sum(x)
x = np.power(Sum, 1/2)
return x
dataset = [[2, 3],[5, 4],[9, 6],[4, 7], [8, 1], [7, 2]] x = [5, 3] kd_tree = KD_Tree() head = kd_tree.create(dataset, 0) # kd_tree.preOrder(head) nearest_point, nearest_value = kd_tree.search(head, x) print(nearest_point) print(nearest_value)
[5, 4] 1.0
从一位大神的复现看到《统计学习方法》的KD-tree python实现,本篇为记录贴。



