栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

用Python从零复现A星寻路算法 | 手撕代码#1

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

用Python从零复现A星寻路算法 | 手撕代码#1

用Python从零复现A星寻路算法 | 手撕代码#1 Ⅰ.实现目标

关于A星算法的代码实现在各大厂人工智能岗位面试中均出现过,需要面试者有较高的手撕代码能力。

这篇文章即在python中从零复现A星寻路算法,并利用matplot将计算所得路线可视化画出,适合新手练习代码能力,也适合准备面试的朋友复习。

本文不注重讲解A星算法的具体内容,着重于通过代码的实现!

*参考资料:https://www.laurentluce.com/posts/solving-mazes-using-python-simple-recursivity-and-a-search/


Ⅱ.开始代码寿司

为了简化,这里不考虑路途权重问题,每个方块只考虑是否可达(即是否为墙)。

我们这里定义两个类,一个为cell,用于记录每个cell的信息,如坐标、是否可达及gfh等信息(若需要可以记录权重);另一个为AStar类,用于实现算法。

一、Cell类

很简单,第一个cell类定义如下:

class cell(object):
    def __init__(self, x, y, reachable:bool) -> None:
        self.x = x
        self.y = y
        self.isReachable = reachable
        self.g = 0
        self.h = 0
        self.f = self.g + self.h
        self.parent = None
        self.weight = 0
        
    def __lt__(self, other):
        return self.f < other.f

其中weights作为可扩展项,此算法测试中不考虑每个cell权重大小来更新,只考虑该cell是否可达来更新。

二、AStar类

第二个AStar类需要实现的功能很多,我们先一一列出来,然后再逐一讲解填充。

class AStar(object):
    def __init__(self) -> None:
        '''
        初始化一些数据结构:堆、集合帮助实现算法
        '''
        pass
    
    def init_grid(self, width, height, cells, walls):
        '''
        初始化地图,导入walls
        '''
        pass
    
    def get_cell(self, x, y):
        '''
        因为输入cells信息时为一维信息,这里需要通过width和height检索到相应位置的cell
        '''
        pass
    
    def caculate_one_way(self, start, end):
        '''
        在地图确定不变的情况下,每次传入不同的起点和终点,计算返回路径
        '''
        pass
    
    def caculate_heuristic(self, cell):
        '''
        计算启发式距离h值
        '''
        pass
    
    def get_adjacent_cell(self, cell):
        '''
        返回cell周围的cell,这里的周围指八个方向
        '''
        pass
    
    def get_updated(self, adj, cell):
        '''
        用于每次更新cell信息
        '''
        pass
    
    def save_path(self):
        '''
        保存计算路径
        '''
        pass
    
    def solve(self):
        '''
        代码核心,实现逻辑
        '''
        pass
    

接下来逐一填充讲解

①初始化数据结构

了解A星算法的都知道其实现需要一些数据结构的帮忙,这一版本的实现需要用到堆、列表与集合

import heapq

class AStar(object):
    def __init__(self) -> None:
        '''
        初始化一些数据结构:堆、集合帮助实现算法
        '''
        self.closed = set()
        self.open = []
        heapq.heapify(self.open)
        self.cells = []
②初始化地图

导入宽度、高度及cells

    def init_grid(self, width, height, walls):
        '''
        初始化地图,导入cells
        '''
        self.grid_width = width
        self.grid_height = height
        
        for i in range(self.grid_height):
            for j in range(self.grid_width):
                if (i,j) in walls:
                    reachable = false
                else:
                    reachable = true
                self.cells.append(cell(i, j, reachable))
③重定位cells
    def get_cell(self, x, y):
        '''
        因为输入cells信息时为一维信息,这里需要通过width和height检索到相应位置的cell
        '''
        return self.cells[ x*self.grid_width + y ]
④设置起终点
	def caculate_one_way(self, start, end):
        '''
        在地图确定不变的情况下,每次传入不同的起点和终点
        '''
        self.start = self.get_cell(*start)
        self.end = self.get_cell(*end)
        
⑤计算启发式距离
    def caculate_heuristic(self, cell):
        '''
        计算启发式距离h值,这里采用曼哈顿距离
        '''
        return 10 * ( abs(self.end.x - cell.x) + abs(self.end.y - cell.y) )
⑥计算临近点
    def get_adjacent_cell(self, cell):
        '''
        返回cell周围的cell,这里的周围指八个方向
        '''
        adj_cells = []
        for dx, dy in [ (1, 0), (0, 1), (-1, 0), (0, -1), (1, -1), (-1, 1), (-1, -1), (1, 1) ]:
            x2 = cell.x + dx
            y2 = cell.y + dy
            if x2>0 and x20 and y2 
⑦更新点 
    def get_updated(self, adj, cell):
        '''
        用于每次更新cell信息
        '''
        adj.g = cell.g + 10
        adj.parent = cell
        adj.h = self.caculate_heuristic(adj)
        adj.f = self.g + self.h
⑧保存路径
    def save_path(self):
        '''
        保存计算路径
        '''
        cell = self.end
        path = [(cell.x, cell.y)]
        while cell.parent is not self.start:
            cell = cell.parent
            path.append((cell.x, cell.y))
        path.append((self.start.x, self.start.y))
        path.reverse()
        return path
⑨逻辑实现

借助上图,我们很容易写出python对应的逻辑代码:

    def solve(self):
        '''
        代码核心,实现逻辑
        '''
        heapq.heappush(self.openlist, (self.start.f, self.start))
        while len(self.openlist):
            f, cell = heapq.heappop(self.openlist)
            self.closed.add(cell)
            if cell is self.end:
                return self.save_path()
            
            adj_cells = self.get_adjacent_cell(cell)
            for adj_cell in adj_cells:
                if adj_cell.isReachable and adj_cell not in self.closed:
                    if ( adj_cell.f, adj_cell ) in self.openlist:
                        if adj_cell.g > cell.g + 10:
                            self.get_updated(adj_cell, cell)
                    else:
                        self.get_updated(adj_cell, cell)
                        heapq.heappush(self.openlist, ( adj_cell.f, adj_cell ))
                            
        raise RuntimeError("A* failed to find a solution")
三、可视化

借助matplotlib将结果可视化

def draw_result(result_path, walls, start, end):
    plt.plot([v[0] for v in result_path], [v[1] for v in result_path])
    plt.plot([v[0] for v in result_path], [v[1] for v in result_path], 'o', color='lightblue')
    plt.plot([start[0], end[0]], [start[1], end[1]], 'o', color='red')
    plt.plot([barrier[0] for barrier in walls ], [barrier[1] for barrier in walls], 's', color='m')
    plt.xlim(-1, 8)
    plt.ylim(-1, 8)
    plt.show()
四、测试

若路径不可行会报错

raise RuntimeError("A* failed to find a solution")

下面测试可行路径

if __name__ == '__main__':
    a = AStar()
    walls = ((2, 5), (2, 6), (3, 6), (4, 6), (5, 6), (5, 5), (5, 4), 
    		(5, 3), (5, 2), (4, 2), (3, 2), (7, 1), (6, 4), (1, 5), (7, 6))
    a.init_grid(8, 8, walls)
    a.caculate_one_way((0, 0), (7, 7))
    path = a.solve()
    print(path)
    draw_result(path,walls,(0, 0), (7,7))

结果如下,紫色代表墙体,蓝色为计算路径


Ⅲ.完整代码
from __future__ import print_function
import heapq
import matplotlib.pyplot as plt
import unittest

class cell(object):
    def __init__(self, x, y, reachable:bool) -> None:
        self.x = x
        self.y = y
        self.isReachable = reachable
        self.g = 0
        self.h = 0
        self.f = self.g + self.h
        self.parent = None
        self.weight = 0
        
    def __lt__(self, other):
        return self.f < other.f
        
class AStar(object):
    def __init__(self) -> None:
        '''
        初始化一些数据结构:堆、集合帮助实现算法
        '''
        self.closed = set()
        self.openlist = []
        heapq.heapify(self.openlist)
        self.cells = []
    
    def init_grid(self, width, height, walls):
        '''
        初始化地图,导入cells
        '''
        self.grid_width = width
        self.grid_height = height
        
        for i in range(self.grid_height):
            for j in range(self.grid_width):
                if (i,j) in walls:
                    reachable = False
                else:
                    reachable = True
                self.cells.append(cell(i, j, reachable))
    
    def get_cell(self, x, y):
        '''
        因为输入cells信息时为一维信息,这里需要通过width和height检索到相应位置的cell
        '''
        return self.cells[ x*self.grid_width + y ]
    
    def caculate_one_way(self, start, end):
        '''
        在地图确定不变的情况下,每次传入不同的起点和终点
        '''
        self.start = self.get_cell(*start)
        self.end = self.get_cell(*end)
        
    def caculate_heuristic(self, cell):
        '''
        计算启发式距离h值,这里采用曼哈顿距离
        '''
        return 10 * ( abs(self.end.x - cell.x) + abs(self.end.y - cell.y) )
        
    def get_adjacent_cell(self, cell):
        '''
        返回cell周围的cell,这里的周围指八个方向
        '''
        adj_cells = []
        for dx, dy in [ (1, 0), (0, 1), (-1, 0), (0, -1), (1, -1), (-1, 1), (-1, -1), (1, 1) ]:
            x2 = cell.x + dx
            y2 = cell.y + dy
            if x2>0 and x20 and y2 cell.g + 10:
                            self.get_updated(adj_cell, cell)
                    else:
                        self.get_updated(adj_cell, cell)
                        heapq.heappush(self.openlist, ( adj_cell.f, adj_cell ))
                            
        raise RuntimeError("A* failed to find a solution")


def draw_result(result_path, walls, start, end):
    plt.plot([v[0] for v in result_path], [v[1] for v in result_path])
    plt.plot([v[0] for v in result_path], [v[1] for v in result_path], 'o', color='lightblue')
    plt.plot([start[0], end[0]], [start[1], end[1]], 'o', color='red')
    plt.plot([barrier[0] for barrier in walls ], [barrier[1] for barrier in walls], 's', color='m')
    plt.xlim(-1, 8)
    plt.ylim(-1, 8)
    plt.show()

if __name__ == '__main__':
    a = AStar()
    walls = ((2, 5), (2, 6), (3, 6), (4, 6), (5, 6), (5, 5), 
             (5, 4), (5, 3), (5, 2), (4, 2), (3, 2), (7, 1), (6, 4), (1, 5), (7, 6))
    a.init_grid(8, 8, walls)
    a.caculate_one_way((0, 0), (7, 7))
    path = a.solve()
    print(path)
    draw_result(path,walls,(0, 0), (7,7))

END

祝大家学习面试顺利,OFFER多多。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/739238.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号