【强化学习环境】TSP问题强化学习环境搭建

13 篇文章 0 订阅
12 篇文章 0 订阅

Reset()重置包括坐标点

reset()重置不包括坐标点

这里的render()参考之前看过的一篇文章(具体是哪篇忘了,没有收藏)使用matplotlib实现

import warnings
import numpy as np
import matplotlib.pyplot as plt

class TSPEnvironment:
    """
    __init__() parm: num city, coordinate_dimension, box size
    step() and reset() return: (coordinates, path, valid) -> state, reward, done
    """
    def __init__(self, num_cities, coordinate_dimension=2, box_size=1.0):
        assert coordinate_dimension >= 2, "coordinate_dimension must >= 2 !"
        self.num_cities = num_cities
        self.coordinate_dimension = coordinate_dimension
        self.box_size = box_size
        self.coordinates, self.cities_coordinates, self.path, self.now_location = None, None, None, None
        self.done = False
        self.total_distance = 0.0
        self.__init_environment = self.Reset
        self.__init_environment()

    def reset(self, start_city=None):
        if start_city is not None:
            assert start_city < self.num_cities, "Start city must < num of city !!!"

        self.now_location = start_city if start_city is not None else np.random.choice(
            list(self.cities_coordinates.keys()))
        self.path = [self.now_location]
        self.done = False
        self.total_distance = 0.0

        valid = self.get_valid_cities(self.path, self.coordinates)
        coordinates = np.array([i for i in self.coordinates])
        path = [i for i in self.path]
        return (coordinates, path, valid), 0.0, self.done

    def Reset(self, start_city=None):
        if start_city is not None:
            assert start_city < self.num_cities, "Start city must < num of city !!!"
        self.coordinates = np.random.rand(self.num_cities, self.coordinate_dimension) * self.box_size
        self.cities_coordinates = dict(enumerate(self.coordinates))

        self.now_location = start_city if start_city is not None else np.random.choice(
            list(self.cities_coordinates.keys()))
        self.path = [self.now_location]
        self.done = False
        self.total_distance = 0.0

        valid = self.get_valid_cities(self.path, self.coordinates)
        coordinates = np.array([i for i in self.coordinates])
        path = [i for i in self.path]
        return (coordinates, path, valid), 0.0, self.done

    def step(self, action: int):
        if self.done:
            warn_msg = "The environment {} is done, please call Reset()/reset() or create new environment!".format(self)
            warnings.warn(warn_msg)
            return self.now_location, self.path, self.coordinates, self.cities_coordinates, None, self.done
        else:
            assert self.coordinates is not None, "No coordinates, please call 'Reset()' first!"
            assert self.cities_coordinates is not None, "No cities_coordinates, please call 'Reset()' first!"
            assert self.path is not None, "No path, please call 'Reset()/reset()' first!"
            assert self.now_location is not None, "No now_location, please call 'Reset()/reset()' first!"

            next_city = action

            assert next_city < self.num_cities and next_city >= 0, "There is no city: {} !\n\t\tValid cities: {}".format(
                next_city, set(self.cities_coordinates.keys())
            )
            assert next_city not in self.path, "Wrong next city: {}, Can not be repeated access: {} !\n\t\tValid cities: {}.".format(
                next_city, set(self.path), set(self.cities_coordinates.keys()) - set(self.path)
            )

            next_city_coordinate = self.cities_coordinates[next_city]
            now_city_coordinate = self.cities_coordinates[self.now_location]

            distance = self.euclidian_distance(next_city_coordinate, now_city_coordinate)
            reward = - distance[0]
            self.total_distance += distance[0]
            self.path.append(next_city)
            self.now_location = next_city
            if set(self.path) == set(self.cities_coordinates.keys()): self.done = True

            if self.done:
                start_end_distance = self.euclidian_distance(self.cities_coordinates[self.path[0]], self.cities_coordinates[self.path[-1]])
                reward += - start_end_distance[0]
                self.total_distance += start_end_distance[0]

            valid = self.get_valid_cities(self.path, self.coordinates)
            coordinates = np.array([i for i in self.coordinates])
            path = [i for i in self.path]
            return (coordinates, path, valid), reward, self.done

    @staticmethod
    def euclidian_distance(x, y):
        return np.sqrt(np.sum((x - y) ** 2, axis=-1, keepdims=True))

    def render(self):
        assert self.coordinates is not None, "No coordinates, please call reset() first!"
        if self.coordinate_dimension != 2:
            warnings.warn("Only show the first two dimension!")
        fig = plt.figure(figsize=(7, 7))
        ax = fig.add_subplot(111)
        ax.set_title("TSP environment")
        ax.scatter(self.coordinates[:, 0], self.coordinates[:, 1], c="red", s=50, marker="*")

        # plot start city as color blue
        start_city = self.cities_coordinates[self.path[0]]
        text = start_city[0] + 0.1, start_city[1]
        ax.annotate("start city", xy=start_city[[0, 1]], xytext=text, weight="bold")
        ax.scatter(start_city[0], start_city[1], c="blue", marker="*", s=50)

        # plot path as color orange, access cities as color green
        ax.plot(self.coordinates[self.path, 0], self.coordinates[self.path, 1], c="orange", linewidth=1, linestyle="--")
        ax.scatter(self.coordinates[self.path[1:], 0], self.coordinates[self.path[1:], 1], c="green", s=50, marker="*")

        if self.done:
            end_city = self.cities_coordinates[self.path[-1]]
            text = end_city[0] + 0.1, end_city[1]
            ax.annotate("end city", xy=end_city[[0, 1]], xytext=text, weight="bold")
            ax.scatter(end_city[0], end_city[1], c="black", s=50, marker="*")
            ax.plot([start_city[0], end_city[0]], [start_city[1], end_city[1]], c="orange", linewidth=1, linestyle="--")
        plt.xticks([])
        plt.yticks([])
        plt.show()

    def get_total_distance(self):
        return self.total_distance

    @staticmethod
    def get_valid_cities(path, coordinates):
        return (
            np.delete(coordinates, path, axis=0),
            [i for i in range(coordinates.shape[0]) if i not in path]
        )


  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值