题目
通过优先队列搜索算法求解八数码拼图问题
代码:
import numpy as np
import heapq
import copy
class VertexNode(object):
def __init__(self, state=None, key=0, space=0, num=0):
self.key = key
self.state = state
self.space = space
self.num = num
def __lt__(self, other):
return self.key < other.key
class Heap(object):
def __init__(self):
self._queue = []
def push(self, item):
heapq.heappush(self._queue, item)
def pop(self):
if self._queue:
return heapq.heappop(self._queue)
else:
return None
def heapify(self):
heapq.heapify(self._queue)
@property
def queue(self):
return self._queue
def dn_num(state, goal): # 不在位将牌数
num = 0
for i in range(9):
if state[i] != goal[i]:
num += 1
return num
def dn_state(all_state, new_state): # 判断新状态是否走过
k = 0
for i in all_state:
if i == new_state:
k = k + 1
if k > 0:
return False
else:
return True
# if new_state in all_state:
# return False
# else:
# return True
def heap_search(start, goal, space):
start_node = VertexNode(state=copy.deepcopy(start), key=0, space=space, num=0)
open = Heap()
open.push(start_node)
all_state = []
count = 0
x = [-3, 3, -1, 1]
all_state.append(list(start_node.state))
while open.queue:
cur_node = open.pop()
if cur_node.state == goal: # 到目标状态结束
count = cur_node.num
break
# 四个条件,判断上下左右移动
if cur_node.space not in [0, 1, 2]:
new_list = copy.deepcopy(cur_node.state)
new_list[cur_node.space + x[0]], new_list[cur_node.space] = new_list[cur_node.space], new_list[cur_node.space + x[0]]
new_key = cur_node.num + dn_num(new_list, goal)
new_node = VertexNode(state=copy.deepcopy(new_list), key=new_key, space=cur_node.space + x[0], num=cur_node.num + 1)
if dn_state(all_state, new_node.state):
all_state.append(list(new_node.state))
open.push(new_node)
if cur_node.space not in [6, 7, 8]:
new_list = copy.deepcopy(cur_node.state)
new_list[cur_node.space + x[1]], new_list[cur_node.space] = new_list[cur_node.space], new_list[cur_node.space + x[1]]
new_key = cur_node.num + dn_num(new_list, goal)
new_node = VertexNode(state=copy.deepcopy(new_list), key=new_key, space=cur_node.space + x[1], num=cur_node.num + 1)
if dn_state(all_state, new_node.state):
all_state.append(list(new_node.state))
open.push(new_node)
if cur_node.space not in [0, 3, 6]:
new_list = copy.deepcopy(cur_node.state)
new_list[cur_node.space + x[2]], new_list[cur_node.space] = new_list[cur_node.space], new_list[cur_node.space + x[2]]
new_key = cur_node.num + dn_num(new_list, goal)
new_node = VertexNode(state=copy.deepcopy(new_list), key=new_key, space=cur_node.space + x[2], num=cur_node.num + 1)
if dn_state(all_state, new_node.state):
all_state.append(list(new_node.state))
open.push(new_node)
if cur_node.space not in [2, 5, 8]:
new_list = copy.deepcopy(cur_node.state)
new_list[cur_node.space + x[3]], new_list[cur_node.space] = new_list[cur_node.space], new_list[cur_node.space + x[3]]
new_key = cur_node.num + dn_num(new_list, goal)
new_node = VertexNode(state=copy.deepcopy(new_list), key=new_key, space=cur_node.space + x[3], num=cur_node.num + 1)
if dn_state(all_state, new_node.state):
all_state.append(list(new_node.state))
open.push(new_node)
# for i in range(4):
# a = cur_node.space + x[i]
# if a > 8 or a < 0:
# continue
# else:
# new_list = copy.deepcopy(cur_node.state)
# new_list[cur_node.space+x[i]], new_list[cur_node.space] = new_list[cur_node.space], new_list[cur_node.space+x[i]]
# new_key = cur_node.num + dn_num(new_list, goal)
# new_node = VertexNode(state=copy.deepcopy(new_list), key=new_key, space=cur_node.space+x[i], num=cur_node.num+1)
# if dn_state(all_state, new_node.state):
# all_state.append(list(new_node.state))
# open.push(new_node)
return count
def main():
m, n = map(int, input().split())
m = m
n = n
a = []
for i in range(3):
a.append(list(map(int, input().rstrip().split())))
b = np.array(a).reshape(3, 3)
start = list(b.flatten()) # 压成一维数组
goal = [0, 1, 2, 3, 4, 5, 6, 7, 8]
goal[3*m+n] = -1
for i in range(9):
if start[i] == -1:
space = i
dis = heap_search(start, goal, space)
print(dis)
if __name__ == '__main__':
main()
"""
2 2
0 1 2
3 -1 4
6 7 5
"""
"""
1 1
-1 0 2
3 1 5
6 7 8
"""
"""
1 1
2 5 -1
1 3 7
6 0 8
"""