栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > C/C++/C#

KDTree的C++实现

C/C++/C# 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

KDTree的C++实现

KDTree原理:

请参考
1. k-d tree算法的研究
2. Python手撸机器学习系列(十一):KNN之kd树实现
完整代码: https://github.com/nnzzll/NaiveKDTree

C++实现 点的结构
template 
struct Point3D
{
    T x, y, z; 
    int index; // 在点的成员里记录该点索引,方便构造KDTree
    Point3D() : x(0), y(0), z(0), index(-1){};
    Point3D(T a, T b, T c) : x(a), y(b), z(c), index(-1){};
    Point3D(T a, T b, T c, int idx) : x(a), y(b), z(c), index(idx){};
    inline T &operator[](int i) { return i == 0 ? x : i == 1 ? y
                                                             : z; };
};

template 
struct Point2D
{
    T x, y;
    int index;
    Point2D() : x(0), y(0), index(-1){};
    Point2D(T a, T b) : x(a), y(b), index(-1){};
    Point2D(T a, T b, int idx) : x(a), y(b), index(idx){};

    inline T &operator[](int i) { return i == 0 ? x : y; };
};
KDTree结点的结构
struct KDNode
{
	int index; // 记录该结点保存的点的索引
	int axis; // 记录该结点二分的维度
	KDNode *left;
	KDNode *right;
	KDNode(int index, int axis, KDNode *left = nullptr, KDNode *right = nullptr)
	{
		this->index = index;
		this->axis = axis;
		this->left = left;
		this->right = right;
	}
};
KDTree的结构
template 
class KDTree
{
private:
	int ndim;
	KDNode *root;
	KDNode *build(std::vector &);
	std::set visited; // 用于搜索时回溯
	std::stack queueNode; // 记录搜索路径
	std::vector m_data;

	void release(KDNode *);
	void printNode(KDNode *);
	int chooseAxis(std::vector &);
	void dfs(KDNode *, T);
	// 点与点之间的距离
	inline double distanceT(KDNode *, T);
	inline double distanceT(int, T);
	// 点与超平面的距离
	inline double distanceP(KDNode *, T);
	// 检查父节点超平面是否在超球体中
	inline bool checkParent(KDNode *, T, double);

public:
	KDTree(std::vector &, int);
	~KDTree();
	void Print();
	int findNearestPoint(T);
};
KDTree的构造函数
template 
KDTree::KDTree(std::vector &data, int dim)
{
	ndim = dim;
	m_data = data; // 拷贝一份数据
	root = build(data); // 递归地构造二叉树
}

template 
KDNode *KDTree::build(std::vector &data)
{
	if (data.empty())
		return nullptr;
	std::vector temp = data;
	int mid_index = static_cast(data.size() / 2); // 二分的索引
	int axis = data.size() > 1 ? chooseAxis(temp) : -1; // 根据每个维度的方差大小选择二分的维度,叶子结点无法二分,默认为-1
	std::sort(temp.begin(), temp.end(), [axis](T a, T b)
			  { return a[axis] < b[axis]; });
			  
	std::vector leftData, rightData;
	leftData.assign(temp.begin(), temp.begin() + mid_index);
	rightData.assign(temp.begin() + mid_index + 1, temp.end());
	
	KDNode *leftNode = build(leftData);
	KDNode *rightNode = build(rightData);
	KDNode *rootNode;
	rootNode = new KDNode(temp[mid_index].index, axis, leftNode, rightNode);
	return rootNode;
}
最近邻搜索

参考[1]

template 
int KDTree::findNearestPoint(T pt)
{
	while (!queueNode.empty())
		queueNode.pop();
	double min_dist = DBL_MAX;
	int resNodeIdx = -1;
	dfs(root, pt);
	while (!queueNode.empty())
	{
		KDNode *curNode = queueNode.top();
		queueNode.pop();
		double dist = distanceT(curNode, pt);
		if (dist < min_dist)
		{
			min_dist = dist;
			resNodeIdx = curNode->index;
		}

		if (!queueNode.empty())
		{
			KDNode *parentNode = queueNode.top();
			int parentAxis = parentNode->axis;
			int parentIndex = parentNode->index;
			if (checkParent(parentNode, pt, min_dist))
			{
				if (m_data[curNode->index][parentNode->axis] < m_data[parentNode->index][parentNode->axis])
					dfs(parentNode->right, pt);
				else
					dfs(parentNode->left, pt);
			}
		}
	}
	return resNodeIdx;
}

template 
void KDTree::dfs(KDNode *node, T pt)
{
	if (node)
	{
		if (visited.find(node->index) != visited.end())
			return;
		queueNode.push(node);
		visited.insert(node->index);
		if (pt[node->axis] < m_data[node->index][node->axis] && node->left)
			return dfs(node->left, pt);
		else if (pt[node->axis] >= m_data[node->index][node->axis] && node->right)
			return dfs(node->right, pt);
	}
}
测试

与VTK官方样例ClosestNPoints进行验证对比。
点云个数在1000以内,性能差不多和VTK的KdTree相当。

int main()
{
    int N = 500;
    // Create some random points
    vtkNew pointSource;
    pointSource->SetNumberOfPoints(N);
    pointSource->Update();

    std::vector> datasets;
    vtkPoints *randPts = pointSource->GetOutput()->GetPoints();
    for (vtkIdType i = 0; i < N; i++)
    {
        double pts[3];
        randPts->GetPoint(i, pts);
        datasets.push_back(Point3D(pts[0], pts[1], pts[2], i));
        // std::cout << pts[0] << "," << pts[1] << "," << pts[2] << std::endl;
    }

    auto t1 = std::chrono::duration_cast(
                  std::chrono::system_clock::now().time_since_epoch())
                  .count();
    
    // Create the tree
    vtkNew pointTree;
    pointTree->SetDataSet(pointSource->GetOutput());
    pointTree->BuildLocator();

    // Find the k closest points to (0,0,0)
    unsigned int k = 1;
    vtkNew testSource;
    testSource->SetNumberOfPoints(1);
    testSource->Update();
    double testPoint[3];
    testSource->GetOutput()->GetPoints()->GetPoint(0, testPoint);
    vtkNew result;
    std::cout << "Test Point: " << testPoint[0] << "," << testPoint[1] << "," << testPoint[2] << std::endl;

    pointTree->FindClosestNPoints(k, testPoint, result);

    for (vtkIdType i = 0; i < k; i++)
    {
        vtkIdType point_ind = result->GetId(i);
        double p[3];
        pointSource->GetOutput()->GetPoint(point_ind, p);
        std::cout << "Closest point " << i << ": Point " << point_ind << ": ("
                  << p[0] << ", " << p[1] << ", " << p[2] << ")" << std::endl;
    }
    auto t2 = std::chrono::duration_cast(
                  std::chrono::system_clock::now().time_since_epoch())
                  .count();

    // Should return:
    // Closest point 0: Point 2: (-0.136162, -0.0276359, 0.0369441)

    // std::vector> datasets = {Point2D(7, 2, 0),
    //                                       Point2D(5, 4, 1),
    //                                       Point2D(9, 6, 2),
    //                                       Point2D(2, 3, 3),
    //                                       Point2D(4, 7, 4),
    //                                       Point2D(8, 1, 5)};
    KDTree> tree(datasets, 3);
    // tree.Print();
    std::cout << tree.findNearestPoint(Point3D(testPoint[0], testPoint[1], testPoint[2])) << std::endl;
    auto t3 = std::chrono::duration_cast(
                  std::chrono::system_clock::now().time_since_epoch())
                  .count();
    std::cout << "VTK Time:" << t2 - t1 << " ms" << std::endl;
    std::cout << "MY Time:" << t3 - t2 << " ms" << std::endl;
    return EXIT_SUCCESS;
}
Test Point: 0.117163,-0.205549,0.352397
Closest point 0: Point 474: (0.12327, -0.22358, 0.322906)
474
VTK Time:4 ms
MY Time:2 ms
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/879280.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号