python实现A*算法,显示步骤

import numpy as np
import queue
import prettytable as pt

'''
输入格式:0表示白色棋子,1表示黑色棋子,-1表示空位。
初始状态:             目标状态:
01011                11111
110-11               01111
01110                00-111
01010                00001
00100                00000
'''
#初始状态的数据可以依据上图zi'xi
start_data = np.array([[1,0,1,1,0],[0,1,-1,1,1], [1,0,1,1,1],[0,1,0,0,1],[0,0,0,0,0]])
#目标状态的数据
end_data = np.array([[1,1,1,1,1], [0,1,1,1,1], [0,0,-1,1,1],[0,0,0,0,1],[0,0,0,0,0]])


# 找空格(-1)号元素在哪的函数
def find_zero(num):
    tmp_x, tmp_y = np.where(num == -1)
    # 返回0所在的x坐标与y坐标
    return tmp_x[0], tmp_y[0]


# 交换位置的函数 移动的时候要判断一下是否可以移动
# 记空格为-1号,则每次移动一个数字可以看做对空格(-1)的移动,总共有八种可能
def swap(num_data, direction):
    x, y = find_zero(num_data)
    num = np.copy(num_data)
    #print(direction)
    #'left1down2', 'left1up2', 'left2down1', 'left2up1', 'right1down2', 'right1up2', 'right2down1', 'right2up1'
    if direction == 'left1down2':
        if y == 0 or x>=3:
            #print('不能左移')
            return num
        num[x][y] = num[x+2][y - 1]
        num[x+2][y - 1] = -1
        return num

    if direction == 'left1up2':
        if y == 0 or x<=1:
           # print('不能左移')
            return num
        num[x][y] = num[x-2][y - 1]
        num[x-2][y - 1] = -1
        return num

    if direction == 'left2down1':
        if y <= 1 or x>3:
            #print('不能左移')
            return num
        num[x][y] = num[x+1][y-2]
        num[x+1][y - 2] = -1
        return num

    if direction == 'left2up1':
        if y <= 1 or x<1:
            #print('不能左移')
            return num
        num[x][y] = num[x-1][y - 2]
        num[x-1][y-2] = -1
        return num


    if direction == 'right1down2':
        if y == 4 or x>=3:
            #print('不能右移')
            return num
        num[x][y] = num[x+2][y + 1]
        num[x+2][y + 1] = -1
        return num

    if direction == 'right1up2':
        if y == 4 or x<=1:
            #print('不能右移')
            return num
        num[x][y] = num[x-2][y + 1]
        num[x-2][y + 1] = -1
        return num

    if direction == 'right2down1':
        #print(x,y)
        if y >= 3 or x>3:
            #print('不能右移')
            return num
        num[x][y] = num[x+1][y+2]
        num[x+1][y + 2] = -1
        return num

    if direction == 'right2up1':
        if y >=3 or x<1:
            #print('不能右移')
            return num
        num[x][y] = num[x-1][y+2]
        num[x-1][y + 2] = -1
        return num
'''f(n)=d(n)+w(n)
其中d(n)为搜索树的深度
w(n)为当前状态到目标状态的实际最小费用的估计值
编写一个用来计算w(n)的函数
'''
def cal_wcost(num):
    '''
    计算w(n)的值,及放错元素的个数
    :param num: 要比较的数组的值
    :return: 返回w(n)的值
    '''
    # return sum(sum(num != end_data)) - int(num[1][1] != 0)
    con = 0
    for i in range(5):
        for j in range(5):
            tmp_num = num[i][j]
            compare_num = end_data[i][j]
            if tmp_num != -1:
                con += tmp_num != compare_num
    return con


# 将data转化为不一样的数字类似于hash,一遍存入close字典中
def data_to_int(num):
    value = 0
    for i in num:
        for j in i:
            value = value * 10 + j
    return value


# 编写一个给open表排序的函数
def sorte_by_floss():
    tmp_open = opened.queue.copy()
    length = len(tmp_open)
    # 排序,从小到大,当一样的时候按照step的大小排序
    for i in range(length):
        for j in range(length):
            if tmp_open[i].f_loss < tmp_open[j].f_loss:
                tmp = tmp_open[i]
                tmp_open[i] = tmp_open[j]
                tmp_open[j] = tmp
            if tmp_open[i].f_loss == tmp_open[j].f_loss:
                if tmp_open[i].step > tmp_open[j].step:
                    tmp = tmp_open[i]
                    tmp_open[i] = tmp_open[j]
                    tmp_open[j] = tmp
    opened.queue = tmp_open


# 编写一个比较当前节点是否在open表中,如果在,根据f(n)的大小来判断去留
def refresh_open(now_node):
    '''
    :param now_node: 当前的节点
    :return:
    '''
    tmp_open = opened.queue.copy()  # 复制一份open表的内容
    for i in range(len(tmp_open)):
        '''这里要比较一下node和now_node的区别,并决定是否更新'''
        data = tmp_open[i]
        now_data = now_node.data
        if (data == now_data).all():
            data_f_loss = tmp_open[i].f_loss
            now_data_f_loss = now_node.f_loss
            if data_f_loss <= now_data_f_loss:
                return False
            else:
                print('')
                tmp_open[i] = now_node
                opened.queue = tmp_open  # 更新之后的open表还原
                return True
    tmp_open.append(now_node)
    opened.queue = tmp_open  # 更新之后的open表还原
    return True

'''
f_loss 记录f(n)的值
'''
# 创建Node类 (包含当前数据内容,父节点,步数)
class Node:
    f_loss = -1  # 启发值
    step = 0  # 初始状态到当前状态的距离(步数)
    parent = None,  # 父节点

    # 用状态和步数构造节点对象
    def __init__(self, data, step, parent):
        self.data = data  # 当前状态数值
        self.step = step
        self.parent = parent
        # 计算f(n)的值
        self.f_loss = cal_wcost(data) + step


'算法'
opened = queue.Queue()  # open表
start_node = Node(start_data, 0, None)
opened.put(start_node)
closed = {}  # close表

def method_a_function():
    con = 0
    while len(opened.queue) != 0:
        node = opened.get()
        if (node.data == end_data).all():
            print(f'总共耗费{con}轮')
            return node
        closed[data_to_int(node.data)] = 1  # 奖取出的点加入closed表中
        # 八种移动方法
        for action in ['left1down2','left1up2', 'left2down1','left2up1','right1down2','right1up2', 'right2down1','right2up1']:
            # 创建子节点
            child_node = Node(swap(node.data, action),node.step + 1,node)
            index = data_to_int(child_node.data)
            if index not in closed:
                if refresh_open(child_node):
                    con=con+1
        '''为open表进行排序,根据其中的f_loss值'''
        sorte_by_floss()
        if con ==15:
            print("15步不能运行出来!")
            return node
result_node = method_a_function()

def output_result(node):
    all_node = [node]
    for i in range(node.step):
        father_node = node.parent
        all_node.append(father_node)
        node = father_node
    return reversed(all_node)

node_list = list(output_result(result_node))
tb = pt.PrettyTable()
tb.field_names = ['step','data','f_loss']
for node in node_list:
    num = node.data
    tb.add_row([node.step, num, node.f_loss])
    if node != node_list[-1]:
        tb.add_row(['---','--------','---'])
print(tb)

运行结果:

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值