请参考
1. k-d tree算法的研究
2. Python手撸机器学习系列(十一):KNN之kd树实现
完整代码: https://github.com/nnzzll/NaiveKDTree
templateKDTree结点的结构struct Point3D { T x, y, z; int index; // 在点的成员里记录该点索引,方便构造KDTree Point3D() : x(0), y(0), z(0), index(-1){}; Point3D(T a, T b, T c) : x(a), y(b), z(c), index(-1){}; Point3D(T a, T b, T c, int idx) : x(a), y(b), z(c), index(idx){}; inline T &operator[](int i) { return i == 0 ? x : i == 1 ? y : z; }; }; template struct Point2D { T x, y; int index; Point2D() : x(0), y(0), index(-1){}; Point2D(T a, T b) : x(a), y(b), index(-1){}; Point2D(T a, T b, int idx) : x(a), y(b), index(idx){}; inline T &operator[](int i) { return i == 0 ? x : y; }; };
struct KDNode
{
int index; // 记录该结点保存的点的索引
int axis; // 记录该结点二分的维度
KDNode *left;
KDNode *right;
KDNode(int index, int axis, KDNode *left = nullptr, KDNode *right = nullptr)
{
this->index = index;
this->axis = axis;
this->left = left;
this->right = right;
}
};
KDTree的结构
templateKDTree的构造函数class KDTree { private: int ndim; KDNode *root; KDNode *build(std::vector &); std::set visited; // 用于搜索时回溯 std::stack queueNode; // 记录搜索路径 std::vector m_data; void release(KDNode *); void printNode(KDNode *); int chooseAxis(std::vector &); void dfs(KDNode *, T); // 点与点之间的距离 inline double distanceT(KDNode *, T); inline double distanceT(int, T); // 点与超平面的距离 inline double distanceP(KDNode *, T); // 检查父节点超平面是否在超球体中 inline bool checkParent(KDNode *, T, double); public: KDTree(std::vector &, int); ~KDTree(); void Print(); int findNearestPoint(T); };
template最近邻搜索KDTree ::KDTree(std::vector &data, int dim) { ndim = dim; m_data = data; // 拷贝一份数据 root = build(data); // 递归地构造二叉树 } template KDNode *KDTree ::build(std::vector &data) { if (data.empty()) return nullptr; std::vector temp = data; int mid_index = static_cast (data.size() / 2); // 二分的索引 int axis = data.size() > 1 ? chooseAxis(temp) : -1; // 根据每个维度的方差大小选择二分的维度,叶子结点无法二分,默认为-1 std::sort(temp.begin(), temp.end(), [axis](T a, T b) { return a[axis] < b[axis]; }); std::vector leftData, rightData; leftData.assign(temp.begin(), temp.begin() + mid_index); rightData.assign(temp.begin() + mid_index + 1, temp.end()); KDNode *leftNode = build(leftData); KDNode *rightNode = build(rightData); KDNode *rootNode; rootNode = new KDNode(temp[mid_index].index, axis, leftNode, rightNode); return rootNode; }
参考[1]
template测试int KDTree ::findNearestPoint(T pt) { while (!queueNode.empty()) queueNode.pop(); double min_dist = DBL_MAX; int resNodeIdx = -1; dfs(root, pt); while (!queueNode.empty()) { KDNode *curNode = queueNode.top(); queueNode.pop(); double dist = distanceT(curNode, pt); if (dist < min_dist) { min_dist = dist; resNodeIdx = curNode->index; } if (!queueNode.empty()) { KDNode *parentNode = queueNode.top(); int parentAxis = parentNode->axis; int parentIndex = parentNode->index; if (checkParent(parentNode, pt, min_dist)) { if (m_data[curNode->index][parentNode->axis] < m_data[parentNode->index][parentNode->axis]) dfs(parentNode->right, pt); else dfs(parentNode->left, pt); } } } return resNodeIdx; } template void KDTree ::dfs(KDNode *node, T pt) { if (node) { if (visited.find(node->index) != visited.end()) return; queueNode.push(node); visited.insert(node->index); if (pt[node->axis] < m_data[node->index][node->axis] && node->left) return dfs(node->left, pt); else if (pt[node->axis] >= m_data[node->index][node->axis] && node->right) return dfs(node->right, pt); } }
与VTK官方样例ClosestNPoints进行验证对比。
点云个数在1000以内,性能差不多和VTK的KdTree相当。
int main()
{
int N = 500;
// Create some random points
vtkNew pointSource;
pointSource->SetNumberOfPoints(N);
pointSource->Update();
std::vector> datasets;
vtkPoints *randPts = pointSource->GetOutput()->GetPoints();
for (vtkIdType i = 0; i < N; i++)
{
double pts[3];
randPts->GetPoint(i, pts);
datasets.push_back(Point3D(pts[0], pts[1], pts[2], i));
// std::cout << pts[0] << "," << pts[1] << "," << pts[2] << std::endl;
}
auto t1 = std::chrono::duration_cast(
std::chrono::system_clock::now().time_since_epoch())
.count();
// Create the tree
vtkNew pointTree;
pointTree->SetDataSet(pointSource->GetOutput());
pointTree->BuildLocator();
// Find the k closest points to (0,0,0)
unsigned int k = 1;
vtkNew testSource;
testSource->SetNumberOfPoints(1);
testSource->Update();
double testPoint[3];
testSource->GetOutput()->GetPoints()->GetPoint(0, testPoint);
vtkNew result;
std::cout << "Test Point: " << testPoint[0] << "," << testPoint[1] << "," << testPoint[2] << std::endl;
pointTree->FindClosestNPoints(k, testPoint, result);
for (vtkIdType i = 0; i < k; i++)
{
vtkIdType point_ind = result->GetId(i);
double p[3];
pointSource->GetOutput()->GetPoint(point_ind, p);
std::cout << "Closest point " << i << ": Point " << point_ind << ": ("
<< p[0] << ", " << p[1] << ", " << p[2] << ")" << std::endl;
}
auto t2 = std::chrono::duration_cast(
std::chrono::system_clock::now().time_since_epoch())
.count();
// Should return:
// Closest point 0: Point 2: (-0.136162, -0.0276359, 0.0369441)
// std::vector> datasets = {Point2D(7, 2, 0),
// Point2D(5, 4, 1),
// Point2D(9, 6, 2),
// Point2D(2, 3, 3),
// Point2D(4, 7, 4),
// Point2D(8, 1, 5)};
KDTree> tree(datasets, 3);
// tree.Print();
std::cout << tree.findNearestPoint(Point3D(testPoint[0], testPoint[1], testPoint[2])) << std::endl;
auto t3 = std::chrono::duration_cast(
std::chrono::system_clock::now().time_since_epoch())
.count();
std::cout << "VTK Time:" << t2 - t1 << " ms" << std::endl;
std::cout << "MY Time:" << t3 - t2 << " ms" << std::endl;
return EXIT_SUCCESS;
}
Test Point: 0.117163,-0.205549,0.352397 Closest point 0: Point 474: (0.12327, -0.22358, 0.322906) 474 VTK Time:4 ms MY Time:2 ms



