关于A星算法的代码实现在各大厂人工智能岗位面试中均出现过,需要面试者有较高的手撕代码能力。
这篇文章即在python中从零复现A星寻路算法,并利用matplot将计算所得路线可视化画出,适合新手练习代码能力,也适合准备面试的朋友复习。
本文不注重讲解A星算法的具体内容,着重于通过代码的实现!
*参考资料:https://www.laurentluce.com/posts/solving-mazes-using-python-simple-recursivity-and-a-search/
Ⅱ.开始代码寿司
为了简化,这里不考虑路途权重问题,每个方块只考虑是否可达(即是否为墙)。
我们这里定义两个类,一个为cell,用于记录每个cell的信息,如坐标、是否可达及gfh等信息(若需要可以记录权重);另一个为AStar类,用于实现算法。
一、Cell类很简单,第一个cell类定义如下:
class cell(object):
def __init__(self, x, y, reachable:bool) -> None:
self.x = x
self.y = y
self.isReachable = reachable
self.g = 0
self.h = 0
self.f = self.g + self.h
self.parent = None
self.weight = 0
def __lt__(self, other):
return self.f < other.f
其中weights作为可扩展项,此算法测试中不考虑每个cell权重大小来更新,只考虑该cell是否可达来更新。
二、AStar类第二个AStar类需要实现的功能很多,我们先一一列出来,然后再逐一讲解填充。
class AStar(object):
def __init__(self) -> None:
'''
初始化一些数据结构:堆、集合帮助实现算法
'''
pass
def init_grid(self, width, height, cells, walls):
'''
初始化地图,导入walls
'''
pass
def get_cell(self, x, y):
'''
因为输入cells信息时为一维信息,这里需要通过width和height检索到相应位置的cell
'''
pass
def caculate_one_way(self, start, end):
'''
在地图确定不变的情况下,每次传入不同的起点和终点,计算返回路径
'''
pass
def caculate_heuristic(self, cell):
'''
计算启发式距离h值
'''
pass
def get_adjacent_cell(self, cell):
'''
返回cell周围的cell,这里的周围指八个方向
'''
pass
def get_updated(self, adj, cell):
'''
用于每次更新cell信息
'''
pass
def save_path(self):
'''
保存计算路径
'''
pass
def solve(self):
'''
代码核心,实现逻辑
'''
pass
接下来逐一填充讲解
①初始化数据结构了解A星算法的都知道其实现需要一些数据结构的帮忙,这一版本的实现需要用到堆、列表与集合
import heapq
class AStar(object):
def __init__(self) -> None:
'''
初始化一些数据结构:堆、集合帮助实现算法
'''
self.closed = set()
self.open = []
heapq.heapify(self.open)
self.cells = []
②初始化地图
导入宽度、高度及cells
def init_grid(self, width, height, walls):
'''
初始化地图,导入cells
'''
self.grid_width = width
self.grid_height = height
for i in range(self.grid_height):
for j in range(self.grid_width):
if (i,j) in walls:
reachable = false
else:
reachable = true
self.cells.append(cell(i, j, reachable))
③重定位cells
def get_cell(self, x, y):
'''
因为输入cells信息时为一维信息,这里需要通过width和height检索到相应位置的cell
'''
return self.cells[ x*self.grid_width + y ]
④设置起终点
def caculate_one_way(self, start, end):
'''
在地图确定不变的情况下,每次传入不同的起点和终点
'''
self.start = self.get_cell(*start)
self.end = self.get_cell(*end)
⑤计算启发式距离
def caculate_heuristic(self, cell):
'''
计算启发式距离h值,这里采用曼哈顿距离
'''
return 10 * ( abs(self.end.x - cell.x) + abs(self.end.y - cell.y) )
⑥计算临近点
def get_adjacent_cell(self, cell):
'''
返回cell周围的cell,这里的周围指八个方向
'''
adj_cells = []
for dx, dy in [ (1, 0), (0, 1), (-1, 0), (0, -1), (1, -1), (-1, 1), (-1, -1), (1, 1) ]:
x2 = cell.x + dx
y2 = cell.y + dy
if x2>0 and x20 and y2
⑦更新点
def get_updated(self, adj, cell):
'''
用于每次更新cell信息
'''
adj.g = cell.g + 10
adj.parent = cell
adj.h = self.caculate_heuristic(adj)
adj.f = self.g + self.h
⑧保存路径
def save_path(self):
'''
保存计算路径
'''
cell = self.end
path = [(cell.x, cell.y)]
while cell.parent is not self.start:
cell = cell.parent
path.append((cell.x, cell.y))
path.append((self.start.x, self.start.y))
path.reverse()
return path
⑨逻辑实现
借助上图,我们很容易写出python对应的逻辑代码:
def solve(self):
'''
代码核心,实现逻辑
'''
heapq.heappush(self.openlist, (self.start.f, self.start))
while len(self.openlist):
f, cell = heapq.heappop(self.openlist)
self.closed.add(cell)
if cell is self.end:
return self.save_path()
adj_cells = self.get_adjacent_cell(cell)
for adj_cell in adj_cells:
if adj_cell.isReachable and adj_cell not in self.closed:
if ( adj_cell.f, adj_cell ) in self.openlist:
if adj_cell.g > cell.g + 10:
self.get_updated(adj_cell, cell)
else:
self.get_updated(adj_cell, cell)
heapq.heappush(self.openlist, ( adj_cell.f, adj_cell ))
raise RuntimeError("A* failed to find a solution")
三、可视化
借助matplotlib将结果可视化
def draw_result(result_path, walls, start, end):
plt.plot([v[0] for v in result_path], [v[1] for v in result_path])
plt.plot([v[0] for v in result_path], [v[1] for v in result_path], 'o', color='lightblue')
plt.plot([start[0], end[0]], [start[1], end[1]], 'o', color='red')
plt.plot([barrier[0] for barrier in walls ], [barrier[1] for barrier in walls], 's', color='m')
plt.xlim(-1, 8)
plt.ylim(-1, 8)
plt.show()
四、测试
若路径不可行会报错
raise RuntimeError("A* failed to find a solution")
下面测试可行路径
if __name__ == '__main__':
a = AStar()
walls = ((2, 5), (2, 6), (3, 6), (4, 6), (5, 6), (5, 5), (5, 4),
(5, 3), (5, 2), (4, 2), (3, 2), (7, 1), (6, 4), (1, 5), (7, 6))
a.init_grid(8, 8, walls)
a.caculate_one_way((0, 0), (7, 7))
path = a.solve()
print(path)
draw_result(path,walls,(0, 0), (7,7))
结果如下,紫色代表墙体,蓝色为计算路径
Ⅲ.完整代码
from __future__ import print_function
import heapq
import matplotlib.pyplot as plt
import unittest
class cell(object):
def __init__(self, x, y, reachable:bool) -> None:
self.x = x
self.y = y
self.isReachable = reachable
self.g = 0
self.h = 0
self.f = self.g + self.h
self.parent = None
self.weight = 0
def __lt__(self, other):
return self.f < other.f
class AStar(object):
def __init__(self) -> None:
'''
初始化一些数据结构:堆、集合帮助实现算法
'''
self.closed = set()
self.openlist = []
heapq.heapify(self.openlist)
self.cells = []
def init_grid(self, width, height, walls):
'''
初始化地图,导入cells
'''
self.grid_width = width
self.grid_height = height
for i in range(self.grid_height):
for j in range(self.grid_width):
if (i,j) in walls:
reachable = False
else:
reachable = True
self.cells.append(cell(i, j, reachable))
def get_cell(self, x, y):
'''
因为输入cells信息时为一维信息,这里需要通过width和height检索到相应位置的cell
'''
return self.cells[ x*self.grid_width + y ]
def caculate_one_way(self, start, end):
'''
在地图确定不变的情况下,每次传入不同的起点和终点
'''
self.start = self.get_cell(*start)
self.end = self.get_cell(*end)
def caculate_heuristic(self, cell):
'''
计算启发式距离h值,这里采用曼哈顿距离
'''
return 10 * ( abs(self.end.x - cell.x) + abs(self.end.y - cell.y) )
def get_adjacent_cell(self, cell):
'''
返回cell周围的cell,这里的周围指八个方向
'''
adj_cells = []
for dx, dy in [ (1, 0), (0, 1), (-1, 0), (0, -1), (1, -1), (-1, 1), (-1, -1), (1, 1) ]:
x2 = cell.x + dx
y2 = cell.y + dy
if x2>0 and x20 and y2 cell.g + 10:
self.get_updated(adj_cell, cell)
else:
self.get_updated(adj_cell, cell)
heapq.heappush(self.openlist, ( adj_cell.f, adj_cell ))
raise RuntimeError("A* failed to find a solution")
def draw_result(result_path, walls, start, end):
plt.plot([v[0] for v in result_path], [v[1] for v in result_path])
plt.plot([v[0] for v in result_path], [v[1] for v in result_path], 'o', color='lightblue')
plt.plot([start[0], end[0]], [start[1], end[1]], 'o', color='red')
plt.plot([barrier[0] for barrier in walls ], [barrier[1] for barrier in walls], 's', color='m')
plt.xlim(-1, 8)
plt.ylim(-1, 8)
plt.show()
if __name__ == '__main__':
a = AStar()
walls = ((2, 5), (2, 6), (3, 6), (4, 6), (5, 6), (5, 5),
(5, 4), (5, 3), (5, 2), (4, 2), (3, 2), (7, 1), (6, 4), (1, 5), (7, 6))
a.init_grid(8, 8, walls)
a.caculate_one_way((0, 0), (7, 7))
path = a.solve()
print(path)
draw_result(path,walls,(0, 0), (7,7))
END
祝大家学习面试顺利,OFFER多多。



