1. Distributed Reinforcement Learning using RPC and RRef
距离说明,如何通过RPC与RRef实现强化学习模型(CartPole-v1)。
主要实现的功能
展示如何通过RPC在多个workers见进行数据传输。
展示如何使用 RRef 来代表 remote objects。
通过 torch.distributed.rpc 可以得到原生支持以及优化。
模型代码如下
import torch.nn as nn
import torch.nn.functional as F
classPolicy(nn.Module):def__init__(self):super(Policy, self).__init__()
self.affine1 = nn.Linear(4,128)
self.dropout = nn.Dropout(p=0.6)
self.affine2 = nn.Linear(128,2)
self.saved_log_probs =[]
self.rewards =[]defforward(self, x):
x = self.affine1(x)
x = self.dropout(x)
x = F.relu(x)
action_scores = self.affine2(x)return F.softmax(action_scores, dim=1)
构建 helepr function 用来远程调用 RRef 中 owner worker 的函数。
from torch.distributed.rpc import rpc_sync
def_call_method(method, rref,*args,**kwargs):return method(rref.local_value(),*args,**kwargs)def_remote_method(method, rref,*args,**kwargs):
args =[method, rref]+list(args)return rpc_sync(rref.owner(), _call_method, args=args, kwargs=kwargs)# to call a function on an rref, we could do the following# _remote_method(some_func, rref, *args)
介绍observer。
每个observer创建自己的环境,等待agent来运行episode。
每个episode中,每个observer循环 n_steps 次。
每次循环使用RPC将环境状态传递给agent并得到反馈。
将反馈结果用于当前环境,得到reward以及下一个环境状态。
之后,observer利用另外一个rpc连接将reward传递给agent。
import argparse
import gym
import torch.distributed.rpc as rpc
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,)
parser.add_argument('--world_size', default=2,help='Number of workers')
parser.add_argument('--log_interval', default=1,help='Log every log_interval episodes')
parser.add_argument('--gamma', default=0.1,help='how much to value future rewards')
parser.add_argument('--seed', default=1,help='random seed for reproducibility')
args = parser.parse_args()classObserver:def__init__(self):
self.id= rpc.get_worker_info().id
self.env = gym.make('CartPole-v1')
self.env.seed(args.seed)defrun_episode(self, agent_rref, n_steps):
state, ep_reward = self.env.reset(),0for step inrange(n_steps):# send the state to the agent to get an action
action = _remote_method(Agent.select_action, agent_rref, self.id, state)# apply the action to the environment, and get the reward
state, reward, done, _ = self.env.step(action)# report the reward to the agent for training purpose
_remote_method(Agent.report_reward, agent_rref, self.id, reward)if done:break