【算法 python A*算法的实现 可视化】

时间:2024-12-02 07:36:26

 - 算法实现:

import heapq


class Node:
    def __init__(self, position, g=0, h=0):
        self.position = position  # 节点的位置
        self.g = g  # 从起点到当前节点的成本
        self.h = h  # 从当前节点到终点的启发式估计成本
        self.f = g + h  # 总成本
        self.parent = None  # 父节点

    def __lt__(self, other):
        return self.f < other.f


def heuristic(a, b):
    # 使用曼哈顿距离作为启发式函数
    return abs(a[0] - b[0]) + abs(a[1] - b[1])


def a_star(start, goal, grid):
    open_list = []
    closed_list = set()

    start_node = Node(start, 0, heuristic(start, goal))
    goal_node = Node(goal)

    heapq.heappush(open_list, start_node)

    while open_list:
        current_node = heapq.heappop(open_list)

        if current_node.position == goal_node.position:
            path = []
            while current_node:
                path.append(current_node.position)
                current_node = current_node.parent
            return path[::-1]  # 返回反转后的路径

        closed_list.add(current_node.position)

        # 获取相邻节点
        neighbors = [(0, 1), (1, 0), (0, -1), (-1, 0)]  # 右,下,左,上
        for new_position in neighbors:
            node_position = (
                current_node.position[0] + new_position[0],
                current_node.position[1] + new_position[1],
            )

            # 检查节点是否在网格内
            if (0 <= node_position[0] < len(grid)) and (
                0 <= node_position[1] < len(grid[0])
            ):
                # 检查节点是否是障碍物
                if grid[node_position[0]][node_position[1]] == 1:
                    continue

                # 创建新的节点
                neighbor_node = Node(node_position)
                if neighbor_node.position in closed_list:
                    continue

                # 计算 g, h, f 值
                neighbor_node.g = current_node.g + 1
                neighbor_node.h = heuristic(neighbor_node.position, goal_node.position)
                neighbor_node.f = neighbor_node.g + neighbor_node.h
                neighbor_node.parent = current_node

                # 检查节点是否在开放列表中
                if add_to_open(open_list, neighbor_node):
                    heapq.heappush(open_list, neighbor_node)

    return None  # 如果没有找到路径


def add_to_open(open_list, neighbor):
    for node in open_list:
        if neighbor.position == node.position and neighbor.g >= node.g:
            return False
    return True

 - 测试:

# 示例使用

import numpy as np
grid = np.array(
    [
        [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0],
        [0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1],
        [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1],
        [0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1],
        [0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
        [0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1],
        [0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1],
        [0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1],
        [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1],
        [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0],
        [0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0],
    ]
)

start = (0, 0)  # 起点
goal = (15, 15)  # 终点

path = a_star(start, goal, grid)
print("找到的路径:", path)

找到的路径: [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (4, 1), (4, 2), (3, 2), (2, 2), (1, 2), (0, 2), (0, 3), (0, 4), (1, 4), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (1, 8), (0, 8), (0, 9), (0, 10), (0, 11), (0, 12), (0, 13), (0, 14), (1, 14), (2, 14), (3, 14), (4, 14), (4, 13), (4, 12), (4, 11), (4, 10), (4, 9), (4, 8), (4, 7), (4, 6), (5, 6), (6, 6), (6, 5), (6, 4), (7, 4), (8, 4), (8, 3), (8, 2), (9, 2), (10, 2), (11, 2), (12, 2), (12, 3), (12, 4), (12, 5), (12, 6), (11, 6), (10, 6), (9, 6), (8, 6), (8, 7), (8, 8), (7, 8), (6, 8), (6, 9), (6, 10), (6, 11), (6, 12), (7, 12), (8, 12), (8, 11), (8, 10), (9, 10), (10, 10), (10, 9), (10, 8), (11, 8), (12, 8), (12, 9), (12, 10), (13, 10), (14, 10), (14, 11), (14, 12), (13, 12), (12, 12), (11, 12), (10, 12), (10, 13), (10, 14), (11, 14), (12, 14), (13, 14), (13, 15), (14, 15), (15, 15)]

 - 可视化

def visualize_path(maze, path):
    fig, axis = plt.subplots(1, 2)
    mask = maze == 1
    maze[mask] = 0
    maze[~mask] = 1
    axis[0].imshow(maze, cmap="gray")
    axis[0].set_xticks([])
    axis[0].set_yticks([])
    axis[0].set_title("map")

    for position in path:
        maze[position[0]][position[1]] = 2
    axis[1].imshow(maze, cmap="gray")
    axis[1].set_xticks([])
    axis[1].set_yticks([])
    axis[1].set_title("shortest path")
    plt.show()


visualize_path(grid.copy(), path)