import torch.multiprocessing as mp
from tools.arguments import get_args
from agent import MASTERAgent
from tools import utilslite
import gym
import time
import random
from collections import namedtuple
import torch.nn as nn
Comms=namedtuple('Comms',('master_step','master_step_lock','master_num_steps','master_net_lock','master_queue','comm_queue','comm_flag','comm_lock'))
def changemaster(rank,args):
master=utilslite.process_comm(args,'master')
master.set_comm(args.comm)
pid = rank
while True:
with master._net_lock:
value = master.acnet.feats[0].weight.data[0, 0].clone()
master.acnet.feats[0].weight.data[0, 0] = pid
device=master.acnet.feats[0].weight.device
# value=master.weight[0,0].clone()
# master.weight[0, 0]=pid
# device=master.weight.device
print("{}: master value: {} ----> {} ({})".format(pid, value, pid,device))
time.sleep(random.random()*4)
def putmaster(rank,args):
env=gym.make('Acrobot-v1')
obsv_shape=env.observation_space.shape
action_space=env.action_space
obsv_shape=(1,*obsv_shape)
master=MASTERAgent(args,obsv_shape,action_space,selflock=False,device='cpu')
# smallnet=nn.Linear(4,4)
# smallnet.share_memory()
print('here put master')
utilslite.process_comm(args, 'master',master)
print('put master over')
return
def main():
args=get_args()
num_process = 4
args.num_processes = num_process
mp.set_start_method('spawn')
manager = mp.Manager()
manager.register('master',nn.Linear)
comm = Comms(master_step=manager.Value('l', 0), master_step_lock=manager.Lock(), master_num_steps=manager.Value('l', args.num_steps), master_net_lock=manager.Lock(),
master_queue=manager.Queue(maxsize=args.num_episode_rewards), comm_queue=manager.Queue(maxsize=args.num_processes), comm_flag=manager.Value('i', -1),
comm_lock=manager.Lock())
args.comm = comm
pool=mp.Pool(processes=num_process+2)
results=[]
result=pool.apply_async(putmaster,(0,args))
results.append(result)
for i in range(num_process):
result=pool.apply_async(changemaster,(i+1,args))
results.append(result)
for result in results:
result.get()
if __name__=='__main__':
main()