import random
import math
import matplotlib.pyplot as plt
import networkx as nx
# 定义城市类
class City:
def __init__(self, name, x, y):
self.name = name
self.x = x
self.y = y
# 定义旅行商问题类
class TravelingSalesman:
def __init__(self, cities):
self.cities = cities
self.distances = self.calculate_distances()
def calculate_distances(self):
distances = {}
for city1 in self.cities:
distances[city1.name] = {}
for city2 in self.cities:
if city1 != city2:
distances[city1.name][city2.name] = math.sqrt((city1.x - city2.x) ** 2 + (city1.y - city2.y) ** 2)
return distances
# 定义MCTS节点类
class Node:
def __init__(self, state, parent=None):
self.state = state # 城市序列
self.parent = parent
self.children = []
self.visits = 0
self.value = 0
# 选择节点
def select(node, C=1.0):
while node.children:
node = max(node.children, key=lambda child: child.value / child.visits + C * math.sqrt(2.0 * math.log(node.visits) / float(child.visits)))
return node
# 扩展节点
def expand(node, available_cities):
city = random.choice(available_cities)
new_state = node.state + [city]
available_cities.remove(city)
new_node = Node(new_state, parent=node)
node.children.append(new_node)
return new_node
# 模拟一次TSP路径并计算距离
def simulate(node, traveling_salesman):
state = node.state.copy()
random.shuffle(state) # 随机重新排序城市序列
total_distance = 0
for i in range(len(state) - 1):
total_distance += traveling_salesman.distances[state[i]][state[i + 1]]
total_distance += traveling_salesman.distances[state[-1]][state[0]] # 回到起始城市
return total_distance
# 回溯更新节点信息
def backpropagate(node, value):
while node:
node.visits += 1
node.value += value
node = node.parent
# 可视化搜索树
def visualize_tree(root):
G = nx.DiGraph()
node_dict = {} # 用于保存节点对象的字典
def add_node_to_graph(node):
if node not in node_dict:
node_dict[node] = len(node_dict) # 用于为节点分配唯一的标识符
return node_dict[node]
queue = [(None, root)]
while queue:
parent, node = queue.pop(0)
parent_id = add_node_to_graph(parent)
node_id = add_node_to_graph(node)
G.add_node(node_id, label=", ".join(map(str, node.state)))
if parent:
G.add_edge(parent_id, node_id)
for child in node.children:
queue.append((node, child))
pos = nx.spring_layout(G, seed=42)
labels = nx.get_node_attributes(G, 'label')
nx.draw(G, pos, labels=labels, with_labels=True, node_size=2000, node_color='lightblue', font_size=10)
plt.show()
if __name__ == "__main__":
cities = [
City("A", 0, 0),
City("B", 1, 2),
City("C", 3, 1),
City("D", 2, 4),
City("E", 4, 3)
]
iterations = 10
initial_state = ["A"] # 初始状态为从城市A开始
root = Node(initial_state)
available_cities = [city.name for city in cities if city.name != initial_state[0]]
traveling_salesman = TravelingSalesman(cities)
for _ in range(iterations):
node_to_expand = select(root)
if available_cities:
expanded_node = expand(node_to_expand, available_cities)
value = simulate(expanded_node, traveling_salesman)
backpropagate(expanded_node, value)
visualize_tree(root) # 可视化搜索树
旅行商问题-MCTS
最新推荐文章于 2024-10-20 19:52:39 发布