pytorch cpu network can be passed through the mp.manager().queue() between processes.(share memory)

import torch.multiprocessing as mp

from tools.arguments import get_args
from agent import MASTERAgent
from tools import utilslite
import gym
import time
import random
from collections import namedtuple
import torch.nn as nn

Comms=namedtuple('Comms',('master_step','master_step_lock','master_num_steps','master_net_lock','master_queue','comm_queue','comm_flag','comm_lock'))

def changemaster(rank,args):
    master=utilslite.process_comm(args,'master')
    master.set_comm(args.comm)
    pid = rank
    while True:
        with master._net_lock:
            value = master.acnet.feats[0].weight.data[0, 0].clone()
            master.acnet.feats[0].weight.data[0, 0] = pid
            device=master.acnet.feats[0].weight.device
        # value=master.weight[0,0].clone()
        # master.weight[0, 0]=pid
        # device=master.weight.device
        print("{}: master value: {} ----> {} ({})".format(pid, value, pid,device))
        time.sleep(random.random()*4)

def putmaster(rank,args):
    env=gym.make('Acrobot-v1')
    obsv_shape=env.observation_space.shape
    action_space=env.action_space
    obsv_shape=(1,*obsv_shape)
    master=MASTERAgent(args,obsv_shape,action_space,selflock=False,device='cpu')
    # smallnet=nn.Linear(4,4)
    # smallnet.share_memory()
    print('here put master')
    utilslite.process_comm(args, 'master',master)
    print('put master over')
    return

def main():
    args=get_args()
    num_process = 4
    args.num_processes = num_process
    mp.set_start_method('spawn')
    manager = mp.Manager()
    manager.register('master',nn.Linear)
    comm = Comms(master_step=manager.Value('l', 0), master_step_lock=manager.Lock(), master_num_steps=manager.Value('l', args.num_steps), master_net_lock=manager.Lock(),
                 master_queue=manager.Queue(maxsize=args.num_episode_rewards), comm_queue=manager.Queue(maxsize=args.num_processes), comm_flag=manager.Value('i', -1),
                 comm_lock=manager.Lock())
    args.comm = comm
    pool=mp.Pool(processes=num_process+2)
    results=[]
    result=pool.apply_async(putmaster,(0,args))
    results.append(result)
    for i in range(num_process):
        result=pool.apply_async(changemaster,(i+1,args))
        results.append(result)
    for result in results:
        result.get()

if __name__=='__main__':
    main()

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值