Reset()重置包括坐标点
reset()重置不包括坐标点
这里的render()参考之前看过的一篇文章(具体是哪篇忘了,没有收藏)使用matplotlib实现
import warnings
import numpy as np
import matplotlib.pyplot as plt
class TSPEnvironment:
"""
__init__() parm: num city, coordinate_dimension, box size
step() and reset() return: (coordinates, path, valid) -> state, reward, done
"""
def __init__(self, num_cities, coordinate_dimension=2, box_size=1.0):
assert coordinate_dimension >= 2, "coordinate_dimension must >= 2 !"
self.num_cities = num_cities
self.coordinate_dimension = coordinate_dimension
self.box_size = box_size
self.coordinates, self.cities_coordinates, self.path, self.now_location = None, None, None, None
self.done = False
self.total_distance = 0.0
self.__init_environment = self.Reset
self.__init_environment()
def reset(self, start_city=None):
if start_city is not None:
assert start_city < self.num_cities, "Start city must < num of city !!!"
self.now_location = start_city if start_city is not None else np.random.choice(
list(self.cities_coordinates.keys()))
self.path = [self.now_location]
self.done = False
self.total_distance = 0.0
valid = self.get_valid_cities(self.path, self.coordinates)
coordinates = np.array([i for i in self.coordinates])
path = [i for i in self.path]
return (coordinates, path, valid), 0.0, self.done
def Reset(self, start_city=None):
if start_city is not None:
assert start_city < self.num_cities, "Start city must < num of city !!!"
self.coordinates = np.random.rand(self.num_cities, self.coordinate_dimension) * self.box_size
self.cities_coordinates = dict(enumerate(self.coordinates))
self.now_location = start_city if start_city is not None else np.random.choice(
list(self.cities_coordinates.keys()))
self.path = [self.now_location]
self.done = False
self.total_distance = 0.0
valid = self.get_valid_cities(self.path, self.coordinates)
coordinates = np.array([i for i in self.coordinates])
path = [i for i in self.path]
return (coordinates, path, valid), 0.0, self.done
def step(self, action: int):
if self.done:
warn_msg = "The environment {} is done, please call Reset()/reset() or create new environment!".format(self)
warnings.warn(warn_msg)
return self.now_location, self.path, self.coordinates, self.cities_coordinates, None, self.done
else:
assert self.coordinates is not None, "No coordinates, please call 'Reset()' first!"
assert self.cities_coordinates is not None, "No cities_coordinates, please call 'Reset()' first!"
assert self.path is not None, "No path, please call 'Reset()/reset()' first!"
assert self.now_location is not None, "No now_location, please call 'Reset()/reset()' first!"
next_city = action
assert next_city < self.num_cities and next_city >= 0, "There is no city: {} !\n\t\tValid cities: {}".format(
next_city, set(self.cities_coordinates.keys())
)
assert next_city not in self.path, "Wrong next city: {}, Can not be repeated access: {} !\n\t\tValid cities: {}.".format(
next_city, set(self.path), set(self.cities_coordinates.keys()) - set(self.path)
)
next_city_coordinate = self.cities_coordinates[next_city]
now_city_coordinate = self.cities_coordinates[self.now_location]
distance = self.euclidian_distance(next_city_coordinate, now_city_coordinate)
reward = - distance[0]
self.total_distance += distance[0]
self.path.append(next_city)
self.now_location = next_city
if set(self.path) == set(self.cities_coordinates.keys()): self.done = True
if self.done:
start_end_distance = self.euclidian_distance(self.cities_coordinates[self.path[0]], self.cities_coordinates[self.path[-1]])
reward += - start_end_distance[0]
self.total_distance += start_end_distance[0]
valid = self.get_valid_cities(self.path, self.coordinates)
coordinates = np.array([i for i in self.coordinates])
path = [i for i in self.path]
return (coordinates, path, valid), reward, self.done
@staticmethod
def euclidian_distance(x, y):
return np.sqrt(np.sum((x - y) ** 2, axis=-1, keepdims=True))
def render(self):
assert self.coordinates is not None, "No coordinates, please call reset() first!"
if self.coordinate_dimension != 2:
warnings.warn("Only show the first two dimension!")
fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111)
ax.set_title("TSP environment")
ax.scatter(self.coordinates[:, 0], self.coordinates[:, 1], c="red", s=50, marker="*")
start_city = self.cities_coordinates[self.path[0]]
text = start_city[0] + 0.1, start_city[1]
ax.annotate("start city", xy=start_city[[0, 1]], xytext=text, weight="bold")
ax.scatter(start_city[0], start_city[1], c="blue", marker="*", s=50)
ax.plot(self.coordinates[self.path, 0], self.coordinates[self.path, 1], c="orange", linewidth=1, linestyle="--")
ax.scatter(self.coordinates[self.path[1:], 0], self.coordinates[self.path[1:], 1], c="green", s=50, marker="*")
if self.done:
end_city = self.cities_coordinates[self.path[-1]]
text = end_city[0] + 0.1, end_city[1]
ax.annotate("end city", xy=end_city[[0, 1]], xytext=text, weight="bold")
ax.scatter(end_city[0], end_city[1], c="black", s=50, marker="*")
ax.plot([start_city[0], end_city[0]], [start_city[1], end_city[1]], c="orange", linewidth=1, linestyle="--")
plt.xticks([])
plt.yticks([])
plt.show()
def get_total_distance(self):
return self.total_distance
@staticmethod
def get_valid_cities(path, coordinates):
return (
np.delete(coordinates, path, axis=0),
[i for i in range(coordinates.shape[0]) if i not in path]
)