✅ 功能说明
-
使用 Q-learning 在二维地图上训练出最短路径策略
-
支持从任意起点坐标
start_x, start_y
出发 -
使用保存的 Q 表模型计算并打印:
-
最短路径坐标序列
-
总步数
-
控制台可视化动画
-
✅ 文件 1:训练模型并保存(train_q_robot.py
)
import numpy as np
import pickle
import random
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
# 地图与参数设置
GRID_SIZE = 5
GOAL_POS = (4, 4)
n_states = GRID_SIZE * GRID_SIZE
n_actions = 4
actions = ['up', 'down', 'left', 'right']
# 初始化 Q 表(共享资源)
Q = np.zeros((n_states, n_actions))
Q_lock = Lock()
# 超参数
alpha = 0.1
gamma = 0.95
epsilon = 0.2
episodes = 100000
num_threads = 16 # 并发线程数量
def pos_to_state(x, y):
return x * GRID_SIZE + y
def state_to_pos(state):
return divmod(state, GRID_SIZE)
def step(state, action):
x, y = state_to_pos(state)
if action == 0 and x > 0: x -= 1
if action == 1 and x < GRID_SIZE - 1: x += 1
if action == 2 and y > 0: y -= 1
if action == 3 and y < GRID_SIZE - 1: y += 1
new_state = pos_to_state(x, y)
reward = 1 if (x, y) == GOAL_POS else 0
done = reward == 1
return new_state, reward, done
# 单个 episode 的训练逻辑
def train_episode():
state = pos_to_state(0, 0)
done = False
while not done:
if random.uniform(0, 1) < epsilon:
action = random.randint(0, 3)
else:
with Q_lock:
action = np.argmax(Q[state])
next_state, reward, done = step(state, action)
with Q_lock:
Q[state][action] += alpha * (
reward + gamma * np.max(Q[next_state]) - Q[state][action]
)
state = next_state
# 多线程并发训练入口
with ThreadPoolExecutor(max_workers=num_threads) as executor:
executor.map(lambda _: train_episode(), range(episodes))
# 保存 Q 表
with open("q_table.pkl", "wb") as f:
pickle.dump(Q, f)
print("✅ 多线程 Q-learning 训练完成并保存为 q_table.pkl")
✅ 文件 2:使用模型寻找最短路径(run_q_robot.py
)
import numpy as np
import pickle
import os
import time
GRID_SIZE = 5
GOAL_POS = (4, 4) # 目标终点坐标(充电站)
# 坐标转状态编号
def pos_to_state(x, y):
return x * GRID_SIZE + y
# 状态编号转坐标
def state_to_pos(state):
return divmod(state, GRID_SIZE)
# 控制台绘制地图,带路径与当前位置标记
def draw_grid(current, path_set):
os.system('cls' if os.name == 'nt' else 'clear') # 清屏
for i in range(GRID_SIZE):
row = ''
for j in range(GRID_SIZE):
if (i, j) == current:
row += '[🤖]' # 当前机器人位置
elif (i, j) == GOAL_POS:
row += '[🔋]' # 充电站
elif (i, j) in path_set:
row += '[• ]' # 走过的路径
else:
row += '[ ]'
print(row)
time.sleep(0.3)
# 加载保存好的 Q 表模型
with open("q_table.pkl", "rb") as f:
Q = pickle.load(f)
# 获取用户输入的起点坐标
try:
start_x = int(input("请输入起点 X 坐标(0~4):"))
start_y = int(input("请输入起点 Y 坐标(0~4):"))
except:
print("❌ 输入无效,请输入整数")
exit()
# 检查坐标合法性
if not (0 <= start_x < GRID_SIZE and 0 <= start_y < GRID_SIZE):
print("❌ 起点超出地图范围")
exit()
# 初始化状态与路径
state = pos_to_state(start_x, start_y)
path = [state]
# 执行路径查找
while True:
current_pos = state_to_pos(state)
draw_grid(current_pos, set(state_to_pos(s) for s in path))
# 如果到达终点
if current_pos == GOAL_POS:
print("✅ 已成功到达充电站")
break
# 选择最优动作
action = np.argmax(Q[state])
x, y = current_pos
if action == 0: x -= 1
elif action == 1: x += 1
elif action == 2: y -= 1
elif action == 3: y += 1
# 保证位置不越界
x = max(0, min(GRID_SIZE - 1, x))
y = max(0, min(GRID_SIZE - 1, y))
next_state = pos_to_state(x, y)
# 防止死循环(重复走)
if next_state == state or next_state in path:
print("❌ 死循环或路径卡死,停止")
break
path.append(next_state)
state = next_state
# 打印最终结果
print("🧭 最短路径坐标序列:", [state_to_pos(s) for s in path])
print(f"📏 路径总步数:{len(path) - 1}")
✅ 示例运行
用户输入:
起点 X 坐标(0~4):0
起点 Y 坐标(0~4):1
输出:
[ ][• ][ ][ ][ ]
[ ][• ][ ][ ][ ]
[ ][• ][ ][ ][ ]
[ ][• ][ ][ ][ ]
[ ][• ][• ][• ][🤖]
✅ 已成功到达充电站
🧭 最短路径坐标序列: [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (4, 2), (4, 3), (4, 4)]
📏 路径总步数:7