MAML-RL Pytorch 代码解读 (17) -- maml_rl/metalearner.py

MAML-RL Pytorch 代码解读 (17) – maml_rl/metalearner.py

基本介绍

在网上看到的元学习 MAML 的代码大多是跟图像相关的,强化学习这边的代码比较少。

因为自己的思路跟 MAML-RL 相关,所以打算读一些源码。

MAML 的原始代码是基于 tensorflow 的,在 Github 上找到了基于 Pytorch 源码包,学习这个包。

源码链接

https://github.com/dragen1860/MAML-Pytorch-RL

文件路径

./maml_rl/sampler.py

import

import  gym
import  torch
import  multiprocessing as mp

from    maml_rl.envs.subproc_vec_env import SubprocVecEnv
from    maml_rl.episode import BatchEpisodes

make_env() 函数

#### 对gym库的.make(env_name)做了一个简单的包装。
def make_env(env_name):
	"""
	return a function
	:param env_name:
	:return:
	"""
	def _make_env():
		return gym.make(env_name)

	return _make_env

BatchSampler()

class BatchSampler:

	def __init__(self, env_name, batch_size, num_workers=mp.cpu_count()):
		"""

		:param env_name:
		:param batch_size: fast batch size
		:param num_workers:
		"""
        
        #### 将环境名字self.env_name、一批任务的数量self.batch_size和能参与工作的cpu数量self.num_workers初始化,初始化多线程队列。
		self.env_name = env_name
		self.batch_size = batch_size
		self.num_workers = num_workers
		self.queue = mp.Queue()
        
        #### 对于self.num_workers数量做迭代,也就是为每个线程开辟一个环境。最后通过列表的形式存储到env_factorys变量中。
		# [lambda function]
		env_factorys = [make_env(env_name) for _ in range(num_workers)]
        
        #### 创建父进程,用于管理self.num_workers数量的线程。最后再创建一个环境,应该是用于内环更新结束后的用于测试的环境。
		# this is the main process manager, and it will be in charge of num_workers sub-processes interacting with environment.
		self.envs = SubprocVecEnv(env_factorys, queue_=self.queue)
		self._env = gym.make(env_name)

	def sample(self, policy, params=None, gamma=0.95, device='cpu'):
		"""

		:param policy:
		:param params:
		:param gamma:
		:param device:
		:return:
		"""
        
        #### 创建一个批处理实例。现在队列中加入批任务大小的数字,然后再加入self.num_workers数量的None这样应该是做一个标志。
		episodes = BatchEpisodes(batch_size=self.batch_size, gamma=gamma, device=device)
		for i in range(self.batch_size):
			self.queue.put(i)
		for _ in range(self.num_workers):
			self.queue.put(None)

        #### 对所有环境做初始化命令。得到每个子线程的观测和任务号。标记done"是否完成"为否。
		observations, batch_ids = self.envs.reset()
		dones = [False]
        
        #### 如果所有队列都没有完成"not all(dones)"且队列没有空,就说明还有队列。
		while (not all(dones)) or (not self.queue.empty()): # if all done and queue is empty
			# for reinforcement learning, the forward process requires no-gradient
            
            #### 接下来做的是强化学习执行任务过程。因为这本身是输出结果,是前馈过程,那么就不需要导数。
			with torch.no_grad():
				# convert observation to cuda
				# compute policy on cuda
				# convert action to cpu
                
                #### 经典强化学习过程。先得到观测向量,然后获得动作张量,再转成动作array。
				observations_tensor = torch.from_numpy(observations).to(device=device)
				# forward via policy network
				# policy network will return Categorical(logits=logits)
				actions_tensor = policy(observations_tensor, params=params).sample()
				actions = actions_tensor.cpu().numpy()

            #### 最后执行step()函数,得到新的观测、奖励、是否完成信息以及新的批任务号。最后将这些加入episodes的经验池子中。最后做一个更新。
			new_observations, rewards, dones, new_batch_ids, _ = self.envs.step(actions)
			# here is observations NOT new_observations, batch_ids NOT new_batch_ids
			episodes.append(observations, actions, rewards, batch_ids)
			observations, batch_ids = new_observations, new_batch_ids

		return episodes
    
	#### 重置任务进行新的回合。
	def reset_task(self, task):
		tasks = [task for _ in range(self.num_workers)]
		reset = self.envs.reset_task(tasks)
		return all(reset)

    #### 通过各种分布获得一批任务。
	def sample_tasks(self, num_tasks):
		tasks = self._env.unwrapped.sample_tasks(num_tasks)
		return tasks
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
以下是使用PyTorch实现的MAML元学习的示例代码: ```python import torch import torch.nn as nn import torch.optim as optim class MAML(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(MAML, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, output_size) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x def clone(self, device=None): clone = MAML(self.input_size, self.hidden_size, self.output_size) if device is not None: clone.to(device) clone.load_state_dict(self.state_dict()) return clone class MetaLearner(nn.Module): def __init__(self, model, lr): super(MetaLearner, self).__init__() self.model = model self.optimizer = optim.Adam(self.model.parameters(), lr=lr) def forward(self, x): return self.model(x) def meta_update(self, task_gradients): for param, gradient in zip(self.model.parameters(), task_gradients): param.grad = gradient self.optimizer.step() self.optimizer.zero_grad() def train_task(model, data_loader, lr_inner, num_updates_inner): model.train() task_loss = 0.0 for i, (input, target) in enumerate(data_loader): input = input.to(device) target = target.to(device) clone = model.clone(device) meta_optimizer = MetaLearner(clone, lr_inner) for j in range(num_updates_inner): output = clone(input) loss = nn.functional.mse_loss(output, target) grad = torch.autograd.grad(loss, clone.parameters(), create_graph=True) fast_weights = [param - lr_inner * g for param, g in zip(clone.parameters(), grad)] clone.load_state_dict({name: param for name, param in zip(clone.state_dict(), fast_weights)}) output = clone(input) loss = nn.functional.mse_loss(output, target) task_loss += loss.item() grad = torch.autograd.grad(loss, model.parameters()) task_gradients = [-lr_inner * g for g in grad] meta_optimizer.meta_update(task_gradients) return task_loss / len(data_loader) # Example usage device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') input_size = 1 hidden_size = 20 output_size = 1 model = MAML(input_size, hidden_size, output_size) model.to(device) data_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(torch.randn(100, input_size), torch.randn(100, output_size)), batch_size=10, shuffle=True) meta_optimizer = MetaLearner(model, lr=0.001) for i in range(100): task_loss = train_task(model, data_loader, lr_inner=0.01, num_updates_inner=5) print('Task loss:', task_loss) meta_optimizer.zero_grad() task_gradients = torch.autograd.grad(task_loss, model.parameters()) meta_optimizer.meta_update(task_gradients) ``` 在这个示例中,我们定义了两个类,MAML和MetaLearner。MAML是一个普通的神经网络,而MetaLearner包含了用于更新MAML的元优化器。在每个任务上,我们使用MAML的副本进行内部更新,然后使用元优化器来更新MAML的权重。在元学习的过程中,我们首先通过调用train_task函数来训练一个任务,然后通过调用meta_update函数来更新MAML的权重。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ctrl+Alt+L

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值