MAML-RL Pytorch 代码解读 (17) -- maml_rl/metalearner.py

MAML-RL Pytorch 代码解读 (17) – maml_rl/metalearner.py

基本介绍

在网上看到的元学习 MAML 的代码大多是跟图像相关的,强化学习这边的代码比较少。

因为自己的思路跟 MAML-RL 相关,所以打算读一些源码。

MAML 的原始代码是基于 tensorflow 的,在 Github 上找到了基于 Pytorch 源码包,学习这个包。

源码链接

https://github.com/dragen1860/MAML-Pytorch-RL

文件路径

./maml_rl/sampler.py

import

import  gym
import  torch
import  multiprocessing as mp

from    maml_rl.envs.subproc_vec_env import SubprocVecEnv
from    maml_rl.episode import BatchEpisodes

make_env() 函数

#### 对gym库的.make(env_name)做了一个简单的包装。
def make_env(env_name):
	"""
	return a function
	:param env_name:
	:return:
	"""
	def _make_env():
		return gym.make(env_name)

	return _make_env

BatchSampler()

class BatchSampler:

	def __init__(self, env_name, batch_size, num_workers=mp.cpu_count()):
		"""

		:param env_name:
		:param batch_size: fast batch size
		:param num_workers:
		"""
        
        #### 将环境名字self.env_name、一批任务的数量self.batch_size和能参与工作的cpu数量self.num_workers初始化,初始化多线程队列。
		self.env_name = env_name
		self.batch_size = batch_size
		self.num_workers = num_workers
		self.queue = mp.Queue()
        
        #### 对于self.num_workers数量做迭代,也就是为每个线程开辟一个环境。最后通过列表的形式存储到env_factorys变量中。
		# [lambda function]
		env_factorys = [make_env(env_name) for _ in range(num_workers)]
        
        #### 创建父进程,用于管理self.num_workers数量的线程。最后再创建一个环境,应该是用于内环更新结束后的用于测试的环境。
		# this is the main process manager, and it will be in charge of num_workers sub-processes interacting with environment.
		self.envs = SubprocVecEnv(env_factorys, queue_=self.queue)
		self._env = gym.make(env_name)

	def sample(self, policy, params=None, gamma=0.95, device='cpu'):
		"""

		:param policy:
		:param params:
		:param gamma:
		:param device:
		:return:
		"""
        
        #### 创建一个批处理实例。现在队列中加入批任务大小的数字,然后再加入self.num_workers数量的None这样应该是做一个标志。
		episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device)
		for i in range(self.batch_size):
			self.queue.put(i)
		for _ in range(self.num_workers):
			self.queue.put(None)

        #### 对所有环境做初始化命令。得到每个子线程的观测和任务号。标记done"是否完成"为否。
		observations, batch_ids = self.envs.reset()
		dones = [False]
        
        #### 如果所有队列都没有完成"not all(dones)"且队列没有空,就说明还有队列。
		while (not all(dones)) or (not self.queue.empty()): # if all done and queue is empty
			# for reinforcement learning, the forward process requires no-gradient
            
            #### 接下来做的是强化学习执行任务过程。因为这本身是输出结果,是前馈过程,那么就不需要导数。
			with torch.no_grad():
				# convert observation to cuda
				# compute policy on cuda
				# convert action to cpu
                
                #### 经典强化学习过程。先得到观测向量,然后获得动作张量,再转成动作array。
				observations_tensor = torch.from_numpy(observations).to(device=device)
				# forward via policy network
				# policy network will return Categorical(logits=logits)
				actions_tensor = policy(observations_tensor, params=params).sample()
				actions = actions_tensor.cpu().numpy()

            #### 最后执行step()函数,得到新的观测、奖励、是否完成信息以及新的批任务号。最后将这些加入episodes的经验池子中。最后做一个更新。
			new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(actions)
			# here is observations NOT new_observations, batch_ids NOT new_batch_ids
			episodes.append(observations, actions, rewards, batch_ids)
			observations, batch_ids = new_observations, new_batch_ids

		return episodes
    
	#### 重置任务进行新的回合。
	def reset_task(self, task):
		tasks = [task for _ in range(self.num_workers)]
		reset = self.envs.reset_task(tasks)
		return all(reset)

    #### 通过各种分布获得一批任务。
	def sample_tasks(self, num_tasks):
		tasks = self._env.unwrapped.sample_tasks(num_tasks)
		return tasks
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ctrl+Alt+L

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值