- 算法实现:
import heapq
class Node:
def __init__(self, position, g=0, h=0):
self.position = position # 节点的位置
self.g = g # 从起点到当前节点的成本
self.h = h # 从当前节点到终点的启发式估计成本
self.f = g + h # 总成本
self.parent = None # 父节点
def __lt__(self, other):
return self.f < other.f
def heuristic(a, b):
# 使用曼哈顿距离作为启发式函数
return abs(a[0] - b[0]) + abs(a[1] - b[1])
def a_star(start, goal, grid):
open_list = []
closed_list = set()
start_node = Node(start, 0, heuristic(start, goal))
goal_node = Node(goal)
heapq.heappush(open_list, start_node)
while open_list:
current_node = heapq.heappop(open_list)
if current_node.position == goal_node.position:
path = []
while current_node:
path.append(current_node.position)
current_node = current_node.parent
return path[::-1] # 返回反转后的路径
closed_list.add(current_node.position)
# 获取相邻节点
neighbors = [(0, 1), (1, 0), (0, -1), (-1, 0)] # 右,下,左,上
for new_position in neighbors:
node_position = (
current_node.position[0] + new_position[0],
current_node.position[1] + new_position[1],
)
# 检查节点是否在网格内
if (0 <= node_position[0] < len(grid)) and (
0 <= node_position[1] < len(grid[0])
):
# 检查节点是否是障碍物
if grid[node_position[0]][node_position[1]] == 1:
continue
# 创建新的节点
neighbor_node = Node(node_position)
if neighbor_node.position in closed_list:
continue
# 计算 g, h, f 值
neighbor_node.g = current_node.g + 1
neighbor_node.h = heuristic(neighbor_node.position, goal_node.position)
neighbor_node.f = neighbor_node.g + neighbor_node.h
neighbor_node.parent = current_node
# 检查节点是否在开放列表中
if add_to_open(open_list, neighbor_node):
heapq.heappush(open_list, neighbor_node)
return None # 如果没有找到路径
def add_to_open(open_list, neighbor):
for node in open_list:
if neighbor.position == node.position and neighbor.g >= node.g:
return False
return True
- 测试:
# 示例使用
import numpy as np
grid = np.array(
[
[0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0],
[0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1],
[0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1],
[0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1],
[0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1],
[0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],
[0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1],
[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1],
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0],
[0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0],
]
)
start = (0, 0) # 起点
goal = (15, 15) # 终点
path = a_star(start, goal, grid)
print("找到的路径:", path)
找到的路径: [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (4, 1), (4, 2), (3, 2), (2, 2), (1, 2), (0, 2), (0, 3), (0, 4), (1, 4), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (1, 8), (0, 8), (0, 9), (0, 10), (0, 11), (0, 12), (0, 13), (0, 14), (1, 14), (2, 14), (3, 14), (4, 14), (4, 13), (4, 12), (4, 11), (4, 10), (4, 9), (4, 8), (4, 7), (4, 6), (5, 6), (6, 6), (6, 5), (6, 4), (7, 4), (8, 4), (8, 3), (8, 2), (9, 2), (10, 2), (11, 2), (12, 2), (12, 3), (12, 4), (12, 5), (12, 6), (11, 6), (10, 6), (9, 6), (8, 6), (8, 7), (8, 8), (7, 8), (6, 8), (6, 9), (6, 10), (6, 11), (6, 12), (7, 12), (8, 12), (8, 11), (8, 10), (9, 10), (10, 10), (10, 9), (10, 8), (11, 8), (12, 8), (12, 9), (12, 10), (13, 10), (14, 10), (14, 11), (14, 12), (13, 12), (12, 12), (11, 12), (10, 12), (10, 13), (10, 14), (11, 14), (12, 14), (13, 14), (13, 15), (14, 15), (15, 15)]
- 可视化
def visualize_path(maze, path):
fig, axis = plt.subplots(1, 2)
mask = maze == 1
maze[mask] = 0
maze[~mask] = 1
axis[0].imshow(maze, cmap="gray")
axis[0].set_xticks([])
axis[0].set_yticks([])
axis[0].set_title("map")
for position in path:
maze[position[0]][position[1]] = 2
axis[1].imshow(maze, cmap="gray")
axis[1].set_xticks([])
axis[1].set_yticks([])
axis[1].set_title("shortest path")
plt.show()
visualize_path(grid.copy(), path)