强化学习PPO算法解决:限制取出物品数量的01背包问题

import gymnasium as gym  # 导入gym
from gymnasium import Env
from gymnasium.spaces import Discrete, Box, Dict, Tuple, MultiBinary, MultiDiscrete
import numpy as np
import random
import os
from stable_baselines3 import PPO, DQN ,A2C
from stable_baselines3.common.vec_env import VecFrameStack  # 堆叠操作,提高训练效率
from stable_baselines3.common.evaluation import evaluate_policy

model_name = "PPO3"
# 定义环境
class KnapsackEnv(Env):
    def __init__(self, weights, values, capacity, limit):
        super(KnapsackEnv, self).__init__()
        self.weights = weights
        self.values = values
        self.capacity = capacity
        self.n_items = len(weights)
        self.limit_num = limit

        self.current_weight = 0
        self.current_value = 0
        self.current_index = 0

        self.action_space = Discrete(2)  # 0: 不选, 1: 选
        self.observation_space = gym.spaces.Dict({"select_id": gym.spaces.Box(0, self.n_items, (1,), np.int64),
                                                  "selected": gym.spaces.Box(0, 1, (self.n_items,), np.int64),})

    def reset(self, seed="", options=""):
        self.current_weight = 0
        self.current_value = 0
        self.current_index = 0
        self.state = {"select_id": np.zeros(1 ,dtype=int), "selected": np.zeros(self.n_items, dtype=int)}
        return self.state, {}

    def step(self, action):
        # self.state = self.observation_space.sample()
        reward = 0
        done = False
        truncated = False
        select_id = self.state["select_id"][0]

        if action==1:
            self.state["selected"][select_id] = 1
            if self.current_weight + self.weights[select_id] > self.capacity:
                reward = -100
            else:
                self.current_weight += self.weights[select_id]
                self.current_value += self.values[select_id]
                reward = 0
        else:
            reward = 0
            self.state["selected"][select_id] = 0

        select_id += 1
        if select_id == self.n_items or np.sum(self.state["selected"]) == self.limit_num:
            reward = self.current_value * 100
            done = True

        self.state['select_id'][0] = select_id
        return self.state, reward, done, truncated, {}

    def render(self, mode='human'):
        pass

env = KnapsackEnv([1,4,6,4],[3,4,10,10],9, 2)

model = PPO("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps=200000)
model.save(model_name)

model = PPO.load(model_name)
episodes = 20
for episode in range(1, episodes + 1):
    obs, _ = env.reset()
    done = False
    score = 0
    truncated = False
    step = 0
    while not done:
        env.render()
        action, _ = model.predict(obs)
        obs, reward, done, truncated, info = env.step(action)
        step += 1
    print('Episode:{} Score:{} Step:{}'.format(episode, env.current_value, step))

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

forgetable tree

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值